Default argument literals, better True/False coercion
authorRobert Bradshaw <robertwb@math.washington.edu>
Fri, 8 Feb 2008 04:05:38 +0000 (20:05 -0800)
committerRobert Bradshaw <robertwb@math.washington.edu>
Fri, 8 Feb 2008 04:05:38 +0000 (20:05 -0800)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Nodes.py
Cython/Compiler/Parsing.py

index 4219f400cbf647b741e532c8bf6c498c13888e8e..c06d82a6eed42732347fb574ad40fc3032bd30a0 100644 (file)
@@ -550,6 +550,8 @@ class AtomicExprNode(ExprNode):
 class PyConstNode(AtomicExprNode):
     #  Abstract base class for constant Python values.
     
+    is_literal = 1
+    
     def is_simple(self):
         return 1
     
@@ -571,6 +573,24 @@ class NoneNode(PyConstNode):
     def compile_time_value(self, denv):
         return None
     
+class BoolNode(PyConstNode):
+    #  The constant value True or False
+    
+    def compile_time_value(self, denv):
+        return None
+    
+    def calculate_result_code(self):
+        if self.value:
+            return "Py_True"
+        else:
+            return "Py_False"
+
+    def coerce_to(self, dst_type, env):
+        value = self.value
+        if dst_type.is_numeric:
+            return IntNode(self.pos, value=self.value).coerce_to(dst_type, env)
+        else:
+            return PyConstNode.coerce_to(self, dst_type, env)
 
 class EllipsisNode(PyConstNode):
     #  '...' in a subscript list.
@@ -2148,6 +2168,7 @@ class TupleNode(SequenceNode):
         if len(self.args) == 0:
             self.type = py_object_type
             self.is_temp = 0
+            self.is_literal = 1
         else:
             SequenceNode.analyse_types(self, env)
             
index 9181a23d6f765398a1f2258f9b1ced9805cafbad..641db93605e5af2b455009681310de585f5f2758 100644 (file)
@@ -339,13 +339,15 @@ class CFuncDeclaratorNode(CDeclaratorNode):
                 PyrexTypes.CFuncTypeArg(name, type, arg_node.pos))
             if arg_node.default:
                 self.optional_arg_count += 1
+            elif self.optional_arg_count:
+                error(self.pos, "Non-default argument follows default argument")
         
         if self.optional_arg_count:
             scope = StructOrUnionScope()
             scope.declare_var('n', PyrexTypes.c_int_type, self.pos)
             for arg in func_type_args[len(func_type_args)-self.optional_arg_count:]:
                 scope.declare_var(arg.name, arg.type, arg.pos, allow_pyobject = 1)
-            struct_cname = Naming.opt_arg_prefix + env.mangle(self.base.name)
+            struct_cname = env.mangle(Naming.opt_arg_prefix, self.base.name)
             self.op_args_struct = env.global_scope().declare_struct_or_union(name = struct_cname,
                                         kind = 'struct',
                                         scope = scope,
@@ -396,6 +398,7 @@ class CArgDeclNode(Node):
     # not_none       boolean            Tagged with 'not None'
     # default        ExprNode or None
     # default_entry  Symtab.Entry       Entry for the variable holding the default value
+    # default_result_code string        cname or code fragment for default value
     # is_self_arg    boolean            Is the "self" arg of an extension type method
     # is_kw_only     boolean            Is a keyword-only argument
 
@@ -639,9 +642,14 @@ class FuncDefNode(StatNode, BlockNode):
                     if not hasattr(arg, 'default_entry'):
                         arg.default.analyse_types(genv)
                         arg.default = arg.default.coerce_to(arg.type, genv)
-                        arg.default.allocate_temps(genv)
-                        arg.default_entry = genv.add_default_value(arg.type)
-                        arg.default_entry.used = 1
+                        if arg.default.is_literal:
+                            arg.default_entry = arg.default
+                            arg.default_result_code = arg.default.calculate_result_code()
+                        else:
+                            arg.default.allocate_temps(genv)
+                            arg.default_entry = genv.add_default_value(arg.type)
+                            arg.default_entry.used = 1
+                            arg.default_result_code = arg.default_entry.cname
                 else:
                     error(arg.pos,
                         "This argument cannot have a default value")
@@ -791,16 +799,17 @@ class FuncDefNode(StatNode, BlockNode):
         for arg in self.args:
             default = arg.default
             if default:
-                default.generate_evaluation_code(code)
-                default.make_owned_reference(code)
-                code.putln(
-                    "%s = %s;" % (
-                        arg.default_entry.cname,
-                        default.result_as(arg.default_entry.type)))
-                if default.is_temp and default.type.is_pyobject:
+                if not default.is_literal:
+                    default.generate_evaluation_code(code)
+                    default.make_owned_reference(code)
                     code.putln(
-                        "%s = 0;" %
-                            default.result_code)
+                        "%s = %s;" % (
+                            arg.default_entry.cname,
+                            default.result_as(arg.default_entry.type)))
+                    if default.is_temp and default.type.is_pyobject:
+                        code.putln(
+                            "%s = 0;" %
+                                default.result_code)
         # For Python class methods, create and store function object
         if self.assmt:
             self.assmt.generate_execution_code(code)
@@ -930,7 +939,7 @@ class CFuncDefNode(FuncDefNode):
     def generate_argument_declarations(self, env, code):
         for arg in self.declarator.args:
             if arg.default:
-                code.putln('%s = %s;' % (arg.type.declaration_code(arg.cname), arg.default_entry.cname))
+                    code.putln('%s = %s;' % (arg.type.declaration_code(arg.cname), arg.default_result_code))
 
     def generate_keyword_list(self, code):
         pass
@@ -943,7 +952,10 @@ class CFuncDefNode(FuncDefNode):
             for arg in rev_args:
                 if arg.default:
                     code.putln('if (%s->n > %s) {' % (Naming.optional_args_cname, i))
-                    code.putln('%s = %s->%s;' % (arg.cname, Naming.optional_args_cname, arg.declarator.name))
+                    declarator = arg.declarator
+                    while not hasattr(declarator, 'name'):
+                        declarator = declarator.base
+                    code.putln('%s = %s->%s;' % (arg.cname, Naming.optional_args_cname, declarator.name))
                     i += 1
             for _ in range(self.type.optional_arg_count):
                 code.putln('}')
@@ -1307,7 +1319,7 @@ class DefNode(FuncDefNode):
                         code.putln(
                             "%s = %s;" % (
                                 arg_entry.cname,
-                                arg.default_entry.cname))
+                                arg.default_result_code))
                         if not default_seen:
                             arg_formats.append("|")
                         default_seen = 1
index 99769e3db5ef8a05ec5f0741003f00519352df2f..3489c3dd803b23563b1cb0b614eb66d66980b47a 100644 (file)
@@ -468,6 +468,10 @@ def p_atom(s):
         s.next()
         if name == "None":
             return ExprNodes.NoneNode(pos)
+        elif name == "True":
+            return ExprNodes.BoolNode(pos, value=1)
+        elif name == "False":
+            return ExprNodes.BoolNode(pos, value=0)
         else:
             return p_name(s, name)
     elif sy == 'NULL':