decorators for cdef functions, remove strange pxd locals syntax
authorRobert Bradshaw <robertwb@math.washington.edu>
Sun, 1 Mar 2009 11:04:37 +0000 (03:04 -0800)
committerRobert Bradshaw <robertwb@math.washington.edu>
Sun, 1 Mar 2009 11:04:37 +0000 (03:04 -0800)
Cython/Compiler/Nodes.py
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/Parsing.py

index 7990c335828be5c5c77f1ad960a556e5d59f8161..76cef6e484ec867031bf69cab02a220d6f370998 100644 (file)
@@ -771,11 +771,12 @@ class CVarDefNode(StatNode):
     #  in_pxd        boolean
     #  api           boolean
     #  need_properties [entry]
-    #  pxd_locals    [CVarDefNode]  (used for functions declared in pxd)
+
+    #  directive_locals { string : NameNode } locals defined by cython.locals(...)
 
     child_attrs = ["base_type", "declarators"]
     need_properties = ()
-    pxd_locals = []
+    directive_locals = {}
     
     def analyse_declarations(self, env, dest_scope = None):
         if not dest_scope:
@@ -812,8 +813,10 @@ class CVarDefNode(StatNode):
                     cname = cname, visibility = self.visibility, in_pxd = self.in_pxd,
                     api = self.api)
                 if entry is not None:
-                    entry.pxd_locals = self.pxd_locals
+                    entry.directive_locals = self.directive_locals
             else:
+                if self.directive_locals:
+                    s.error("Decorators can only be followed by functions")
                 if self.in_pxd and self.visibility != 'extern':
                     error(self.pos, 
                         "Only 'extern' C variable declaration allowed in .pxd file")
@@ -969,12 +972,11 @@ class FuncDefNode(StatNode, BlockNode):
     #  #filename        string        C name of filename string const
     #  entry           Symtab.Entry
     #  needs_closure   boolean        Whether or not this function has inner functions/classes/yield
-    #  pxd_locals      [CVarDefNode]   locals defined in the pxd
+    #  directive_locals { string : NameNode } locals defined by cython.locals(...)
     
     py_func = None
     assmt = None
     needs_closure = False
-    pxd_locals = []
     
     def analyse_default_values(self, env):
         genv = env.global_scope()
@@ -1280,6 +1282,7 @@ class CFuncDefNode(FuncDefNode):
     #  declarator    CDeclaratorNode
     #  body          StatListNode
     #  api           boolean
+    #  decorators    [DecoratorNode]        list of decorators
     #
     #  with_gil      boolean    Acquire GIL around body
     #  type          CFuncType
@@ -1290,16 +1293,16 @@ class CFuncDefNode(FuncDefNode):
     child_attrs = ["base_type", "declarator", "body", "py_func"]
 
     inline_in_pxd = False
+    decorators = None
+    directive_locals = {}
 
     def unqualified_name(self):
         return self.entry.name
         
     def analyse_declarations(self, env):
-        if 'locals' in env.directives:
-            directive_locals = env.directives['locals']
-        else:
-            directive_locals = {}
-        self.directive_locals = directive_locals
+        if 'locals' in env.directives and env.directives['locals']:
+            self.directive_locals = env.directives['locals']
+        directive_locals = self.directive_locals
         base_type = self.base_type.analyse(env)
         # The 2 here is because we need both function and argument names. 
         name_declarator, type = self.declarator.analyse(base_type, env, nonempty = 2 * (self.body is not None))
@@ -1595,7 +1598,7 @@ class DefNode(FuncDefNode):
                                               nogil = False,
                                               with_gil = False,
                                               is_overridable = True)
-            cfunc = CVarDefNode(self.pos, type=cfunc_type, pxd_locals=[])
+            cfunc = CVarDefNode(self.pos, type=cfunc_type)
         else:
             cfunc_type = cfunc.type
             if len(self.args) != len(cfunc_type.args) or cfunc_type.has_varargs:
@@ -1631,7 +1634,7 @@ class DefNode(FuncDefNode):
                             nogil = cfunc_type.nogil,
                             visibility = 'private',
                             api = False,
-                            pxd_locals = cfunc.pxd_locals)
+                            directive_locals = cfunc.directive_locals)
     
     def analyse_declarations(self, env):
         if 'locals' in env.directives:
index fbd1036010e82d8cf6b3906d293fc3084bfb0287..dc0560c1acaf109713d752546f3eba5943b06667 100644 (file)
@@ -294,19 +294,8 @@ class PxdPostParse(CythonTransform, SkipDeclarations):
                 else:
                     err = None # allow inline function
             else:
-                err = None
-                for stat in node.body.stats:
-                    if not isinstance(stat, CVarDefNode):
-                        err = self.ERR_INLINE_ONLY
-                        break
-                node = CVarDefNode(node.pos, 
-                                   visibility = node.visibility,
-                                   base_type = node.base_type,
-                                   declarators = [node.declarator],
-                                   in_pxd = True,
-                                   api = node.api,
-                                   overridable = node.overridable, 
-                                   pxd_locals = node.body.stats)
+                err = self.ERR_INLINE_ONLY
+
         if err:
             self.context.nonfatal_error(PostParseError(node.pos, err))
             return None
@@ -462,7 +451,7 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
         return directive
  
     # Handle decorators
-    def visit_DefNode(self, node):
+    def visit_FuncDefNode(self, node):
         options = []
         
         if node.decorators:
@@ -474,7 +463,10 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
                     options.append(option)
                 else:
                     realdecs.append(dec)
-            node.decorators = realdecs
+            if realdecs and isinstance(node, CFuncDefNode):
+                raise PostParseError(realdecs[0].pos, "Cdef functions cannot take arbitrary decorators.")
+            else:
+                node.decorators = realdecs
         
         if options:
             optdict = {}
@@ -486,6 +478,16 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
             return self.visit_with_options(body, optdict)
         else:
             return self.visit_Node(node)
+    
+    def visit_CVarDefNode(self, node):
+        if node.decorators:
+            for dec in node.decorators:
+                option = self.try_to_parse_option(dec.decorator)
+                if option is not None and option[0] == u'locals':
+                    node.directive_locals = option[1]
+                else:
+                    raise PostParseError(dec.pos, "Cdef functions can only take cython.locals() decorator.")
+        return node
                                    
     # Handle with statements
     def visit_WithStatNode(self, node):
@@ -686,8 +688,6 @@ property NAME:
                     lenv.declare_var(var, type, type_node.pos)
                 else:
                     error(type_node.pos, "Not a type")
-        for stat in node.pxd_locals:
-            stat.analyse_declarations(lenv)
         node.body.analyse_declarations(lenv)
         self.env_stack.append(lenv)
         self.visitchildren(node)
index 65cd6ca2c41893a8fdd46a0de64428bf7adf4a78..e6415fb50e6d98bfb8da46338596579930ffb248 100644 (file)
@@ -1479,6 +1479,7 @@ def p_IF_statement(s, ctx):
 
 def p_statement(s, ctx, first_statement = 0):
     cdef_flag = ctx.cdef_flag
+    decorators = []
     if s.sy == 'ctypedef':
         if ctx.level not in ('module', 'module_pxd'):
             s.error("ctypedef statement not allowed here")
@@ -1490,63 +1491,67 @@ def p_statement(s, ctx, first_statement = 0):
     elif s.sy == 'IF':
         return p_IF_statement(s, ctx)
     elif s.sy == 'DECORATOR':
-        if ctx.level not in ('module', 'class', 'c_class', 'property'):
+        if ctx.level not in ('module', 'class', 'c_class', 'property', 'module_pxd', 'class_pxd'):
             s.error('decorator not allowed here')
         s.level = ctx.level
         decorators = p_decorators(s)
-        if s.sy != 'def':
+        if s.sy not in ('def', 'cdef', 'cpdef'):
             s.error("Decorators can only be followed by functions ")
