Simplify and enhance struct/union wrapping.
authorRobert Bradshaw <robertwb@math.washington.edu>
Sun, 27 Feb 2011 07:55:14 +0000 (23:55 -0800)
committerRobert Bradshaw <robertwb@math.washington.edu>
Sun, 27 Feb 2011 07:55:14 +0000 (23:55 -0800)
Cython/Compiler/ParseTreeTransforms.py

index b57d374f2eaf10e1b52f22aa53d6a1b1ca717f51..cc367a795b40db0f0cdfea74da76b9b5c185267d 100644 (file)
@@ -1018,16 +1018,27 @@ property NAME:
         return ATTR
     """, level='c_class')
 
-    repr_tree = TreeFragment(u"""
-def NAME(self):
-    return FORMAT % ATTRS
-    """, level='c_class')
+    struct_or_union_wrapper = TreeFragment(u"""
+cdef class NAME:
+    cdef TYPE value
+    def __init__(self, MEMBER=None):
+        cdef int count
+        count = 0
+        INIT_ASSIGNMENTS
+        if IS_UNION and count > 1:
+            raise ValueError, "At most one union member should be specified."
+    def __str__(self):
+        return STR_FORMAT % MEMBER_TUPLE
+    def __repr__(self):
+        return REPR_FORMAT % MEMBER_TUPLE
+    """)
+
+    init_assignment = TreeFragment(u"""
+if VALUE is not None:
+    ATTR = VALUE
+    count += 1
+    """)
 
-    init_assign = TreeFragment(u"""
-if VAR is not None:
-    ATTR = VAR
-    """, level='c_class')
-    
     def __call__(self, root):
         self.env_stack = [root.scope]
         # needed to determine if a cdef var is declared after it's used.
@@ -1111,22 +1122,15 @@ if VAR is not None:
         return node
 
     def visit_CStructOrUnionDefNode(self, node):
-        # Create a shadow node if needed.
+        # Create a wrapper node if needed.
         # We want to use the struct type information (so it can't happen
         # before this phase) but also create new objects to be declared
         # (so it can't happen later).
-        # Note that we don't need to return the original node, as it is
+        # Note that we don't return the original node, as it is
         # never used after this phase.
         if True: # private (default)
             return None
-        # cdef struct_type value
-        class_body = [Nodes.CVarDefNode(
-            node.pos,
-            base_type = Nodes.CSimpleBaseTypeNode(node.pos, name=node.name),
-            declarators = [Nodes.CNameDeclaratorNode(node.pos, name='value', cname='value')],
-            visibility = 'private',
-            in_pxd = False)]
-        # setters/getters
+
         self_value = ExprNodes.AttributeNode(
             pos = node.pos,
             obj = ExprNodes.NameNode(pos=node.pos, name=u"self"),
@@ -1137,6 +1141,45 @@ if VAR is not None:
             attributes.append(ExprNodes.AttributeNode(pos = entry.pos,
                                                       obj = self_value,
                                                       attribute = entry.name))
+        # __init__ assignments
+        init_assignments = []
+        for entry, attr in zip(var_entries, attributes):
+            # TODO: branch on visibility
+            init_assignments.append(self.init_assignment.substitute({
+                    u"VALUE": ExprNodes.NameNode(entry.pos, name = entry.name),
+                    u"ATTR": attr,
+                }, pos = entry.pos))
+
+        # create the class
+        str_format = u"%s(%s)" % (node.entry.type.name, ("%s, " * len(attributes))[:-2])
+        wrapper_class = self.struct_or_union_wrapper.substitute({
+            u"INIT_ASSIGNMENTS": Nodes.StatListNode(node.pos, stats = init_assignments),
+            u"IS_UNION": ExprNodes.BoolNode(node.pos, value = not node.entry.type.is_struct),
+            u"MEMBER_TUPLE": ExprNodes.TupleNode(node.pos, args=attributes),
+            u"STR_FORMAT": ExprNodes.StringNode(node.pos, value = EncodedString(str_format)),
+            u"REPR_FORMAT": ExprNodes.StringNode(node.pos, value = EncodedString(str_format.replace("%s", "%r"))),
+        }, pos = node.pos).stats[0]
+        wrapper_class.class_name = node.name
+        wrapper_class.shadow = True
+        class_body = wrapper_class.body.stats
+
+        # fix value type
+        assert isinstance(class_body[0].base_type, Nodes.CSimpleBaseTypeNode)
+        class_body[0].base_type.name = node.name
+
+        # fix __init__ arguments
+        init_method = class_body[1]
+        assert isinstance(init_method, Nodes.DefNode) and init_method.name == '__init__'
+        arg_template = init_method.args[1]
+        if not node.entry.type.is_struct:
+            arg_template.kw_only = True
+        del init_method.args[1]
+        for entry, attr in zip(var_entries, attributes):
+            arg = copy.deepcopy(arg_template)
+            arg.declarator.name = entry.name
+            init_method.args.append(arg)
+            
+        # setters/getters
         for entry, attr in zip(var_entries, attributes):
             # TODO: branch on visibility
             if entry.type.is_pyobject:
@@ -1147,78 +1190,10 @@ if VAR is not None:
                     u"ATTR": attr,
                 }, pos = entry.pos).stats[0]
             property.name = entry.name
-            class_body.append(property)
-        
-        # __init__
-        self_base_type = Nodes.CSimpleBaseTypeNode(node.pos,
-            name = None, module_path = [],
-            is_basic_c_type = 0, signed = 0,
-            complex = 0, longness = 0,
-            is_self_arg = 1, templates = None)
-        init_args = [Nodes.CArgDeclNode(
-            node.pos,
-            base_type = self_base_type,
-            declarator = Nodes.CNameDeclaratorNode(node.pos, name = u"self", cname = None),
-            default = None)]
-        empty_base_type = Nodes.CSimpleBaseTypeNode(node.pos,
-            name = None, module_path = [],
-            is_basic_c_type = 0, signed = 0,
-            complex = 0, longness = 0,
-            is_self_arg = 0, templates = None)
-        init_body = []
-        for entry, attr in zip(var_entries, attributes):
-            # TODO: branch on visibility
-            init_args.append(Nodes.CArgDeclNode(
-                entry.pos,
-                base_type = empty_base_type,
-                declarator = Nodes.CNameDeclaratorNode(entry.pos, name = entry.name, cname = None),
-                default = ExprNodes.NoneNode(entry.pos)))
-            init_body.append(self.init_assign.substitute({
-                    u"VAR": ExprNodes.NameNode(entry.pos, name = entry.name),
-                    u"ATTR": attr,
-                }, pos = entry.pos))
-        init_method = Nodes.DefNode(
-            pos = node.pos,
-            args = init_args,
-            name = u"__init__",
-            decorators = [],
-            body = Nodes.StatListNode(node.pos, stats=init_body))
-        class_body.append(init_method)
-
-        # __str__
-        attr_tuple = ExprNodes.TupleNode(node.pos, args=attributes)
-        format = u"%s(%s)" % (node.entry.type.name, ("%s, " * len(attributes))[:-2])
-        repr_method = self.repr_tree.substitute({
-                u"FORMAT": ExprNodes.StringNode(node.pos, value = EncodedString(format)),
-                u"ATTRS": attr_tuple,
-            }, pos = node.pos).stats[0]
-        repr_method.name = "__str__"
-        class_body.append(repr_method)
-
-        # __repr__
-        format = u"%s(%s)" % (node.entry.type.name, ("%r, " * len(attributes))[:-2])
-        repr_method = self.repr_tree.substitute({
-                u"FORMAT": ExprNodes.StringNode(node.pos, value = EncodedString(format)),
-                u"ATTRS": attr_tuple,
-            }, pos = node.pos).stats[0]
-        repr_method.name = "__repr__"
-        class_body.append(repr_method)
-
-        # Now create the class.
-        shadow = Nodes.CClassDefNode(node.pos,
-            visibility = 'public',
-            module_name = None,
-            class_name = node.name,
-            base_class_module = None,
-            base_class_name = None,
-            decorators = None,
-            body = Nodes.StatListNode(node.pos, stats=class_body),
-            in_pxd = False,
-            doc = None,
-            shadow = True)
-
-        shadow.analyse_declarations(self.env_stack[-1])
-        return self.visit_CClassDefNode(shadow)
+            wrapper_class.body.stats.append(property)
+            
+        wrapper_class.analyse_declarations(self.env_stack[-1])
+        return self.visit_CClassDefNode(wrapper_class)
 
     # Some nodes are no longer needed after declaration
     # analysis and can be dropped. The analysis was performed