Buffer parsing complete; small transform factorizations and renaming of PostParse
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Tue, 1 Jul 2008 20:22:10 +0000 (22:22 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Tue, 1 Jul 2008 20:22:10 +0000 (22:22 +0200)
Cython/Compiler/Errors.py
Cython/Compiler/Main.py
Cython/Compiler/Nodes.py
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/Parsing.py
Cython/Compiler/Tests/TestParseTreeTransforms.py
Cython/Compiler/Visitor.py
Cython/TestUtils.py

index 814146274d5031b6151c88cfaa427c5adbb488c5..bc105fafa97d33cf3ae218b457e0771e94e61d4e 100644 (file)
@@ -32,6 +32,7 @@ class CompileError(PyrexError):
     
     def __init__(self, position = None, message = ""):
         self.position = position
+        self.message_only = message
     # Deprecated and withdrawn in 2.6:
     #   self.message = message
         if position:
index 5c51e632bf47caf408c2cd64be78d246c7508818..e3fbcb313a0471b3176065a03ac7edcb0aa8e531 100644 (file)
@@ -334,17 +334,18 @@ def create_generate_code(context, options, result):
     return generate_code
 
 def create_default_pipeline(context, options, result):
-    from ParseTreeTransforms import WithTransform, PostParse
+    from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse
     from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
     from ModuleNode import check_c_classes
     
     return [
         create_parse(context),
-        PostParse(),
-        WithTransform(),
-        AnalyseDeclarationsTransform(),
+        NormalizeTree(context),
+        PostParse(context),
+        WithTransform(context),
+        AnalyseDeclarationsTransform(context),
         check_c_classes,
-        AnalyseExpressionsTransform(),
+        AnalyseExpressionsTransform(context),
         create_generate_code(context, options, result)
     ]
 
index 75db65e55a3cd07b85e769770107f5a78dbdcbb5..698d57c692593adbdcbc93cf87d967ad1f6b433b 100644 (file)
@@ -565,7 +565,6 @@ class CBaseTypeNode(Node):
     
     pass
 
-
 class CSimpleBaseTypeNode(CBaseTypeNode):
     # name             string
     # module_path      [string]     Qualifying name components
index dbb30d246eee8cdbeb80a91072b44a4502fdbe28..b31642dd2137c5e5fd4e09179eeba64ed7ff750e 100644 (file)
@@ -1,10 +1,13 @@
-from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle
+from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform
 from Cython.Compiler.Nodes import *
 from Cython.Compiler.ExprNodes import *
 from Cython.Compiler.TreeFragment import TreeFragment
+from Cython.Utils import EncodedString
+from Cython.Compiler.Errors import CompileError
+from sets import Set as set
 
 
-class PostParse(VisitorTransform):
+class NormalizeTree(CythonTransform):
     """
     This transform fixes up a few things after parsing
     in order to make the parse tree more suitable for
@@ -25,15 +28,11 @@ class PostParse(VisitorTransform):
     StatListNode has no children to see if the block is empty).
     """
 
-    def __init__(self):
-        super(PostParse, self).__init__()
+    def __init__(self, context):
+        super(NormalizeTree, self).__init__(context)
         self.is_in_statlist = False
         self.is_in_expr = False
 
-    def visit_Node(self, node):
-        self.visitchildren(node)
-        return node
-
     def visit_ExprNode(self, node):
         stacktmp = self.is_in_expr
         self.is_in_expr = True
@@ -73,7 +72,80 @@ class PostParse(VisitorTransform):
         return self.visit_StatNode(node, True)
 
 
-class WithTransform(VisitorTransform):
+class PostParseError(CompileError): pass
+
+# error strings checked by unit tests, so define them
+ERR_BUF_OPTION_UNKNOWN = '"%s" is not a buffer option'
+ERR_BUF_TOO_MANY = 'Too many buffer options'
+ERR_BUF_DUP = '"%s" buffer option already supplied'
+ERR_BUF_MISSING = '"%s" missing'
+ERR_BUF_INT = '"%s" must be an integer'
+ERR_BUF_NONNEG = '"%s" must be non-negative'
+
+class PostParse(CythonTransform):
+    """
+    Basic interpretation of the parse tree, as well as validity
+    checking that can be done on a very basic level on the parse
+    tree (while still not being a problem with the basic syntax,
+    as such).
+
+    Specifically:
+    - CBufferAccessTypeNode has its options interpreted:
+    Any first positional argument goes into the "dtype" attribute,
+    any "ndim" keyword argument goes into the "ndim" attribute and
+    so on. Also it is checked that the option combination is valid.
+
+    Note: Currently Parsing.py does a lot of interpretation and
+    reorganization that can be refactored into this transform
+    if a more pure Abstract Syntax Tree is wanted.
+    """
+
+    buffer_options = ("dtype", "ndim") # ordered!
+    def visit_CBufferAccessTypeNode(self, node):
+        options = {}
+        # Fetch positional arguments
+        if len(node.positional_args) > len(self.buffer_options):
+            self.context.error(ERR_BUF_TOO_MANY)
+        for arg, unicode_name in zip(node.positional_args, self.buffer_options):
+            name = str(unicode_name)
+            options[name] = arg
+        # Fetch named arguments
+        for item in node.keyword_args.key_value_pairs:
+            name = str(item.key.value)
+            if not name in self.buffer_options:
+                raise PostParseError(item.key.pos,
+                                     ERR_BUF_UNKNOWN % name)
+            if name in options.keys():
+                raise PostParseError(item.key.pos,
+                                     ERR_BUF_DUP % key)
+            options[name] = item.value
+
+        provided = options.keys()
+        # get dtype
+        dtype = options.get("dtype")
+        if dtype is None: raise PostParseError(node.pos, ERR_BUF_MISSING % 'dtype')
+        node.dtype = dtype
+
+        # get ndim
+        if "ndim" in provided:
+            ndimnode = options["ndim"]
+            if not isinstance(ndimnode, IntNode):
+                # Compile-time values (DEF) are currently resolved by the parser,
+                # so nothing more to do here
+                raise PostParseError(ndimnode.pos, ERR_BUF_INT % 'ndim')
+            ndim_value = int(ndimnode.value)
+            if ndim_value < 0:
+                raise PostParseError(ndimnode.pos, ERR_BUF_NONNEG % 'ndim')
+            node.ndim = int(ndimnode.value)
+        
+        # We're done with the parse tree args
+        node.positional_args = None
+        node.keyword_args = None
+        return node
+
+
+
+class WithTransform(CythonTransform):
 
     # EXCINFO is manually set to a variable that contains
     # the exc_info() tuple that can be generated by the enclosing except
@@ -94,7 +166,7 @@ class WithTransform(VisitorTransform):
             if EXC:
                 EXIT(None, None, None)
     """, temps=[u'MGR', u'EXC', u"EXIT", u"SYS)"],
-    pipeline=[PostParse()])
+    pipeline=[NormalizeTree(None)])
 
     template_with_target = TreeFragment(u"""
         MGR = EXPR
@@ -113,11 +185,7 @@ class WithTransform(VisitorTransform):
             if EXC:
                 EXIT(None, None, None)
     """, temps=[u'MGR', u'EXC', u"EXIT", u"VALUE", u"SYS"],
-    pipeline=[PostParse()])
-
-    def visit_Node(self, node):
-       self.visitchildren(node)
-        return node
+    pipeline=[NormalizeTree(None)])
 
     def visit_WithStatNode(self, node):
         excinfo_name = temp_name_handle('EXCINFO')
@@ -143,7 +211,7 @@ class WithTransform(VisitorTransform):
         
         return result.stats
 
-class AnalyseDeclarationsTransform(VisitorTransform):
+class AnalyseDeclarationsTransform(CythonTransform):
 
     def __call__(self, root):
         self.env_stack = [root.scope]
@@ -164,12 +232,7 @@ class AnalyseDeclarationsTransform(VisitorTransform):
         self.env_stack.pop()
         return node
         
-    def visit_Node(self, node):
-        self.visitchildren(node)
-        return node
-
-
-class AnalyseExpressionsTransform(VisitorTransform):
+class AnalyseExpressionsTransform(CythonTransform):
 
     def visit_ModuleNode(self, node):
         node.body.analyse_expressions(node.scope)
@@ -181,7 +244,4 @@ class AnalyseExpressionsTransform(VisitorTransform):
         self.visitchildren(node)
         return node
         
-    def visit_Node(self, node):
-        self.visitchildren(node)
-        return node
 
index 2308553af6c98d480daac322b54428733bc3b175..de9c7810a745f25e589b22c4c6ab222ef5e003d3 100644 (file)
@@ -281,20 +281,18 @@ def p_trailer(s, node1):
         return ExprNodes.AttributeNode(pos, 
             obj = node1, attribute = name)
 
-def p_positional_and_keyword_callargs(s, end_sy_set):
-    """
-    Parses positional and keyword call arguments. end_sy_set
-    should contain any s.sy that terminate the argument chain
-    (this is ('*', '**', ')') for a normal function call,
-    and (']',) for buffers declarators).
+# arglist:  argument (',' argument)* [',']
+# argument: [test '='] test       # Really [keyword '='] test
 
-    Returns: (positional_args, keyword_args)
-    """
+def p_call(s, function):
+    # s.sy == '('
+    pos = s.position()
+    s.next()
     positional_args = []
     keyword_args = []
-    while s.sy not in end_sy_set:
-        if s.sy == '*' or s.sy == '**':
-            s.error('Argument expansion not allowed here.')
+    star_arg = None
+    starstar_arg = None
+    while s.sy not in ('*', '**', ')'):
         arg = p_simple_expr(s)
         if s.sy == '=':
             s.next()
@@ -314,20 +312,6 @@ def p_positional_and_keyword_callargs(s, end_sy_set):
         if s.sy != ',':
             break
         s.next()
-    return positional_args, keyword_args
-
-
-# arglist:  argument (',' argument)* [',']
-# argument: [test '='] test       # Really [keyword '='] test
-
-def p_call(s, function):
-    # s.sy == '('
-    pos = s.position()
-    s.next()
-    star_arg = None
-    starstar_arg = None
-    positional_args, keyword_args = (
-        p_positional_and_keyword_callargs(s,('*', '**', ')')))
 
     if s.sy == '*':
         s.next()
@@ -1473,6 +1457,71 @@ def p_suite(s, ctx = Ctx(), with_doc = 0, with_pseudo_doc = 0):
     else:
         return body
 
+def p_positional_and_keyword_args(s, end_sy_set, type_positions=(), type_keywords=()):
+    """
+    Parses positional and keyword arguments. end_sy_set
+    should contain any s.sy that terminate the argument list.
+    Argument expansion (* and **) are not allowed.
+
+    type_positions and type_keywords specifies which argument
+    positions and/or names which should be interpreted as
+    types. Other arguments will be treated as expressions.
+
+    Returns: (positional_args, keyword_args)
+    """
+    positional_args = []
+    keyword_args = []
+    pos_idx = 0
+
+    while s.sy not in end_sy_set:
+        if s.sy == '*' or s.sy == '**':
+            s.error('Argument expansion not allowed here.')
+
+        was_keyword = False
+        parsed_type = False
+        if s.sy == 'IDENT':
+            # Since we can have either types or expressions as positional args,
+            # we use a strategy of looking an extra step forward for a '=' and
+            # if it is a positional arg we backtrack.
+            ident = s.systring
+            s.next()
+            if s.sy == '=':
+                s.next()
+                # Is keyword arg
+                if ident in type_keywords:
+                    arg = p_c_base_type(s)
+                    parsed_type = True
+                else:
+                    arg = p_simple_expr(s)
+                keyword_node = ExprNodes.IdentifierStringNode(arg.pos,
+                                value = Utils.EncodedString(ident))
+                keyword_args.append((keyword_node, arg))
+                was_keyword = True
+            else:
+                s.put_back('IDENT', ident)
+                
+        if not was_keyword:
+            if pos_idx in type_positions:
+                arg = p_c_base_type(s)
+                parsed_type = True
+            else:
+                arg = p_simple_expr(s)
+            positional_args.append(arg)
+            pos_idx += 1
+            if len(keyword_args) > 0:
+                s.error("Non-keyword arg following keyword arg",
+                        pos = arg.pos)
+
+        if s.sy != ',':
+            if s.sy not in end_sy_set:
+                if parsed_type:
+                    s.error("Expected: type")
+                else:
+                    s.error("Expected: expression")
+            break
+        s.next()
+    return positional_args, keyword_args
+
 def p_c_base_type(s, self_flag = 0, nonempty = 0):
     # If self_flag is true, this is the base type for the
     # self argument of a C method of an extension type.
@@ -1556,24 +1605,32 @@ def p_c_simple_base_type(s, self_flag, nonempty):
     if s.sy == '[':
         if is_basic:
             p.error("Basic C types do not support buffer access")
-        s.next()
-        positional_args, keyword_args = (
-            p_positional_and_keyword_callargs(s, ('[]',)))
-        s.expect(']')
-
-        keyword_dict = ExprNodes.DictNode(pos,
-            key_value_pairs = [
-                ExprNodes.DictItemNode(pos=key.pos, key=key, value=value)
-                for key, value in keyword_args
-            ])
-        
-        return Nodes.CBufferAccessTypeNode(pos,
-            positional_args = positional_args,
-            keyword_args = keyword_dict,
-            base_type_node = type_node)
+        return p_buffer_access(s, type_node)
     else:
         return type_node
 
+def p_buffer_access(s, type_node):
+    # s.sy == '['
+    pos = s.position()
+    s.next()
+    positional_args, keyword_args = (
+        p_positional_and_keyword_args(s, (']',), (0,), ('dtype',))
+    )
+    s.expect(']')
+
+    keyword_dict = ExprNodes.DictNode(pos,
+        key_value_pairs = [
+            ExprNodes.DictItemNode(pos=key.pos, key=key, value=value)
+            for key, value in keyword_args
+        ])
+
+    result = Nodes.CBufferAccessTypeNode(pos,
+        positional_args = positional_args,
+        keyword_args = keyword_dict,
+        base_type_node = type_node)
+    return result
+    
+
 def looking_at_type(s):
     return looking_at_base_type(s) or s.looking_at_type_name()
 
index 3455dbc483e86330dbba6f20565c3f0b1515c15a..28accbbf219dbf41b0bacdf3a283d8c52cbeb93f 100644 (file)
@@ -2,7 +2,7 @@ from Cython.TestUtils import TransformTest
 from Cython.Compiler.ParseTreeTransforms import *
 from Cython.Compiler.Nodes import *
 
-class TestPostParse(TransformTest):
+class TestNormalizeTree(TransformTest):
     def test_parserbehaviour_is_what_we_coded_for(self):
         t = self.fragment(u"if x: y").root
         self.assertLines(u"""
@@ -15,7 +15,7 @@ class TestPostParse(TransformTest):
 """, self.treetypes(t))
         
     def test_wrap_singlestat(self):
-       t = self.run_pipeline([PostParse()], u"if x: y")
+       t = self.run_pipeline([NormalizeTree(None)], u"if x: y")
         self.assertLines(u"""
 (root): StatListNode
   stats[0]: IfStatNode
@@ -27,7 +27,7 @@ class TestPostParse(TransformTest):
 """, self.treetypes(t))
 
     def test_wrap_multistat(self):
-        t = self.run_pipeline([PostParse()], u"""
+        t = self.run_pipeline([NormalizeTree(None)], u"""
             if z:
                 x
                 y
@@ -45,7 +45,7 @@ class TestPostParse(TransformTest):
 """, self.treetypes(t))
 
     def test_statinexpr(self):
-        t = self.run_pipeline([PostParse()], u"""
+        t = self.run_pipeline([NormalizeTree(None)], u"""
             a, b = x, y
         """)
         self.assertLines(u"""
@@ -60,7 +60,7 @@ class TestPostParse(TransformTest):
 """, self.treetypes(t))
 
     def test_wrap_offagain(self):
-        t = self.run_pipeline([PostParse()], u"""
+        t = self.run_pipeline([NormalizeTree(None)], u"""
             x
             y
             if z:
@@ -82,13 +82,13 @@ class TestPostParse(TransformTest):
         
 
     def test_pass_eliminated(self):
-        t = self.run_pipeline([PostParse()], u"pass")
+        t = self.run_pipeline([NormalizeTree(None)], u"pass")
         self.assert_(len(t.stats) == 0)
 
 class TestWithTransform(TransformTest):
 
     def test_simplified(self):
-        t = self.run_pipeline([WithTransform()], u"""
+        t = self.run_pipeline([WithTransform(None)], u"""
         with x:
             y = z ** 3
         """)
@@ -113,7 +113,7 @@ class TestWithTransform(TransformTest):
         """, t)
 
     def test_basic(self):
-        t = self.run_pipeline([WithTransform()], u"""
+        t = self.run_pipeline([WithTransform(None)], u"""
         with x as y:
             y = z ** 3
         """)
index 0f9623d235a1c13921ef2cbc4c2a9e37fa53185e..31b6386b1cf0f1245a18236df3a4cab5cee2ad61 100644 (file)
@@ -131,7 +131,6 @@ class VisitorTransform(TreeVisitor):
     was not, an exception will be raised. (Typically you want to ensure that you
     are within a StatListNode or similar before doing this.)
     """
-    
     def visitchildren(self, parent, attrs=None):
         result = super(VisitorTransform, self).visitchildren(parent, attrs)
         for attr, newnode in result.iteritems():
@@ -152,6 +151,19 @@ class VisitorTransform(TreeVisitor):
     def __call__(self, root):
         return self.visit(root)
 
+class CythonTransform(VisitorTransform):
+    """
+    Certain common conventions and utilitues for Cython transforms.
+    """
+    def __init__(self, context):
+        super(CythonTransform, self).__init__()
+        self.context = context
+
+    def visit_Node(self, node):
+        self.visitchildren(node)
+        return node
+
+
 # Utils
 def ensure_statlist(node):
     if not isinstance(node, Nodes.StatListNode):
index 87f072f811466f1f66834b9614ebb0280d5b96b3..0588ee4acec4fc22e945a5387742b47d1d4f0a28 100644 (file)
@@ -50,12 +50,12 @@ class CythonTest(unittest.TestCase):
         self.assertEqual(len(result_lines), len(expected_lines),
             "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected))
 
-    def fragment(self, code, pxds={}):
+    def fragment(self, code, pxds={}, pipeline=[]):
         "Simply create a tree fragment using the name of the test-case in parse errors."
         name = self.id()
         if name.startswith("__main__."): name = name[len("__main__."):]
         name = name.replace(".", "_")
-        return TreeFragment(code, name, pxds)
+        return TreeFragment(code, name, pxds, pipeline=pipeline)
 
     def treetypes(self, root):
         """Returns a string representing the tree by class names.
@@ -66,6 +66,26 @@ class CythonTest(unittest.TestCase):
         w.visit(root)
         return u"\n".join([u""] + w.result + [u""])
 
+    def should_fail(self, func, exc_type=Exception):
+        """Calls "func" and fails if it doesn't raise the right exception
+        (any exception by default). Also returns the exception in question.
+        """
+        try:
+            func()
+            self.fail("Expected an exception of type %r" % exc_type)
+        except exc_type, e:
+            self.assert_(isinstance(e, exc_type))
+            return e
+
+    def should_not_fail(self, func):
+        """Calls func and succeeds if and only if no exception is raised
+        (i.e. converts exception raising into a failed testcase). Returns
+        the return value of func."""
+        try:
+            return func()
+        except:
+            self.fail()
+
 class TransformTest(CythonTest):
     """
     Utility base class for transform unit tests. It is based around constructing