numpy and extension types for runtime cython
authorRobert Bradshaw <robertwb@math.washington.edu>
Sun, 31 Oct 2010 05:21:07 +0000 (22:21 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Sun, 31 Oct 2010 05:21:07 +0000 (22:21 -0700)
Cython/Build/Inline.py

index dfadec6a1c5a1f59783ecea9292907396da2e6ac..4de123e402d5fc2a7f5ee7116b1b1677adb463f2 100644 (file)
@@ -7,45 +7,69 @@ except ImportError:
     import md5 as hashlib
 
 from distutils.dist import Distribution
-from distutils.core import Extension
+from Cython.Distutils.extension import Extension
 from Cython.Distutils import build_ext
-    
+
+from Cython.Compiler.Main import Context, CompilationOptions, default_options
+
 code_cache = {}
 
-def get_type(arg):
+
+def get_type(arg, context=None):
     py_type = type(arg)
-    # TODO: numpy
     # TODO: extension types
     if py_type in [list, tuple, dict, str]:
         return py_type.__name__
     elif py_type is float:
         return 'double'
+    elif py_type is bool:
+        return 'bint'
     elif py_type is int:
         return 'long'
+    elif 'numpy' in sys.modules and isinstance(arg, sys.modules['numpy'].ndarray):
+        return 'numpy.ndarray[numpy.%s_t, ndim=%s]' % (arg.dtype.name, arg.ndim)
     else:
+        for base_type in py_type.mro():
+            if base_type.__module__ == '__builtin__':
+                return 'object'
+            module = context.find_module(base_type.__module__, need_pxd=False)
+            if module:
+                entry = module.lookup(base_type.__name__)
+                if entry.is_type:
+                    return '%s.%s' % (base_type.__module__, base_type.__name__)
         return 'object'
 
 # TODO: use locals/globals for unbound variables
-def cython_inline(code, types='aggressive', lib_dir=os.path.expanduser('~/.cython/inline'), **kwds):
+def cython_inline(code, types='aggressive', lib_dir=os.path.expanduser('~/.cython/inline'), include_dirs=['.'], **kwds):
+    ctx = Context(include_dirs, default_options)
     _, pyx_file = tempfile.mkstemp('.pyx')
     arg_names = kwds.keys()
     arg_names.sort()
-    arg_sigs = tuple((get_type(kwds[arg]), arg) for arg in arg_names)
+    arg_sigs = tuple((get_type(kwds[arg], ctx), arg) for arg in arg_names)
     key = code, arg_sigs
     module = code_cache.get(key)
     if not module:
+        cimports = ''
+        qualified = re.compile(r'([.\w]+)[.]')
+        for type, _ in arg_sigs:
+            m = qualified.match(type)
+            if m:
+                cimports += '\ncimport %s' % m.groups()[0]
         module_body, func_body = extract_func_code(code)
         params = ', '.join('%s %s' % a for a in arg_sigs)
         module_code = """
+%(cimports)s
 %(module_body)s
 def __invoke(%(params)s):
 %(func_body)s
         """ % locals()
+        print module_code
         open(pyx_file, 'w').write(module_code)
         module = "_" + hashlib.md5(code + str(arg_sigs)).hexdigest()
         extension = Extension(
             name = module,
-            sources=[pyx_file])
+            sources = [pyx_file],
+            pyrex_include_dirs = include_dirs)
         build_extension = build_ext(Distribution())
         build_extension.finalize_options()
         build_extension.extensions = [extension]