Parse distutils directives.
authorRobert Bradshaw <robertwb@math.washington.edu>
Sun, 12 Sep 2010 08:37:51 +0000 (01:37 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Sun, 12 Sep 2010 08:37:51 +0000 (01:37 -0700)
Cython/Compiler/Dependencies.py
tests/build/inline_distutils.srctree [new file with mode: 0644]

index 30f627d483b19f93b8fc64d67b0d7ad5e2723c47..4772bda2179f5a8ccd1377d471d7f6bbdf207992 100644 (file)
@@ -19,6 +19,90 @@ def cached_method(f):
         return res
     return wrapper
 
+
+def parse_list(s):
+    if s[0] == '[' and s[-1] == ']':
+        s = s[1:-1]
+        delimiter = ','
+    else:
+        delimiter = ' '
+    s, literals = strip_string_literals(s)
+    def unquote(literal):
+        literal = literal.strip()
+        if literal[0] == "'":
+            return literals[literal[1:-1]]
+        else:
+            return literal
+            
+    return [unquote(item) for item in s.split(delimiter)]
+
+transitive_str = object()
+transitive_list = object()
+
+distutils_settings = {
+    'name':                 str,
+    'sources':              list,
+    'define_macros':        list,
+    'undef_macros':         list,
+    'libraries':            transitive_list,
+    'library_dirs':         transitive_list,
+    'runtime_library_dirs': transitive_list,
+    'include_dirs':         transitive_list,
+    'extra_objects':        list,
+    'extra_compile_args':   list,
+    'extra_link_args':      list,
+    'export_symbols':       list,
+    'depends':              transitive_list,
+    'language':             transitive_str,
+}
+
+def line_iter(source):
+    start = 0
+    while True:
+        end = source.find('\n', start)
+        if end == -1:
+            yield source[start:]
+            return
+        yield source[start:end]
+        start = end+1
+
+class DistutilsInfo(object):
+    
+    def __init__(self, source):
+        self.values = {}
+        for line in line_iter(source):
+            line = line.strip()
+            if line != '' and line[0] != '#':
+                break
+            line = line[1:].strip()
+            if line[:10] == 'distutils:':
+                line = line[10:]
+                ix = line.index('=')
+                key = str(line[:ix].strip())
+                value = line[ix+1:].strip()
+                type = distutils_settings[key]
+                if type in (list, transitive_list):
+                    value = parse_list(value)
+                    if key == 'define_macros':
+                        value = [tuple(macro.split('=')) for macro in value]
+                self.values[key] = value
+    
+    def merge(self, other):
+        for key, value in other.values.items():
+            type = distutils_settings[key]
+            if type is transitive_str and key not in self.values:
+                self.values[key] = value
+            elif type is transitive_list:
+                if key in self.values:
+                    all = self.values[key]
+                    for v in value:
+                        if v not in all:
+                            all.append(v)
+                else:
+                    self.values[key] = value
+        return self
+
+
 def strip_string_literals(code, prefix='__Pyx_L'):
     """
     Normalizes every string literal to be of the form '__Pyx_Lxxx', 
@@ -85,6 +169,7 @@ def parse_dependencies(source_filename):
     # The only catch is that we must strip comments and string
     # literals ahead of time.
     source = Utils.open_source_file(source_filename, "rU").read()
+    distutils_info = DistutilsInfo(source)
     source = re.sub('#.*', '', source)
     source, literals = strip_string_literals(source)
     source = source.replace('\\\n', ' ')
@@ -105,7 +190,7 @@ def parse_dependencies(source_filename):
             includes.append(literals[groups[5]])
         else:
             externs.append(literals[groups[7]])
-    return cimports, includes, externs
+    return cimports, includes, externs, distutils_info
 
 
 class DependencyTree(object):
@@ -120,7 +205,7 @@ class DependencyTree(object):
     
     @cached_method
     def cimports_and_externs(self, filename):
-        cimports, includes, externs = self.parse_dependencies(filename)
+        cimports, includes, externs = self.parse_dependencies(filename)[:3]
         cimports = set(cimports)
         externs = set(externs)
         for include in includes:
@@ -149,7 +234,7 @@ class DependencyTree(object):
         if module[0] == '.':
             raise NotImplementedError, "New relative imports."
         if filename is not None:
-            relative = '.'.join(self.package(filename) + module.split('.'))
+            relative = '.'.join(self.package(filename) + tuple(module.split('.')))
             pxd = self.context.find_pxd_file(relative, None)
             if pxd:
                 return pxd
@@ -158,9 +243,9 @@ class DependencyTree(object):
     @cached_method
     def cimported_files(self, filename):
         if filename[-4:] == '.pyx' and os.path.exists(filename[:-4] + '.pxd'):
-            self_pxd = (filename[:-4] + '.pxd',)
+            self_pxd = [filename[:-4] + '.pxd']
         else:
-            self_pxd = ()
+            self_pxd = []
         a = self.cimports(filename)
         b = filter(None, [self.find_pxd(m, filename) for m in self.cimports(filename)])
         if len(a) != len(b):
@@ -186,6 +271,12 @@ class DependencyTree(object):
     def newest_dependency(self, filename):
         return self.transitive_merge(filename, self.extract_timestamp, max)
     
+    def distutils_info0(self, filename):
+        return self.parse_dependencies(filename)[3]
+    
+    def distutils_info(self, filename):
+        return self.transitive_merge(filename, self.distutils_info0, DistutilsInfo.merge)
+    
     def transitive_merge(self, node, extract, merge):
         try:
             seen = self._transitive_cache[extract, merge]
@@ -229,6 +320,8 @@ def create_dependency_tree(ctx=None):
         _dep_tree = DependencyTree(ctx)
     return _dep_tree
 
+# TODO: Take common options.
+# TODO: Symbolic names (e.g. for numpy.include_dirs()
 def create_extension_list(filepatterns, ctx=None):
     deps = create_dependency_tree(ctx)
     if isinstance(filepatterns, str):
@@ -238,7 +331,7 @@ def create_extension_list(filepatterns, ctx=None):
         for file in glob(pattern):
             pkg = deps.package(file)
             name = deps.fully_qualifeid_name(file)
-            module_list.append(Extension(name=name, sources=[file]))
+            module_list.append(Extension(name=name, sources=[file], **deps.distutils_info(file).values))
     return module_list
 
 def cythonize(module_list, ctx=None):
diff --git a/tests/build/inline_distutils.srctree b/tests/build/inline_distutils.srctree
new file mode 100644 (file)
index 0000000..436676d
--- /dev/null
@@ -0,0 +1,33 @@
+PYTHON setup.py build_ext --inplace
+PYTHON -c "import a"
+
+######## setup.py ########
+
+
+# TODO: Better interface...
+from Cython.Compiler.Dependencies import create_extension_list, cythonize
+
+from distutils.core import setup
+
+setup(
+  ext_modules = cythonize(create_extension_list("*.pyx")),
+)
+
+######## my_lib.pxd ########
+
+# distutils: language=c++
+
+cdef extern from "my_lib_helper.cpp" namespace "A":
+    int x
+
+######## my_lib_helper.cpp #######
+
+namespace A {
+    int x = 100;
+};
+
+######## a.pyx ########
+
+from my_lib cimport x
+
+print x