Allow .pxd file to set c signatures for .py files.
authorRobert Bradshaw <robertwb@math.washington.edu>
Sat, 4 Oct 2008 08:04:48 +0000 (01:04 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Sat, 4 Oct 2008 08:04:48 +0000 (01:04 -0700)
Cython/Compiler/Main.py
Cython/Compiler/Nodes.py
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/Symtab.py

index 3c7e9cd41fcbc75167e50b915a08d8a2f23af137..015146d29ac2c84cb1385b17958e227655dc649b 100644 (file)
@@ -74,12 +74,13 @@ class Context:
             os.path.join(os.path.dirname(__file__), '..', 'Includes'))
         self.include_directories = include_directories + [standard_include_path]
 
-    def create_pipeline(self, pxd):
+    def create_pipeline(self, pxd, py=False):
         from Visitor import PrintTree
         from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, PxdPostParse
         from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
         from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
         from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
+        from ParseTreeTransforms import AlignFunctionDefinitions
         from AutoDocTransforms import EmbedSignature
         from Optimize import FlattenInListTransform, SwitchTransform, FinalOptimizePhase
         from Buffer import IntroduceBufferAuxiliaryVars
@@ -91,11 +92,17 @@ class Context:
         else:
             _check_c_classes = check_c_classes
             _specific_post_parse = None
+            
+        if py and not pxd:
+            _align_function_definitions = AlignFunctionDefinitions(self)
+        else:
+            _align_function_definitions = None
  
         return [
             NormalizeTree(self),
             PostParse(self),
             _specific_post_parse,
+            _align_function_definitions,
             InterpretCompilerDirectives(self, self.pragma_overrides),
             FlattenInListTransform(),
             WithTransform(self),
@@ -112,7 +119,7 @@ class Context:
             #        CreateClosureClasses(context),
             ]
 
-    def create_pyx_pipeline(self, options, result):
+    def create_pyx_pipeline(self, options, result, py=False):
         def generate_pyx_code(module_node):
             module_node.process_implementation(options, result)
             result.compilation_source = module_node.compilation_source
@@ -134,7 +141,7 @@ class Context:
 
         return ([
                 create_parse(self),
-            ] + self.create_pipeline(pxd=False) + [
+            ] + self.create_pipeline(pxd=False, py=py) + [
                 inject_pxd_code,
                 generate_pyx_code,
             ])
@@ -154,6 +161,10 @@ class Context:
         return [parse_pxd] + self.create_pipeline(pxd=True) + [
             ExtractPxdCode(self),
             ]
+            
+    def create_py_pipeline(self, options, result):
+        return self.create_pyx_pipeline(options, result, py=True)
+
 
     def process_pxd(self, source_desc, scope, module_name):
         pipeline = self.create_pxd_pipeline(scope, module_name)
@@ -504,7 +515,10 @@ def run_pipeline(source, options, full_module_name = None):
     result = create_default_resultobj(source, options)
     
     # Get pipeline
-    pipeline = context.create_pyx_pipeline(options, result)
+    if source_desc.filename.endswith(".py"):
+        pipeline = context.create_py_pipeline(options, result)
+    else:
+        pipeline = context.create_pyx_pipeline(options, result)
 
     context.setup_errors(options)
     err, enddata = context.run_pipeline(pipeline, source)
index b815ec11f1a6ca3c578e6cea42dca79c8596939c..846ab5b6729e5569d9be672b289c0cc784fdcf66 100644 (file)
@@ -412,7 +412,6 @@ class CNameDeclaratorNode(CDeclaratorNode):
             elif base_type.is_void:
                 error(self.pos, "Use spam() rather than spam(void) to declare a function with no arguments.")
             else:
-                print "here"
                 self.name = base_type.declaration_code("", for_display=1, pyrex=1)
                 base_type = py_object_type
         self.type = base_type
@@ -569,23 +568,28 @@ class CArgDeclNode(Node):
 
     is_self_arg = 0
     is_generic = 1
+    type = None
+    name_declarator = None
 
     def analyse(self, env, nonempty = 0):
         #print "CArgDeclNode.analyse: is_self_arg =", self.is_self_arg ###
