Conditional expressions
authorRobert Bradshaw <robertwb@math.washington.edu>
Fri, 23 Feb 2007 04:52:31 +0000 (20:52 -0800)
committerRobert Bradshaw <robertwb@math.washington.edu>
Fri, 23 Feb 2007 04:52:31 +0000 (20:52 -0800)
Changes in grammar required change for this, see http://www.python.org/dev/peps/pep-0308/
Most noteably for list comprehensions (dissambiguate the if)

Cython/Compiler/ExprNodes.py
Cython/Compiler/Parsing.py

index 4681089932980ef8487c2e8397aaeab0cefc237a..cdf8ff562ff664c8921708e84705d8a202b20e09 100644 (file)
@@ -2695,6 +2695,88 @@ class BoolBinopNode(ExprNode):
         return test_result
 
 
+class CondExprNode(ExprNode):
+    #  Short-circuiting conditional expression.
+    #
+    #  test        ExprNode
+    #  true_val    ExprNode
+    #  false_val   ExprNode
+    
+    temp_bool = None
+    
+    subexprs = ['test', 'true_val', 'false_val']
+    
+    def analyse_types(self, env):
+        self.test.analyse_types(env)
+        self.test = self.test.coerce_to_boolean(env)
+        self.true_val.analyse_types(env)
+        self.false_val.analyse_types(env)
+        self.type = self.compute_result_type(self.true_val.type, self.false_val.type)
+        if self.type:
+            if self.true_val.type.is_pyobject or self.false_val.type.is_pyobject:
+                self.true_val = self.true_val.coerce_to(self.type, env)
+                self.false_val = self.false_val.coerce_to(self.type, env)
+            # must be tmp variables so they can share a result
+            self.true_val = self.true_val.coerce_to_temp(env)
+            self.false_val = self.false_val.coerce_to_temp(env)
+            self.is_temp = 1
+        else:
+            self.type_error()
+    
+    def allocate_temps(self, env, result_code = None):
+        #  We only ever evaluate one side, and this is 
+        #  after evaluating the truth value, so we may
+        #  use an allocation strategy here which results in
+        #  this node and both its operands sharing the same
+        #  result variable. This allows us to avoid some 
+        #  assignments and increfs/decrefs that would otherwise
+        #  be necessary.
+        self.allocate_temp(env, result_code)
+        self.test.allocate_temps(env, result_code)
+        self.true_val.allocate_temps(env, self.result_code)
+        self.false_val.allocate_temps(env, self.result_code)
+        #  We haven't called release_temp on either value,
+        #  because although they are temp nodes, they don't own 
+        #  their result variable. And because they are temp
+        #  nodes, any temps in their subnodes will have been
+        #  released before their allocate_temps returned.
+        #  Therefore, they contain no temp vars that need to
+        #  be released.
+        
+    def compute_result_type(self, type1, type2):
+        if type1 == type2:
+            return type1
+        elif type1.is_pyobject or type2.is_pyobject:
+            return py_object_type
+        elif type1.is_numeric and type2.is_numeric:
+            return PyrexTypes.widest_numeric_type(type1, type2)
+        elif type1.is_extension_type and type1.subtype_of_resolved_type(type2):
+            return type2
+        elif type2.is_extension_type and type2.subtype_of_resolved_type(type1):
+            return type1
+        else:
+            return None
+        
+    def type_error(self):
+        if not (self.true_val.type.is_error or self.false_val.type.is_error):
+            error(self.pos, "Incompatable types in conditional expression (%s; %s)" %
+                (self.true_val.type, self.false_val.type))
+        self.type = PyrexTypes.error_type
+
+    def check_const(self):
+        self.test.check_const()
+        self.true_val.check_const()
+        self.false_val.check_const()
+    
+    def generate_evaluation_code(self, code):
+        self.test.generate_evaluation_code(code)
+        code.putln("if (%s) {" % self.test.result_code )
+        self.true_val.generate_evaluation_code(code)
+        code.putln("} else {")
+        self.false_val.generate_evaluation_code(code)
+        code.putln("}")
+        self.test.generate_disposal_code(code)
+
 class CmpNode:
     #  Mixin class containing code common to PrimaryCmpNodes
     #  and CascadedCmpNodes.
index fa6bb2049c472e9e2f4f1a0454979892e88a9c22..5ef8bfa1c5892338dd403c49ff6abe82cf7fb5af 100644 (file)
@@ -46,9 +46,31 @@ def p_binop_expr(s, ops, p_sub_expr):
         n1 = ExprNodes.binop_node(pos, op, n1, n2)
     return n1
 
-#test: and_test ('or' and_test)* | lambdef
+#expression: or_test [if or_test else test] | lambda_form
 
 def p_simple_expr(s):
+    pos = s.position()
+    expr = p_or_test(s)
+    if s.sy == 'if':
+        s.next()
+        test = p_or_test(s)
+        if s.sy == 'else':
+            s.next()
+            other = p_test(s)
+            return ExprNodes.CondExprNode(pos, test=test, true_val=expr, false_val=other)
+        else:
+            s.error("Expected 'else'")
+    else:
+        return expr
+        
+#test: or_test | lambda_form
+        
+def p_test(s):
+    return p_or_test(s)
+
+#or_test: and_test ('or' and_test)*
+
+def p_or_test(s):
     #return p_binop_expr(s, ('or',), p_and_test)
     return p_rassoc_binop_expr(s, ('or',), p_and_test)
 
@@ -627,7 +649,7 @@ def p_list_if(s):
     # s.sy == 'if'
     pos = s.position()
     s.next()
-    test = p_simple_expr(s)
+    test = p_test(s)
     return Nodes.IfStatNode(pos, 
         if_clauses = [Nodes.IfClauseNode(pos, condition = test, body = p_list_iter(s))],
         else_clause = None )
@@ -658,8 +680,6 @@ def p_backquote_expr(s):
     s.expect('`')
     return ExprNodes.BackquoteNode(pos, arg = arg)
 
-#testlist: test (',' test)* [',']
-
 def p_simple_expr_list(s):
     exprs = []
     while s.sy not in expr_terminators:
@@ -679,6 +699,22 @@ def p_expr(s):
     else:
         return expr
 
+
+#testlist: test (',' test)* [',']
+# differs from p_expr only in the fact that it cannot contain conditional expressions
+
+def p_testlist(s):
+    pos = s.position()
+    expr = p_test(s)
+    if s.sy == ',':
+        exprs = [expr]
+        while s.sy == ',':
+            s.next()
+            exprs.append(p_test(s))
+        return ExprNodes.TupleNode(pos, args = exprs)
+    else:
+        return expr
+        
 expr_terminators = (')', ']', '}', ':', '=', 'NEWLINE')
 
 #-------------------------------------------------------
@@ -1053,7 +1089,7 @@ def p_for_target(s):
 
 def p_for_iterator(s):
     pos = s.position()
-    expr = p_expr(s)
+    expr = p_testlist(s)
     return ExprNodes.IteratorNode(pos, sequence = expr)
 
 def p_try_statement(s):