From 3bcf81404146fa411793c860bfb45bef06aab547 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Sat, 4 Oct 2008 01:04:48 -0700 Subject: [PATCH] Allow .pxd file to set c signatures for .py files. --- Cython/Compiler/Main.py | 22 ++++- Cython/Compiler/Nodes.py | 124 +++++++++++++++++++++---- Cython/Compiler/ParseTreeTransforms.py | 48 ++++++++++ Cython/Compiler/Symtab.py | 3 + 4 files changed, 177 insertions(+), 20 deletions(-) diff --git a/Cython/Compiler/Main.py b/Cython/Compiler/Main.py index 3c7e9cd4..015146d2 100644 --- a/Cython/Compiler/Main.py +++ b/Cython/Compiler/Main.py @@ -74,12 +74,13 @@ class Context: os.path.join(os.path.dirname(__file__), '..', 'Includes')) self.include_directories = include_directories + [standard_include_path] - def create_pipeline(self, pxd): + def create_pipeline(self, pxd, py=False): from Visitor import PrintTree from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, PxdPostParse from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods + from ParseTreeTransforms import AlignFunctionDefinitions from AutoDocTransforms import EmbedSignature from Optimize import FlattenInListTransform, SwitchTransform, FinalOptimizePhase from Buffer import IntroduceBufferAuxiliaryVars @@ -91,11 +92,17 @@ class Context: else: _check_c_classes = check_c_classes _specific_post_parse = None + + if py and not pxd: + _align_function_definitions = AlignFunctionDefinitions(self) + else: + _align_function_definitions = None return [ NormalizeTree(self), PostParse(self), _specific_post_parse, + _align_function_definitions, InterpretCompilerDirectives(self, self.pragma_overrides), FlattenInListTransform(), WithTransform(self), @@ -112,7 +119,7 @@ class Context: # CreateClosureClasses(context), ] - def create_pyx_pipeline(self, options, result): + def create_pyx_pipeline(self, options, result, py=False): def generate_pyx_code(module_node): module_node.process_implementation(options, result) result.compilation_source = module_node.compilation_source @@ -134,7 +141,7 @@ class Context: return ([ create_parse(self), - ] + self.create_pipeline(pxd=False) + [ + ] + self.create_pipeline(pxd=False, py=py) + [ inject_pxd_code, generate_pyx_code, ]) @@ -154,6 +161,10 @@ class Context: return [parse_pxd] + self.create_pipeline(pxd=True) + [ ExtractPxdCode(self), ] + + def create_py_pipeline(self, options, result): + return self.create_pyx_pipeline(options, result, py=True) + def process_pxd(self, source_desc, scope, module_name): pipeline = self.create_pxd_pipeline(scope, module_name) @@ -504,7 +515,10 @@ def run_pipeline(source, options, full_module_name = None): result = create_default_resultobj(source, options) # Get pipeline - pipeline = context.create_pyx_pipeline(options, result) + if source_desc.filename.endswith(".py"): + pipeline = context.create_py_pipeline(options, result) + else: + pipeline = context.create_pyx_pipeline(options, result) context.setup_errors(options) err, enddata = context.run_pipeline(pipeline, source) diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index b815ec11..846ab5b6 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -412,7 +412,6 @@ class CNameDeclaratorNode(CDeclaratorNode): elif base_type.is_void: error(self.pos, "Use spam() rather than spam(void) to declare a function with no arguments.") else: - print "here" self.name = base_type.declaration_code("", for_display=1, pyrex=1) base_type = py_object_type self.type = base_type @@ -569,23 +568,28 @@ class CArgDeclNode(Node): is_self_arg = 0 is_generic = 1 + type = None + name_declarator = None def analyse(self, env, nonempty = 0): #print "CArgDeclNode.analyse: is_self_arg =", self.is_self_arg ### - # The parser may missinterpret names as types... - # We fix that here. - if isinstance(self.declarator, CNameDeclaratorNode) and self.declarator.name == '': - if nonempty: - self.declarator.name = self.base_type.name - self.base_type.name = None - self.base_type.is_basic_c_type = False - could_be_name = True + if self.type is None: + # The parser may missinterpret names as types... + # We fix that here. + if isinstance(self.declarator, CNameDeclaratorNode) and self.declarator.name == '': + if nonempty: + self.declarator.name = self.base_type.name + self.base_type.name = None + self.base_type.is_basic_c_type = False + could_be_name = True + else: + could_be_name = False + base_type = self.base_type.analyse(env, could_be_name = could_be_name) + if self.base_type.arg_name: + self.declarator.name = self.base_type.arg_name + return self.declarator.analyse(base_type, env, nonempty = nonempty) else: - could_be_name = False - base_type = self.base_type.analyse(env, could_be_name = could_be_name) - if self.base_type.arg_name: - self.declarator.name = self.base_type.arg_name - return self.declarator.analyse(base_type, env, nonempty = nonempty) + return self.name_declarator, self.type def annotate(self, code): if self.default: @@ -601,6 +605,14 @@ class CBaseTypeNode(Node): # Returns the type. pass + +class CAnalysedBaseTypeNode(Node): + # type type + + child_attrs = [] + + def analyse(self, env, could_be_name = False): + return self.type class CSimpleBaseTypeNode(CBaseTypeNode): # name string @@ -1429,6 +1441,8 @@ class DefNode(FuncDefNode): reqd_kw_flags_cname = "0" is_wrapper = 0 decorators = None + entry = None + def __init__(self, pos, **kwds): FuncDefNode.__init__(self, pos, **kwds) @@ -1443,8 +1457,45 @@ class DefNode(FuncDefNode): self.num_kwonly_args = k self.num_required_kw_args = rk self.num_required_args = r - - entry = None + + def as_cfunction(self, cfunc): + if self.star_arg: + error(self.star_arg.pos, "cdef function cannot have star argument") + if self.starstar_arg: + error(self.starstar_arg.pos, "cdef function cannot have starstar argument") + if len(self.args) != len(cfunc.type.args) or cfunc.type.has_varargs: + error(self.pos, "wrong number of arguments") + error(declarator.pos, "previous declaration here") + for formal_arg, type_arg in zip(self.args, cfunc.type.args): + name_declarator, type = formal_arg.analyse(cfunc.scope, nonempty=1) + if type is PyrexTypes.py_object_type or formal_arg.is_self: + formal_arg.type = type_arg.type + formal_arg.name_declarator = name_declarator + import ExprNodes + if cfunc.type.exception_value is None: + exception_value = None + else: + exception_value = ExprNodes.ConstNode(self.pos, value=cfunc.type.exception_value, type=cfunc.type.return_type) + declarator = CFuncDeclaratorNode(self.pos, + base = CNameDeclaratorNode(self.pos, name=self.name, cname=None), + args = self.args, + has_varargs = False, + exception_check = cfunc.type.exception_check, + exception_value = exception_value, + with_gil = cfunc.type.with_gil, + nogil = cfunc.type.nogil) + return CFuncDefNode(self.pos, + modifiers = [], + base_type = CAnalysedBaseTypeNode(self.pos, type=cfunc.type.return_type), + declarator = declarator, + body = self.body, + doc = self.doc, + overridable = cfunc.type.is_overridable, + type = cfunc.type, + with_gil = cfunc.type.with_gil, + nogil = cfunc.type.nogil, + visibility = 'private', + api = False) def analyse_declarations(self, env): if 'locals' in env.directives: @@ -2235,6 +2286,43 @@ class PyClassDefNode(ClassDefNode): bases = bases, dict = self.dict, doc = doc_node) self.target = ExprNodes.NameNode(pos, name = name) + def as_cclass(self): + """ + Return this node as if it were declared as an extension class" + """ + bases = self.classobj.bases.args + if len(bases) == 0: + base_class_name = None + base_class_module = None + elif len(bases) == 1: + base = bases[0] + path = [] + while isinstance(base, ExprNodes.AttributeNode): + path.insert(0, base.attribute) + base = base.obj + if isinstance(base, ExprNodes.NameNode): + path.insert(0, base.name) + base_class_name = path[-1] + if len(path) > 1: + base_class_module = u'.'.join(path[:-1]) + else: + base_class_module = None + else: + error(self.classobj.bases.args.pos, "Invalid base class") + else: + error(self.classobj.bases.args.pos, "C class may only have one base class") + return None + + return CClassDefNode(self.pos, + visibility = 'private', + module_name = None, + class_name = self.name, + base_class_module = base_class_module, + base_class_name = base_class_name, + body = self.body, + in_pxd = False, + doc = self.doc) + def create_scope(self, env): genv = env while env.is_py_class_scope or env.is_c_class_scope: @@ -2297,6 +2385,10 @@ class CClassDefNode(ClassDefNode): child_attrs = ["body"] buffer_defaults_node = None buffer_defaults_pos = None + typedef_flag = False + api = False + objstruct_name = None + typeobj_name = None def analyse_declarations(self, env): #print "CClassDefNode.analyse_declarations:", self.class_name diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index 19894538..1c3ab75e 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -573,6 +573,54 @@ class AnalyseExpressionsTransform(CythonTransform): node.body.analyse_expressions(node.local_scope) self.visitchildren(node) return node + +class AlignFunctionDefinitions(CythonTransform): + """ + This class takes the signatures from a .pxd file and applies them to + the def methods in a .py file. + """ + + def visit_ModuleNode(self, node): + self.scope = node.scope + self.visitchildren(node) + return node + + def visit_PyClassDefNode(self, node): + pxd_def = self.scope.lookup(node.name) + if pxd_def: + if pxd_def.is_cclass: + return self.visit_CClassDefNode(node.as_cclass(), pxd_def) + else: + error(node.pos, "'%s' redeclared" % node.name) + error(pxd_def.pos, "previous declaration here") + return None + self.visitchildren(node) + return node + + def visit_CClassDefNode(self, node, pxd_def=None): + if pxd_def is None: + pxd_def = self.scope.lookup(node.class_name) + if pxd_def: + outer_scope = self.scope + self.scope = pxd_def.type.scope + self.visitchildren(node) + if pxd_def: + self.scope = outer_scope + return node + + def visit_DefNode(self, node): + pxd_def = self.scope.lookup(node.name) + if pxd_def: + if pxd_def.is_cfunction: + node = node.as_cfunction(pxd_def) + else: + error(node.pos, "'%s' redeclared" % node.name) + error(pxd_def.pos, "previous declaration here") + return None + # Enable this when internal def functions are allowed. + # self.visitchildren(node) + return node + class MarkClosureVisitor(CythonTransform): diff --git a/Cython/Compiler/Symtab.py b/Cython/Compiler/Symtab.py index 759a566a..ac144ee7 100644 --- a/Cython/Compiler/Symtab.py +++ b/Cython/Compiler/Symtab.py @@ -56,6 +56,7 @@ class Entry: # is_cmethod boolean Is a C method of an extension type # is_unbound_cmethod boolean Is an unbound C method of an extension type # is_type boolean Is a type definition + # is_cclass boolean Is an extension class # is_const boolean Is a constant # is_property boolean Is a property of an extension type: # doc_cname string or None C const holding the docstring @@ -108,6 +109,7 @@ class Entry: is_cmethod = 0 is_unbound_cmethod = 0 is_type = 0 + is_cclass = 0 is_const = 0 is_property = 0 doc_cname = None @@ -989,6 +991,7 @@ class ModuleScope(Scope): type.typeptr_cname = self.mangle(Naming.typeptr_prefix, name) entry = self.declare_type(name, type, pos, visibility = visibility, defining = 0) + entry.is_cclass = True if objstruct_cname: type.objstruct_cname = objstruct_cname elif not entry.in_cinclude: -- 2.26.2