-        # The parser may missinterpret names as types...
-        # We fix that here.
-        if isinstance(self.declarator, CNameDeclaratorNode) and self.declarator.name == '':
-            if nonempty:
-                self.declarator.name = self.base_type.name
-                self.base_type.name = None
-                self.base_type.is_basic_c_type = False
-            could_be_name = True
+        if self.type is None:
+            # The parser may missinterpret names as types...
+            # We fix that here.
+            if isinstance(self.declarator, CNameDeclaratorNode) and self.declarator.name == '':
+                if nonempty:
+                    self.declarator.name = self.base_type.name
+                    self.base_type.name = None
+                    self.base_type.is_basic_c_type = False
+                could_be_name = True
+            else:
+                could_be_name = False
+            base_type = self.base_type.analyse(env, could_be_name = could_be_name)
+            if self.base_type.arg_name:
+                self.declarator.name = self.base_type.arg_name
+            return self.declarator.analyse(base_type, env, nonempty = nonempty)
         else:
-            could_be_name = False
-        base_type = self.base_type.analyse(env, could_be_name = could_be_name)
-        if self.base_type.arg_name:
-            self.declarator.name = self.base_type.arg_name
-        return self.declarator.analyse(base_type, env, nonempty = nonempty)
+            return self.name_declarator, self.type
 
     def annotate(self, code):
         if self.default:
@@ -601,6 +605,14 @@ class CBaseTypeNode(Node):
     #     Returns the type.
     
     pass
+    
+class CAnalysedBaseTypeNode(Node):
+    # type            type
+    
+    child_attrs = []
+    
+    def analyse(self, env, could_be_name = False):
+        return self.type
 
 class CSimpleBaseTypeNode(CBaseTypeNode):
     # name             string
@@ -1429,6 +1441,8 @@ class DefNode(FuncDefNode):
     reqd_kw_flags_cname = "0"
     is_wrapper = 0
     decorators = None
+    entry = None
+    
 
     def __init__(self, pos, **kwds):
         FuncDefNode.__init__(self, pos, **kwds)
@@ -1443,8 +1457,45 @@ class DefNode(FuncDefNode):
         self.num_kwonly_args = k
         self.num_required_kw_args = rk
         self.num_required_args = r
-    
-    entry = None
+        
+    def as_cfunction(self, cfunc):
+        if self.star_arg:
+            error(self.star_arg.pos, "cdef function cannot have star argument")
+        if self.starstar_arg:
+            error(self.starstar_arg.pos, "cdef function cannot have starstar argument")
+        if len(self.args) != len(cfunc.type.args) or cfunc.type.has_varargs:
+            error(self.pos, "wrong number of arguments")
+            error(declarator.pos, "previous declaration here")
+        for formal_arg, type_arg in zip(self.args, cfunc.type.args):
+            name_declarator, type = formal_arg.analyse(cfunc.scope, nonempty=1)
+            if type is PyrexTypes.py_object_type or formal_arg.is_self:
+                formal_arg.type = type_arg.type
+                formal_arg.name_declarator = name_declarator
+        import ExprNodes
+        if cfunc.type.exception_value is None:
+            exception_value = None
+        else:
+            exception_value = ExprNodes.ConstNode(self.pos, value=cfunc.type.exception_value, type=cfunc.type.return_type)
+        declarator = CFuncDeclaratorNode(self.pos, 
+                                         base = CNameDeclaratorNode(self.pos, name=self.name, cname=None),
+                                         args = self.args,
+                                         has_varargs = False,
+                                         exception_check = cfunc.type.exception_check,
+                                         exception_value = exception_value,
+                                         with_gil = cfunc.type.with_gil,
+                                         nogil = cfunc.type.nogil)
+        return CFuncDefNode(self.pos, 
+                            modifiers = [],
+                            base_type = CAnalysedBaseTypeNode(self.pos, type=cfunc.type.return_type),
+                            declarator = declarator,
+                            body = self.body,
+                            doc = self.doc,
+                            overridable = cfunc.type.is_overridable,
+                            type = cfunc.type,
+                            with_gil = cfunc.type.with_gil,
+                            nogil = cfunc.type.nogil,
+                            visibility = 'private',
+                            api = False)
     
     def analyse_declarations(self, env):
         if 'locals' in env.directives:
@@ -2235,6 +2286,43 @@ class PyClassDefNode(ClassDefNode):
             bases = bases, dict = self.dict, doc = doc_node)
         self.target = ExprNodes.NameNode(pos, name = name)
         
