List comprehension
authorRobert Bradshaw <robertwb@math.washington.edu>
Wed, 10 Jan 2007 09:06:30 +0000 (01:06 -0800)
committerRobert Bradshaw <robertwb@math.washington.edu>
Wed, 10 Jan 2007 09:06:30 +0000 (01:06 -0800)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Nodes.py
Cython/Compiler/Parsing.py

index 80ac7abfe8331519e50a1f206188521be2c6fb59..1569f5e0356a6f9eac10d21b22fd716296d21793 100644 (file)
@@ -916,10 +916,23 @@ class IteratorNode(ExprNode):
         self.sequence = self.sequence.coerce_to_pyobject(env)
         self.type = py_object_type
         self.is_temp = 1
+        
+        self.counter = TempNode(self.pos, PyrexTypes.c_py_ssize_t_type, env)
+        self.counter.allocate_temp(env)
+        
+    def release_temp(self, env):
+        env.release_temp(self.result_code)
+        self.counter.release_temp(env)
     
     def generate_result_code(self, code):
         code.putln(
-            "%s = PyObject_GetIter(%s); if (!%s) %s" % (
+            "if (PyList_CheckExact(%s)) { %s = 0; %s = %s; Py_INCREF(%s); }" % (
+                self.sequence.py_result(),
+                self.counter.result_code,
+                self.result_code,
+                self.sequence.py_result(),
+                self.result_code))
+        code.putln("else { %s = PyObject_GetIter(%s); if (!%s) %s }" % (
                 self.result_code,
                 self.sequence.py_result(),
                 self.result_code,
@@ -941,6 +954,16 @@ class NextNode(AtomicExprNode):
         self.is_temp = 1
     
     def generate_result_code(self, code):
+        code.putln(
+            "if (PyList_CheckExact(%s)) { if (%s >= PyList_GET_SIZE(%s)) break; %s = PyList_GET_ITEM(%s, %s++); Py_INCREF(%s); }" % (
+                self.iterator.py_result(),
+                self.iterator.counter.result_code,
+                self.iterator.py_result(),
+                self.result_code,
+                self.iterator.py_result(),
+                self.iterator.counter.result_code,
+                self.result_code))
+        code.putln("else {")
         code.putln(
             "%s = PyIter_Next(%s);" % (
                 self.result_code,
@@ -951,10 +974,9 @@ class NextNode(AtomicExprNode):
         code.putln(
                 "if (PyErr_Occurred()) %s" %
                     code.error_goto(self.pos))
-        code.putln(
-                "break;")
-        code.putln(
-            "}")
+        code.putln("break;")
+        code.putln("}")
+        code.putln("}")
         
 
 class ExcValueNode(AtomicExprNode):
@@ -1832,6 +1854,51 @@ class ListNode(SequenceNode):
         for arg in self.args:
             arg.generate_post_assignment_code(code)            
 
+            
+class ListComprehensionNode(SequenceNode):
+
+    subexprs = []
+
+    def analyse_types(self, env): 
+        self.type = py_object_type
+        self.is_temp = 1
+        self.append.target = self # this is a CloneNode used in the PyList_Append in the inner loop
+        
+    def allocate_temps(self, env, result = None): 
+        if debug_temp_alloc:
+            print self, "Allocating temps"
+        self.allocate_temp(env, result)
+        self.loop.analyse_declarations(env)
+        self.loop.analyse_expressions(env)
+        
+    def generate_operation_code(self, code):
+        code.putln("%s = PyList_New(%s); if (!%s) %s" %
+            (self.result_code,
+            0,
+            self.result_code,
+            code.error_goto(self.pos)))
+        self.loop.generate_execution_code(code)
+
+
+class ListComprehensionAppendNode(ExprNode):
+
+    subexprs = ['expr']
+    
+    def analyse_types(self, env):
+        self.expr.analyse_types(env)
+        if self.expr.type != py_object_type:
+            self.expr = self.expr.coerce_to_pyobject(env)
+        self.type = PyrexTypes.c_int_type
+        self.is_temp = 1
+    
+    def generate_result_code(self, code):
+        code.putln("%s = PyList_Append(%s, %s); if (%s) %s" %
+            (self.result_code,
+            self.target.result_code,
+            self.expr.result_code,
+            self.result_code, 
+            code.error_goto(self.pos)))
+
 
 class DictNode(ExprNode):
     #  Dictionary constructor.
@@ -2979,6 +3046,11 @@ class CloneNode(CoercionNode):
     
     def calculate_result_code(self):
         return self.arg.result_code
+        
+    def analyse_types(self, env):
+        self.type = self.arg.type
+        self.result_ctype = self.arg.result_ctype
+        self.is_temp = 1
     
     #def result_as_extension_type(self):
     #  return self.arg.result_as_extension_type()
index 54f6b934a9b6575131ec06e949d0c4fcbd3b2d1b..b45909f9b5943770b0b5b20b75a7705e405a23c2 100644 (file)
@@ -3065,6 +3065,12 @@ class WhileStatNode(StatNode):
         code.put_label(break_label)
 
 
+def ForStatNode(pos, **kw):
+    if kw.has_key('iterator'):
+        return ForInStatNode(pos, **kw)
+    else:
+        return ForFromStatNode(pos, **kw)
+
 class ForInStatNode(StatNode):
     #  for statement
     #
index 66154419057bc5c4ef04ead429587f6149123237..3bbac250bf14f1f861f10fc3f8fd8dbfeeaf65fa 100644 (file)
@@ -573,14 +573,65 @@ def unquote(s):
         s = "".join(l2)
     return s
         
+# list_display         ::=     "[" [listmaker] "]"
+# listmaker    ::=     expression ( list_for | ( "," expression )* [","] )
+# list_iter    ::=     list_for | list_if
+# list_for     ::=     "for" expression_list "in" testlist [list_iter]
+# list_if      ::=     "if" test [list_iter]
+        
 def p_list_maker(s):
     # s.sy == '['
     pos = s.position()
     s.next()
-    exprs = p_simple_expr_list(s)
-    s.expect(']')
-    return ExprNodes.ListNode(pos, args = exprs)
+    if s.sy == ']':
+        s.expect(']')
+        return ExprNodes.ListNode(pos, args = [])
+    expr = p_simple_expr(s)
+    if s.sy == 'for':
+        loop = p_list_for(s)
+        s.expect(']')
+        inner_loop = loop
+        while not isinstance(inner_loop.body, Nodes.PassStatNode):
+            inner_loop = inner_loop.body
+            if isinstance(inner_loop, Nodes.IfStatNode):
+                 inner_loop = inner_loop.if_clauses[0]
+        append = ExprNodes.ListComprehensionAppendNode( pos, expr = expr )
+        inner_loop.body = Nodes.ExprStatNode(pos, expr = append)
+        return ExprNodes.ListComprehensionNode(pos, loop = loop, append = append)
+    else:
+        exprs = [expr]
+        if s.sy == ',':
+            s.next()
+            exprs += p_simple_expr_list(s)
+        s.expect(']')
+        return ExprNodes.ListNode(pos, args = exprs)
+        
+def p_list_iter(s):
+    if s.sy == 'for':
+        return p_list_for(s)
+    elif s.sy == 'if':
+        return p_list_if(s)
+    else:
+        return Nodes.PassStatNode(s.position())
 
+def p_list_for(s):
+    # s.sy == 'for'
+    pos = s.position()
+    s.next()
+    kw = p_for_bounds(s)
+    kw['else_clause'] = None
+    kw['body'] = p_list_iter(s)
+    return Nodes.ForStatNode(pos, **kw)
+        
+def p_list_if(s):
+    # s.sy == 'if'
+    pos = s.position()
+    s.next()
+    test = p_simple_expr(s)
+    return Nodes.IfStatNode(pos, 
+        if_clauses = [Nodes.IfClauseNode(pos, condition = test, body = p_list_iter(s))],
+        else_clause = None )
+    
 #dictmaker: test ':' test (',' test ':' test)* [',']
 
 def p_dict_maker(s):
@@ -931,17 +982,17 @@ def p_for_statement(s):
     # s.sy == 'for'
     pos = s.position()
     s.next()
+    kw = p_for_bounds(s)
+    kw['body'] = p_suite(s)
+    kw['else_clause'] = p_else_clause(s)
+    return Nodes.ForStatNode(pos, **kw)
+            
+def p_for_bounds(s):
     target = p_for_target(s)
     if s.sy == 'in':
         s.next()
         iterator = p_for_iterator(s)
-        body = p_suite(s)
-        else_clause = p_else_clause(s)
-        return Nodes.ForInStatNode(pos, 
-            target = target,
-            iterator = iterator,
-            body = body,
-            else_clause = else_clause)
+        return { 'target': target, 'iterator': iterator }
     elif s.sy == 'from':
         s.next()
         bound1 = p_bit_expr(s)
@@ -960,16 +1011,11 @@ def p_for_statement(s):
         if rel1[0] <> rel2[0]:
             error(rel2_pos,
                 "Relation directions in for-from do not match")
-        body = p_suite(s)
-        else_clause = p_else_clause(s)
-        return Nodes.ForFromStatNode(pos,
-            target = target,
-            bound1 = bound1,
-            relation1 = rel1,
-            relation2 = rel2,
-            bound2 = bound2,
-            body = body,
-            else_clause = else_clause)
+        return {'target': target, 
+                'bound1': bound1, 
+                'relation1': rel1, 
+                'relation2': rel2,
+                'bound2': bound2 }
 
 def p_for_from_relation(s):
     if s.sy in inequality_relations: