decorator support (partly by Fabrizio Milo)
authorStefan Behnel <scoder@users.berlios.de>
Thu, 10 Jul 2008 21:45:33 +0000 (23:45 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Thu, 10 Jul 2008 21:45:33 +0000 (23:45 +0200)
Cython/CodeWriter.py
Cython/Compiler/Lexicon.py
Cython/Compiler/Main.py
Cython/Compiler/Nodes.py
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/Parsing.py
Cython/Compiler/Tests/TestDecorators.py [new file with mode: 0644]
tests/run/decorators.pyx [new file with mode: 0644]

index bf9e3cc6e8190551e023cf44209d70b5d9637822..2d1928569cf8fa6271a159742e5e07c5a242a7de 100644 (file)
@@ -274,6 +274,16 @@ class CodeWriter(TreeVisitor):
         self.visit(node.body)
         self.dedent()
 
+    def visit_ReturnStatNode(self, node):
+        self.startline("return ")
+        self.visit(node.value)
+        self.endline()
+
+    def visit_DecoratorNode(self, node):
+        self.startline("@")
+        self.visit(node.decorator)
+        self.endline()
+
     def visit_ReraiseStatNode(self, node):
         self.line("raise")
 
index bfc7ea97e629919ad01f8ddab53faca6422cf230..e42bb0377d11365f9668f767464315935001ee11 100644 (file)
@@ -65,6 +65,7 @@ def make_lexicon():
     escapeseq = Str("\\") + (two_oct | three_oct | two_hex |
                              Str('u') + four_hex | Str('x') + two_hex | AnyChar)
     
+    deco = Str("@")
     bra = Any("([{")
     ket = Any(")]}")
     punct = Any(":,;+-*/|&<>=.%`~^?")
@@ -82,6 +83,7 @@ def make_lexicon():
         (longconst, 'LONG'),
         (fltconst, 'FLOAT'),
         (imagconst, 'IMAG'),
+        (deco, 'DECORATOR'),
         (punct | diphthong, TEXT),
         
         (bra, Method('open_bracket_action')),
index dd2243bacd3a6cf05e5b8f1810124c2d1df37bad..3cea71dd713d4d5f1a55b174cb9cdbff26948b41 100644 (file)
@@ -356,7 +356,7 @@ def create_generate_code(context, options, result):
 def create_default_pipeline(context, options, result):
     from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse
     from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
-    from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor
+    from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
     from Buffer import BufferTransform
     from ModuleNode import check_c_classes
     
@@ -365,6 +365,7 @@ def create_default_pipeline(context, options, result):
         NormalizeTree(context),
         PostParse(context),
         WithTransform(context),
+        DecoratorTransform(context),
         AnalyseDeclarationsTransform(context),
         check_c_classes,
         AnalyseExpressionsTransform(context),
index 3a2768f85581159d99834c5ffbe31933cc1c6307..d389bd6c387cbcf41a7e261c6174f5a96b758c8f 100644 (file)
@@ -1235,13 +1235,19 @@ class PyArgDeclNode(Node):
     # entry  Symtab.Entry
     child_attrs = []
     
-    pass
-    
+
+class DecoratorNode(Node):
+    # A decorator
+    #
+    # decorator    NameNode or CallNode
+    child_attrs = ['decorator']
+
 
 class DefNode(FuncDefNode):
     # A Python function definition.
     #
     # name          string                 the Python name of the function
+    # decorators    [DecoratorNode]        list of decorators
     # args          [CArgDeclNode]         formal arguments
     # star_arg      PyArgDeclNode or None  * argument
     # starstar_arg  PyArgDeclNode or None  ** argument
@@ -1253,14 +1259,15 @@ class DefNode(FuncDefNode):
     #
     #  assmt   AssignmentNode   Function construction/assignment
     
-    child_attrs = ["args", "star_arg", "starstar_arg", "body"]
+    child_attrs = ["args", "star_arg", "starstar_arg", "body", "decorators"]
 
     assmt = None
     num_kwonly_args = 0
     num_required_kw_args = 0
     reqd_kw_flags_cname = "0"
     is_wrapper = 0
-    
+    decorators = None
+
     def __init__(self, pos, **kwds):
         FuncDefNode.__init__(self, pos, **kwds)
         k = rk = r = 0
index ed59d62f2eba49314b308acc96687e0837d62f77..1e5b2ee05f5b10fed4de48cc26cde876e00dd7e2 100644 (file)
@@ -217,6 +217,26 @@ class WithTransform(CythonTransform):
         
         return result.stats
 
+class DecoratorTransform(CythonTransform):
+
+    def visit_FuncDefNode(self, func_node):
+        if not func_node.decorators:
+            return func_node
+
+        decorator_result = NameNode(func_node.pos, name = func_node.name)
+        for decorator in func_node.decorators[::-1]:
+            decorator_result = SimpleCallNode(
+                decorator.pos,
+                function = decorator.decorator,
+                args = [decorator_result])
+
+        func_name_node = NameNode(func_node.pos, name = func_node.name)
+        reassignment = SingleAssignmentNode(
+            func_node.pos,
+            lhs = func_name_node,
+            rhs = decorator_result)
+        return [func_node, reassignment]
+
 class AnalyseDeclarationsTransform(CythonTransform):
 
     def __call__(self, root):
index 22b042202f292e73c106f66ca0c89e87e67eea39..8839009d2f6b7dee92333b089e5cedf30974cff5 100644 (file)
@@ -1372,6 +1372,14 @@ def p_statement(s, ctx, first_statement = 0):
         return p_DEF_statement(s)
     elif s.sy == 'IF':
         return p_IF_statement(s, ctx)
+    elif s.sy == 'DECORATOR':
+        if ctx.level not in ('module', 'class', 'c_class', 'property'):
+            s.error('decorator not allowed here')
+        s.level = ctx.level
+        decorators = p_decorators(s)
+        if s.sy != 'def':
+            s.error("Decorators can only be followed by functions ")
+        return p_def_statement(s, decorators)
     else:
         overridable = 0
         if s.sy == 'cdef':
@@ -2103,7 +2111,21 @@ def p_ctypedef_statement(s, ctx):
             declarator = declarator, visibility = visibility,
             in_pxd = ctx.level == 'module_pxd')
 