+    def as_cclass(self):
+        """
+        Return this node as if it were declared as an extension class"
+        """
+        bases = self.classobj.bases.args
+        if len(bases) == 0:
+            base_class_name = None
+            base_class_module = None
+        elif len(bases) == 1:
+            base = bases[0]
+            path = []
+            while isinstance(base, ExprNodes.AttributeNode):
+                path.insert(0, base.attribute)
+                base = base.obj
+            if isinstance(base, ExprNodes.NameNode):
+                path.insert(0, base.name)
+                base_class_name = path[-1]
+                if len(path) > 1:
+                    base_class_module = u'.'.join(path[:-1])
+                else:
+                    base_class_module = None
+            else:
+                error(self.classobj.bases.args.pos, "Invalid base class")
+        else:
+            error(self.classobj.bases.args.pos, "C class may only have one base class")
+            return None
+        
+        return CClassDefNode(self.pos, 
+                             visibility = 'private',
+                             module_name = None,
+                             class_name = self.name,
+                             base_class_module = base_class_module,
+                             base_class_name = base_class_name,
+                             body = self.body,
+                             in_pxd = False,
+                             doc = self.doc)
+        
     def create_scope(self, env):
         genv = env
         while env.is_py_class_scope or env.is_c_class_scope:
@@ -2297,6 +2385,10 @@ class CClassDefNode(ClassDefNode):
     child_attrs = ["body"]
     buffer_defaults_node = None
     buffer_defaults_pos = None
+    typedef_flag = False
+    api = False
+    objstruct_name = None
+    typeobj_name = None
 
     def analyse_declarations(self, env):
         #print "CClassDefNode.analyse_declarations:", self.class_name
index 19894538e08dc4edd243aa908a86e9c77656f151..1c3ab75e7f235e454c0b85d2ad26ec49bd4d204b 100644 (file)
@@ -573,6 +573,54 @@ class AnalyseExpressionsTransform(CythonTransform):
         node.body.analyse_expressions(node.local_scope)
         self.visitchildren(node)
         return node
+        
+class AlignFunctionDefinitions(CythonTransform):
+    """
+    This class takes the signatures from a .pxd file and applies them to 
+    the def methods in a .py file. 
+    """
+    
+    def visit_ModuleNode(self, node):
+        self.scope = node.scope
+        self.visitchildren(node)
+        return node
+    
+    def visit_PyClassDefNode(self, node):
+        pxd_def = self.scope.lookup(node.name)
+        if pxd_def:
+            if pxd_def.is_cclass:
+                return self.visit_CClassDefNode(node.as_cclass(), pxd_def)
+            else:
+                error(node.pos, "'%s' redeclared" % node.name)
+                error(pxd_def.pos, "previous declaration here")
+                return None
+        self.visitchildren(node)
+        return node
+        
+    def visit_CClassDefNode(self, node, pxd_def=None):
+        if pxd_def is None:
+            pxd_def = self.scope.lookup(node.class_name)
+        if pxd_def:
+            outer_scope = self.scope
+            self.scope = pxd_def.type.scope
+        self.visitchildren(node)
+        if pxd_def:
+            self.scope = outer_scope
+        return node
+        
+    def visit_DefNode(self, node):
+        pxd_def = self.scope.lookup(node.name)
+        if pxd_def:
+            if pxd_def.is_cfunction:
+                node = node.as_cfunction(pxd_def)
+            else:
+                error(node.pos, "'%s' redeclared" % node.name)
+                error(pxd_def.pos, "previous declaration here")
+                return None
+        # Enable this when internal def functions are allowed. 
+        # self.visitchildren(node)
+        return node
+        
 
 class MarkClosureVisitor(CythonTransform):
     
index 759a566a7c8830fbda6888a3684363e385760297..ac144ee750e47c77fb8af422c6690c29693be4dd 100644 (file)
@@ -56,6 +56,7 @@ class Entry:
     # is_cmethod       boolean    Is a C method of an extension type
     # is_unbound_cmethod boolean  Is an unbound C method of an extension type
     # is_type          boolean    Is a type definition
+    # is_cclass        boolean    Is an extension class
     # is_const         boolean    Is a constant
     # is_property      boolean    Is a property of an extension type:
     # doc_cname        string or None  C const holding the docstring
@@ -108,6 +109,7 @@ class Entry:
     is_cmethod = 0
     is_unbound_cmethod = 0
     is_type = 0
+    is_cclass = 0
     is_const = 0
     is_property = 0
     doc_cname = None
@@ -989,6 +991,7 @@ class ModuleScope(Scope):
             type.typeptr_cname = self.mangle(Naming.typeptr_prefix, name)
             entry = self.declare_type(name, type, pos, visibility = visibility,
                 defining = 0)
+            entry.is_cclass = True
             if objstruct_cname:
                 type.objstruct_cname = objstruct_cname
             elif not entry.in_cinclude: