Module-level cpdef functions
authorRobert Bradshaw <robertwb@math.washington.edu>
Wed, 13 Feb 2008 12:13:44 +0000 (04:13 -0800)
committerRobert Bradshaw <robertwb@math.washington.edu>
Wed, 13 Feb 2008 12:13:44 +0000 (04:13 -0800)
Cython/Compiler/Nodes.py
Cython/Compiler/Options.py
Cython/Compiler/Parsing.py
Cython/Compiler/Symtab.py

index 3360f82703dd0f5af651087ee8448a72d217bbe6..051bfd7e856c66779663c13cd0a93cb8fddcf665 100644 (file)
@@ -867,7 +867,7 @@ class CFuncDefNode(FuncDefNode):
         
         if self.overridable:
             import ExprNodes
-            py_func_body = self.call_self_node()
+            py_func_body = self.call_self_node(is_module_scope = env.is_module_scope)
             self.py_func = DefNode(pos = self.pos, 
                                    name = self.declarator.base.name,
                                    args = self.declarator.args,
@@ -875,23 +875,30 @@ class CFuncDefNode(FuncDefNode):
                                    starstar_arg = None,
                                    doc = self.doc,
                                    body = py_func_body)
+            self.py_func.is_module_scope = env.is_module_scope
             self.py_func.analyse_declarations(env)
+            self.entry.as_variable = self.py_func.entry
             # Reset scope entry the above cfunction
             env.entries[name] = self.entry
             if Options.intern_names:
                 self.py_func.interned_attr_cname = env.intern(self.py_func.entry.name)
-            self.override = OverrideCheckNode(self.pos, py_func = self.py_func)
-            self.body = StatListNode(self.pos, stats=[self.override, self.body])
+            if not env.is_module_scope or Options.lookup_module_cpdef:
+                self.override = OverrideCheckNode(self.pos, py_func = self.py_func)
+                self.body = StatListNode(self.pos, stats=[self.override, self.body])
     
-    def call_self_node(self, omit_optional_args=0):
+    def call_self_node(self, omit_optional_args=0, is_module_scope=0):
         import ExprNodes
         args = self.type.args
         if omit_optional_args:
             args = args[:len(args) - self.type.optional_arg_count]
         arg_names = [arg.name for arg in args]
-        self_arg = ExprNodes.NameNode(self.pos, name=arg_names[0])
-        cfunc = ExprNodes.AttributeNode(self.pos, obj=self_arg, attribute=self.declarator.base.name)
-        c_call = ExprNodes.SimpleCallNode(self.pos, function=cfunc, args=[ExprNodes.NameNode(self.pos, name=n) for n in arg_names[1:]], wrapper_call=True)
+        if is_module_scope:
+            cfunc = ExprNodes.NameNode(self.pos, name=self.declarator.base.name)
+        else:
+            self_arg = ExprNodes.NameNode(self.pos, name=arg_names[0])
+            cfunc = ExprNodes.AttributeNode(self.pos, obj=self_arg, attribute=self.declarator.base.name)
+        skip_dispatch = not is_module_scope or Options.lookup_module_cpdef
+        c_call = ExprNodes.SimpleCallNode(self.pos, function=cfunc, args=[ExprNodes.NameNode(self.pos, name=n) for n in arg_names[1-is_module_scope:]], wrapper_call=skip_dispatch)
         return ReturnStatNode(pos=self.pos, return_type=PyrexTypes.py_object_type, value=c_call)
                     
     def declare_arguments(self, env):
@@ -1667,12 +1674,16 @@ class OverrideCheckNode(StatNode):
     
     def analyse_expressions(self, env):
         self.args = env.arg_entries
+        if self.py_func.is_module_scope:
+            first_arg = 0
+        else:
+            first_arg = 1
         import ExprNodes
         self.func_node = ExprNodes.PyTempNode(self.pos, env)
-        call_tuple = ExprNodes.TupleNode(self.pos, args=[ExprNodes.NameNode(self.pos, name=arg.name) for arg in self.args[1:]])
+        call_tuple = ExprNodes.TupleNode(self.pos, args=[ExprNodes.NameNode(self.pos, name=arg.name) for arg in self.args[first_arg:]])
         call_node = ExprNodes.SimpleCallNode(self.pos,
                                              function=self.func_node, 
-                                             args=[ExprNodes.NameNode(self.pos, name=arg.name) for arg in self.args[1:]])
+                                             args=[ExprNodes.NameNode(self.pos, name=arg.name) for arg in self.args[first_arg:]])
         self.body = ReturnStatNode(self.pos, value=call_node)
 #        self.func_temp = env.allocate_temp_pyobject()
         self.body.analyse_expressions(env)
@@ -1680,11 +1691,17 @@ class OverrideCheckNode(StatNode):
         
     def generate_execution_code(self, code):
         # Check to see if we are an extension type
-        self_arg = "((PyObject *)%s)" % self.args[0].cname
+        if self.py_func.is_module_scope:
+            self_arg = "((PyObject *)%s)" % Naming.module_cname
+        else:
+            self_arg = "((PyObject *)%s)" % self.args[0].cname
         code.putln("/* Check if called by wrapper */")
         code.putln("if (unlikely(%s)) %s = 0;" % (Naming.skip_dispatch_cname, Naming.skip_dispatch_cname))
         code.putln("/* Check if overriden in Python */")
-        code.putln("else if (unlikely(%s->ob_type->tp_dictoffset != 0)) {" % self_arg)
+        if self.py_func.is_module_scope:
+            code.putln("else {")
+        else:
+            code.putln("else if (unlikely(%s->ob_type->tp_dictoffset != 0)) {" % self_arg)
         err = code.error_goto_if_null(self_arg, self.pos)
         # need to get attribute manually--scope would return cdef method
         if Options.intern_names:
index 9a2120e912f002b69714e266692191b83e1f673b..d7cc0974add5b66f30823e2ac141a9245204373f 100644 (file)
@@ -33,3 +33,9 @@ annotate = 0
 # raised before the loop is entered, wheras without this option the loop
 # will execute util a overflowing value is encountered. 
 convert_range = 0
+
+# Enable this to allow one to write your_module.foo = ... to overwrite the 
+# definition if the cpdef function foo, at the cost of an extra dictionary 
+# lookup on every call. 
+# If this is 0 it simply creates a wrapper. 
+lookup_module_cpdef = 0
index 8635642c81cbb407efc54546a751c635fc6cfc86..a430a0c9b157c57587aa30949f492de38f9547a3 100644 (file)
@@ -1722,8 +1722,6 @@ def p_api(s):
 def p_cdef_statement(s, level, visibility = 'private', api = 0,
                      overridable = False):
     pos = s.position()
-    if overridable and level not in ('c_class', 'c_class_pxd'):
-            error(pos, "Overridable cdef function not allowed here")
     visibility = p_visibility(s, visibility)
     api = api or p_api(s)
     if api:
index aadf184358fc9378d79fa4c6528c9d9b38356350..740f9f23f38c276ad894b260078408d8229bab38 100644 (file)
@@ -148,6 +148,7 @@ class Scope:
 
     is_py_class_scope = 0
     is_c_class_scope = 0
+    is_module_scope = 0
     scope_prefix = ""
     in_cinclude = 0
     
@@ -673,6 +674,8 @@ class ModuleScope(Scope):
     # interned_names       [string]           Interned names pending generation of declarations
     # all_pystring_entries [Entry]            Python string consts from all scopes
     # types_imported       {PyrexType : 1}    Set of types for which import code generated
+    
+    is_module_scope = 1
 
     def __init__(self, name, parent_module, context):
         self.parent_module = parent_module