-def p_def_statement(s):
+def p_decorators(s):
+    decorators = []
+    while s.sy == 'DECORATOR':
+        pos = s.position()
+        s.next()
+        decorator = ExprNodes.NameNode(
+            pos, name = Utils.EncodedString(
+                p_dotted_name(s, as_allowed=0)[2] ))
+        if s.sy == '(':
+            decorator = p_call(s, decorator)
+        decorators.append(Nodes.DecoratorNode(pos, decorator=decorator))
+        s.expect_newline("Expected a newline after decorator")
+    return decorators
+
+def p_def_statement(s, decorators=None):
     # s.sy == 'def'
     pos = s.position()
     s.next()
@@ -2132,7 +2154,7 @@ def p_def_statement(s):
     doc, body = p_suite(s, Ctx(level = 'function'), with_doc = 1)
     return Nodes.DefNode(pos, name = name, args = args, 
         star_arg = star_arg, starstar_arg = starstar_arg,
-        doc = doc, body = body)
+        doc = doc, body = body, decorators = decorators)
 
 def p_py_arg_decl(s):
     pos = s.position()
diff --git a/Cython/Compiler/Tests/TestDecorators.py b/Cython/Compiler/Tests/TestDecorators.py
new file mode 100644 (file)
index 0000000..aabf674
--- /dev/null
@@ -0,0 +1,25 @@
+import unittest\r
+from Cython.TestUtils import TransformTest\r
+from Cython.Compiler.ParseTreeTransforms import DecoratorTransform\r
+\r
+class TestDecorator(TransformTest):\r
+\r
+    def test_decorator(self):\r
+        t = self.run_pipeline([DecoratorTransform(None)], u"""\r
+        def decorator(fun):\r
+            return fun\r
+        @decorator\r
+        def decorated():\r
+            pass\r
+        """)\r
+        \r
+        self.assertCode(u"""\r
+        def decorator(fun):\r
+            return fun\r
+        def decorated():\r
+            pass\r
+        decorated = decorator(decorated)\r
+        """, t)\r
+\r
+if __name__ == '__main__':\r
+    unittest.main()\r
diff --git a/tests/run/decorators.pyx b/tests/run/decorators.pyx
new file mode 100644 (file)
index 0000000..6a5eebe
--- /dev/null
@@ -0,0 +1,49 @@
+__doc__ = u"""
+  >>> f(1,2)
+  4
+  >>> f.HERE
+  1
+
+  >>> g(1,2)
+  5
+  >>> g.HERE
+  5
+
+  >>> h(1,2)
+  6
+  >>> h.HERE
+  1
+"""
+
+class wrap:
+    def __init__(self, func):
+        self.func = func
+        self.HERE = 1
+    def __call__(self, *args, **kwargs):
+        return self.func(*args, **kwargs)
+
+def decorate(func):
+    try:
+        func.HERE += 1
+    except AttributeError:
+        func = wrap(func)
+    return func
+
+def decorate2(a,b):
+    return decorate
+
+@decorate
+def f(a,b):
+    return a+b+1
+
+@decorate
+@decorate
+@decorate
+@decorate
+@decorate
+def g(a,b):
+    return a+b+2
+
+@decorate2(1,2)
+def h(a,b):
+    return a+b+3