Fix trac #505: problem with cimport cython
authorCraig Citro <craigcitro@gmail.com>
Tue, 9 Feb 2010 23:56:33 +0000 (15:56 -0800)
committerCraig Citro <craigcitro@gmail.com>
Tue, 9 Feb 2010 23:56:33 +0000 (15:56 -0800)
Cython/Compiler/Nodes.py
Cython/Compiler/ParseTreeTransforms.py

index e727b3ba7cd0c906b16a73c0ea663a8c3534ab37..ed08e494264d035bd61e6ca44be5a82d592ef853 100644 (file)
@@ -4873,7 +4873,9 @@ class FromImportStatNode(StatNode):
                         break
             else:
                 entry =  env.lookup(target.name)
-                if entry.is_type and entry.type.name == name and entry.type.module_name == self.module.module_name.value:
+                if (entry.is_type and 
+                    entry.type.name == name and
+                    entry.type.module_name == self.module.module_name.value):
                     continue # already cimported
                 target.analyse_target_expression(env, None)
                 if target.type is py_object_type:
index e30117771fe75235a7f5f8a5d1f24b53b85793b5..3dd4847cbfb5a1d194b1b188a1a2f2998db43082 100644 (file)
@@ -337,8 +337,9 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
         'address': AmpersandNode,
     }
     
-    special_methods = set(['declare', 'union', 'struct', 'typedef', 'sizeof', 'cast', 'pointer', 'compiled', 'NULL']
-                            + unop_method_nodes.keys())
+    special_methods = set(['declare', 'union', 'struct', 'typedef', 'sizeof',
+                           'cast', 'pointer', 'compiled', 'NULL']
+                          + unop_method_nodes.keys())
 
     def __init__(self, context, compilation_directive_defaults):
         super(InterpretCompilerDirectives, self).__init__(context)
@@ -373,38 +374,33 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
         node.cython_module_names = self.cython_module_names
         return node
 
-    # Track cimports of the cython module.
+    # The following four functions track imports and cimports that
+    # begin with "cython"
+    def is_cython_directive(self, name):
+        return (name in Options.directive_types or
+                name in self.special_methods or
+                PyrexTypes.parse_basic_type(name))
+
     def visit_CImportStatNode(self, node):
         if node.module_name == u"cython":
-            if node.as_name:
-                modname = node.as_name
-            else:
-                modname = u"cython"
-            self.cython_module_names.add(modname)
+            self.cython_module_names.add(node.as_name or u"cython")
         elif node.module_name.startswith(u"cython."):
             if node.as_name:
                 self.directive_names[node.as_name] = node.module_name[7:]
             else:
                 self.cython_module_names.add(u"cython")
-        else:
-            return node
+            # if this cimport was a compiler directive, we don't
+            # want to leave the cimport node sitting in the tree
+            return None
+        return node
     
     def visit_FromCImportStatNode(self, node):
-        if node.module_name.startswith(u"cython."):
-            is_cython_module = True
-            submodule = node.module_name[7:] + u"."
-        elif node.module_name == u"cython":
-            is_cython_module = True
-            submodule = u""
-        else:
-            is_cython_module = False
-        if is_cython_module:
+        if node.module_name.startswith(u"cython"):
+            submodule = (node.module_name + u".")[7:]
             newimp = []
             for pos, name, as_name, kind in node.imported_names:
                 full_name = submodule + name
-                if (full_name in Options.directive_types or 
-                        full_name in self.special_methods or
-                        PyrexTypes.parse_basic_type(full_name)):
+                if self.is_cython_directive(full_name):
                     if as_name is None:
                         as_name = full_name
                     self.directive_names[as_name] = full_name
@@ -419,21 +415,12 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
         return node
         
     def visit_FromImportStatNode(self, node):
-        if node.module.module_name.value.startswith(u"cython."):
-            is_cython_module = True
-            submodule = node.module.module_name.value[7:] + u"."
-        elif node.module.module_name.value == u"cython":
-            is_cython_module = True
-            submodule = u""
-        else:
-            is_cython_module = False
-        if is_cython_module:
+        if node.module.module_name.value.startswith(u"cython"):
+            submodule = (node.module.module_name.value + u".")[7:]
             newimp = []
             for name, name_node in node.items:
                 full_name = submodule + name
-                if (full_name in Options.directive_types or 
-                        full_name in self.special_methods or
-                        PyrexTypes.parse_basic_type(full_name)):
+                if self.is_cython_directive(full_name):
                     self.directive_names[name_node.name] = full_name
                 else:
                     newimp.append((name, name_node))