generate expected code when for-looping over type-declared list or tuple
authorStefan Behnel <scoder@users.berlios.de>
Wed, 19 Nov 2008 07:00:52 +0000 (08:00 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Wed, 19 Nov 2008 07:00:52 +0000 (08:00 +0100)
Cython/Compiler/ExprNodes.py

index 64fe891520cdc6041e3332842e0ee7e52bd641d6..92c39544d0ef8d01015f4dbf099d13ba2e272a63 100644 (file)
@@ -1340,10 +1340,16 @@ class IteratorNode(ExprNode):
         self.counter.release_temp(env)
     
     def generate_result_code(self, code):
-        code.putln(
-            "if (PyList_CheckExact(%s) || PyTuple_CheckExact(%s)) {" % (
-                self.sequence.py_result(),
-                self.sequence.py_result()))
+        is_builtin_sequence = self.sequence.type is list_type or \
+            self.sequence.type is tuple_type
+        if is_builtin_sequence:
+            code.putln(
+                "if (likely(%s != Py_None)) {" % self.sequence.py_result())
+        else:
+            code.putln(
+                "if (PyList_CheckExact(%s) || PyTuple_CheckExact(%s)) {" % (
+                    self.sequence.py_result(),
+                    self.sequence.py_result()))
         code.putln(
             "%s = 0; %s = %s; Py_INCREF(%s);" % (
                 self.counter.result(),
@@ -1351,11 +1357,16 @@ class IteratorNode(ExprNode):
                 self.sequence.py_result(),
                 self.result()))
         code.putln("} else {")
-        code.putln("%s = -1; %s = PyObject_GetIter(%s); %s" % (
-                self.counter.result(),
-                self.result(),
-                self.sequence.py_result(),
-                code.error_goto_if_null(self.result(), self.pos)))
+        if is_builtin_sequence:
+            code.putln(
+                'PyErr_SetString(PyExc_TypeError, "\'NoneType\' object is not iterable"); %s' %
+                code.error_goto(self.pos))
+        else:
+            code.putln("%s = -1; %s = PyObject_GetIter(%s); %s" % (
+                    self.counter.result(),
+                    self.result(),
+                    self.sequence.py_result(),
+                    code.error_goto_if_null(self.result(), self.pos)))
         code.putln("}")
 
 
@@ -1374,23 +1385,35 @@ class NextNode(AtomicExprNode):
         self.is_temp = 1
     
     def generate_result_code(self, code):
-        for py_type in ["List", "Tuple"]:
-            code.putln(
-                "if (likely(Py%s_CheckExact(%s))) {" % (py_type, self.iterator.py_result()))
+        if self.iterator.sequence.type is list_type:
+            type_checks = [(list_type, "List")]
+        elif self.iterator.sequence.type is tuple_type:
+            type_checks = [(tuple_type, "Tuple")]
+        else:
+            type_checks = [(list_type, "List"), (tuple_type, "Tuple")]
+
+        for py_type, prefix in type_checks:
+            if len(type_checks) > 1:
+                code.putln(
+                    "if (likely(Py%s_CheckExact(%s))) {" % (
+                        prefix, self.iterator.py_result()))
             code.putln(
                 "if (%s >= Py%s_GET_SIZE(%s)) break;" % (
                     self.iterator.counter.result(),
-                    py_type,
+                    prefix,
                     self.iterator.py_result()))
             code.putln(
                 "%s = Py%s_GET_ITEM(%s, %s); Py_INCREF(%s); %s++;" % (
                     self.result(),
-                    py_type,
+                    prefix,
                     self.iterator.py_result(),
                     self.iterator.counter.result(),
                     self.result(),
                     self.iterator.counter.result()))
-            code.put("} else ")
+            if len(type_checks) > 1:
+                code.put("} else ")
+        if len(type_checks) == 1:
+            return
         code.putln("{")
         code.putln(
             "%s = PyIter_Next(%s);" % (