Parsing.py parses [] buffer access; fixed a unit test; Node.dump implemented
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Tue, 1 Jul 2008 13:56:19 +0000 (15:56 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Tue, 1 Jul 2008 13:56:19 +0000 (15:56 +0200)
Cython/Compiler/Errors.py
Cython/Compiler/Nodes.py
Cython/Compiler/Parsing.py
Cython/Compiler/Tests/TestTreeFragment.py

index 42e139d3ad4ca0d2983f44e1d8d92bb85c5f2f44..814146274d5031b6151c88cfaa427c5adbb488c5 100644 (file)
@@ -91,6 +91,7 @@ def error(position, message):
     #print "Errors.error:", repr(position), repr(message) ###
     global num_errors
     err = CompileError(position, message)
+#    if position is not None: raise Exception(err) # debug
     line = "%s\n" % err
     if listing_file:
         listing_file.write(line)
index faf16dc96dbd447751277284fe8e883165b35944..75db65e55a3cd07b85e769770107f5a78dbdcbb5 100644 (file)
@@ -172,7 +172,27 @@ class Node(object):
             self._end_pos = pos
             return pos
 
-
+    def dump(self, level=0, filter_out=("pos",)):
+        def dump_child(x, level):
+            if isinstance(x, Node):
+                return x.dump(level)
+            elif isinstance(x, list):
+                return "[%s]" % ", ".join(dump_child(item, level) for item in x)
+            else:
+                return repr(x)
+            
+        
+        attrs = [(key, value) for key, value in self.__dict__.iteritems() if key not in filter_out]
+        if len(attrs) == 0:
+            return "<%s>" % self.__class__.__name__
+        else:
+            indent = "  " * level
+            res = "<%s\n" % (self.__class__.__name__)
+            for key, value in attrs:
+                res += "%s  %s: %s\n" % (indent, key, dump_child(value, level + 1))
+            res += "%s>" % indent
+            return res
+        
 class BlockNode:
     #  Mixin class for nodes representing a declaration block.
 
@@ -587,6 +607,16 @@ class CSimpleBaseTypeNode(CBaseTypeNode):
         else:
             return PyrexTypes.error_type
 
+class CBufferAccessTypeNode(Node):
+    #  base_type_node   CBaseTypeNode
+    #  positional_args  [ExprNode]        List of positional arguments
+    #  keyword_args     DictNode          Keyword arguments
+
+    child_attrs = ["base_type_node", "positional_args", "keyword_args"]
+    
+    def analyse(self, env):
+        
+        return self.base_type_node.analyse(env)
 
 class CComplexBaseTypeNode(CBaseTypeNode):
     # base_type   CBaseTypeNode
index 4d5de6cf04c9e5e99da2075664142e9547a663e1..2308553af6c98d480daac322b54428733bc3b175 100644 (file)
@@ -281,18 +281,20 @@ def p_trailer(s, node1):
         return ExprNodes.AttributeNode(pos, 
             obj = node1, attribute = name)
 
-# arglist:  argument (',' argument)* [',']
-# argument: [test '='] test       # Really [keyword '='] test
-
-def p_call(s, function):
-    # s.sy == '('
-    pos = s.position()
-    s.next()
+def p_positional_and_keyword_callargs(s, end_sy_set):
+    """
+    Parses positional and keyword call arguments. end_sy_set
+    should contain any s.sy that terminate the argument chain
+    (this is ('*', '**', ')') for a normal function call,
+    and (']',) for buffers declarators).
+
+    Returns: (positional_args, keyword_args)
+    """
     positional_args = []
     keyword_args = []
-    star_arg = None
-    starstar_arg = None
-    while s.sy not in ('*', '**', ')'):
+    while s.sy not in end_sy_set:
+        if s.sy == '*' or s.sy == '**':
+            s.error('Argument expansion not allowed here.')
         arg = p_simple_expr(s)
         if s.sy == '=':
             s.next()
@@ -312,6 +314,21 @@ def p_call(s, function):
         if s.sy != ',':
             break
         s.next()
+    return positional_args, keyword_args
+
+
+# arglist:  argument (',' argument)* [',']
+# argument: [test '='] test       # Really [keyword '='] test
+
+def p_call(s, function):
+    # s.sy == '('
+    pos = s.position()
+    s.next()
+    star_arg = None
+    starstar_arg = None
+    positional_args, keyword_args = (
+        p_positional_and_keyword_callargs(s,('*', '**', ')')))
+
     if s.sy == '*':
         s.next()
         star_arg = p_simple_expr(s)
@@ -1528,11 +1545,35 @@ def p_c_simple_base_type(s, self_flag, nonempty):
     else:
         #print "p_c_simple_base_type: not looking at type at", s.position()
         name = None
-    return Nodes.CSimpleBaseTypeNode(pos, 
+
+    type_node = Nodes.CSimpleBaseTypeNode(pos, 
         name = name, module_path = module_path,
         is_basic_c_type = is_basic, signed = signed,
         longness = longness, is_self_arg = self_flag)
 
+
+    # Treat trailing [] on type as buffer access
+    if s.sy == '[':
+        if is_basic:
+            p.error("Basic C types do not support buffer access")
+        s.next()
+        positional_args, keyword_args = (
+            p_positional_and_keyword_callargs(s, ('[]',)))
+        s.expect(']')
+
+        keyword_dict = ExprNodes.DictNode(pos,
+            key_value_pairs = [
+                ExprNodes.DictItemNode(pos=key.pos, key=key, value=value)
+                for key, value in keyword_args
+            ])
+        
+        return Nodes.CBufferAccessTypeNode(pos,
+            positional_args = positional_args,
+            keyword_args = keyword_dict,
+            base_type_node = type_node)
+    else:
+        return type_node
+
 def looking_at_type(s):
     return looking_at_base_type(s) or s.looking_at_type_name()
 
index ef953826f4dbd8f87c0e9088ef999dc0b57cc51b..2214cd14bde3898032a00b01cf4eea4aea6b4fd2 100644 (file)
@@ -4,6 +4,7 @@ from Cython.Compiler.Nodes import *
 import Cython.Compiler.Naming as Naming
 
 class TestTreeFragments(CythonTest):
+    
     def test_basic(self):
         F = self.fragment(u"x = 4")
         T = F.copy()
@@ -46,13 +47,14 @@ class TestTreeFragments(CythonTest):
         self.assertEquals(v.pos, a.pos)
         
     def test_temps(self):
+        import Cython.Compiler.Visitor as v
+        v.tmpnamectr = 0
         F = self.fragment(u"""
             TMP
             x = TMP
         """)
         T = F.substitute(temps=[u"TMP"])
         s = T.stats
-        print s[0].expr.name
         self.assert_(s[0].expr.name == Naming.temp_prefix +  u"1_TMP", s[0].expr.name)
         self.assert_(s[1].rhs.name == Naming.temp_prefix + u"1_TMP")
         self.assert_(s[0].expr.name !=  u"TMP")