-        return p_def_statement(s, decorators)
+
+    overridable = 0
+    if s.sy == 'cdef':
+        cdef_flag = 1
+        s.next()
+    if s.sy == 'cpdef':
+        cdef_flag = 1
+        overridable = 1
+        s.next()
+    if cdef_flag:
+        if ctx.level not in ('module', 'module_pxd', 'function', 'c_class', 'c_class_pxd'):
+            s.error('cdef statement not allowed here')
+        s.level = ctx.level
+        node = p_cdef_statement(s, ctx(overridable = overridable))
+        if decorators is not None:
+            if not isinstance(node, (Nodes.CFuncDefNode, Nodes.CVarDefNode)):
+                s.error("Decorators can only be followed by functions ")
+            node.decorators = decorators
+        return node
     else:
-        overridable = 0
-        if s.sy == 'cdef':
-            cdef_flag = 1
-            s.next()
-        if s.sy == 'cpdef':
-            cdef_flag = 1
-            overridable = 1
-            s.next()
-        if cdef_flag:
-            if ctx.level not in ('module', 'module_pxd', 'function', 'c_class', 'c_class_pxd'):
-                s.error('cdef statement not allowed here')
+        if ctx.api:
+            error(s.pos, "'api' not allowed with this statement")
+        elif s.sy == 'def':
+            if ctx.level not in ('module', 'class', 'c_class', 'c_class_pxd', 'property'):
+                s.error('def statement not allowed here')
             s.level = ctx.level
-            return p_cdef_statement(s, ctx(overridable = overridable))
+            return p_def_statement(s, decorators)
+        elif s.sy == 'class':
+            if ctx.level != 'module':
+                s.error("class definition not allowed here")
+            return p_class_statement(s)
+        elif s.sy == 'include':
+            if ctx.level not in ('module', 'module_pxd'):
+                s.error("include statement not allowed here")
+            return p_include_statement(s, ctx)
+        elif ctx.level == 'c_class' and s.sy == 'IDENT' and s.systring == 'property':
+            return p_property_decl(s)
+        elif s.sy == 'pass' and ctx.level != 'property':
+            return p_pass_statement(s, with_newline = 1)
         else:
-            if ctx.api:
-                error(s.pos, "'api' not allowed with this statement")
-            elif s.sy == 'def':
-                if ctx.level not in ('module', 'class', 'c_class', 'c_class_pxd', 'property'):
-                    s.error('def statement not allowed here')
-                s.level = ctx.level
-                return p_def_statement(s)
-            elif s.sy == 'class':
-                if ctx.level != 'module':
-                    s.error("class definition not allowed here")
-                return p_class_statement(s)
-            elif s.sy == 'include':
-                if ctx.level not in ('module', 'module_pxd'):
-                    s.error("include statement not allowed here")
-                return p_include_statement(s, ctx)
-            elif ctx.level == 'c_class' and s.sy == 'IDENT' and s.systring == 'property':
-                return p_property_decl(s)
-            elif s.sy == 'pass' and ctx.level != 'property':
-                return p_pass_statement(s, with_newline = 1)
+            if ctx.level in ('c_class_pxd', 'property'):
+                s.error("Executable statement not allowed here")
+            if s.sy == 'if':
+                return p_if_statement(s)
+            elif s.sy == 'while':
+                return p_while_statement(s)
+            elif s.sy == 'for':
+                return p_for_statement(s)
+            elif s.sy == 'try':
+                return p_try_statement(s)
+            elif s.sy == 'with':
+                return p_with_statement(s)
             else:
-                if ctx.level in ('c_class_pxd', 'property'):
-                    s.error("Executable statement not allowed here")
-                if s.sy == 'if':
-                    return p_if_statement(s)
-                elif s.sy == 'while':
-                    return p_while_statement(s)
-                elif s.sy == 'for':
-                    return p_for_statement(s)
-                elif s.sy == 'try':
-                    return p_try_statement(s)
-                elif s.sy == 'with':
-                    return p_with_statement(s)
-                else:
-                    return p_simple_statement_list(
-                        s, ctx, first_statement = first_statement)
+                return p_simple_statement_list(
+                    s, ctx, first_statement = first_statement)
 
 def p_statement_list(s, ctx, first_statement = 0):
     # Parse a series of statements separated by newlines.