From 2d2fba93da93924700a0195a5b0c6fb7bb992fc2 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Sat, 30 Oct 2010 22:21:07 -0700 Subject: [PATCH] numpy and extension types for runtime cython --- Cython/Build/Inline.py | 38 +++++++++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/Cython/Build/Inline.py b/Cython/Build/Inline.py index dfadec6a..4de123e4 100644 --- a/Cython/Build/Inline.py +++ b/Cython/Build/Inline.py @@ -7,45 +7,69 @@ except ImportError: import md5 as hashlib from distutils.dist import Distribution -from distutils.core import Extension +from Cython.Distutils.extension import Extension from Cython.Distutils import build_ext - + +from Cython.Compiler.Main import Context, CompilationOptions, default_options + code_cache = {} -def get_type(arg): + +def get_type(arg, context=None): py_type = type(arg) - # TODO: numpy # TODO: extension types if py_type in [list, tuple, dict, str]: return py_type.__name__ elif py_type is float: return 'double' + elif py_type is bool: + return 'bint' elif py_type is int: return 'long' + elif 'numpy' in sys.modules and isinstance(arg, sys.modules['numpy'].ndarray): + return 'numpy.ndarray[numpy.%s_t, ndim=%s]' % (arg.dtype.name, arg.ndim) else: + for base_type in py_type.mro(): + if base_type.__module__ == '__builtin__': + return 'object' + module = context.find_module(base_type.__module__, need_pxd=False) + if module: + entry = module.lookup(base_type.__name__) + if entry.is_type: + return '%s.%s' % (base_type.__module__, base_type.__name__) return 'object' # TODO: use locals/globals for unbound variables -def cython_inline(code, types='aggressive', lib_dir=os.path.expanduser('~/.cython/inline'), **kwds): +def cython_inline(code, types='aggressive', lib_dir=os.path.expanduser('~/.cython/inline'), include_dirs=['.'], **kwds): + ctx = Context(include_dirs, default_options) _, pyx_file = tempfile.mkstemp('.pyx') arg_names = kwds.keys() arg_names.sort() - arg_sigs = tuple((get_type(kwds[arg]), arg) for arg in arg_names) + arg_sigs = tuple((get_type(kwds[arg], ctx), arg) for arg in arg_names) key = code, arg_sigs module = code_cache.get(key) if not module: + cimports = '' + qualified = re.compile(r'([.\w]+)[.]') + for type, _ in arg_sigs: + m = qualified.match(type) + if m: + cimports += '\ncimport %s' % m.groups()[0] module_body, func_body = extract_func_code(code) params = ', '.join('%s %s' % a for a in arg_sigs) module_code = """ +%(cimports)s %(module_body)s def __invoke(%(params)s): %(func_body)s """ % locals() + print module_code open(pyx_file, 'w').write(module_code) module = "_" + hashlib.md5(code + str(arg_sigs)).hexdigest() extension = Extension( name = module, - sources=[pyx_file]) + sources = [pyx_file], + pyrex_include_dirs = include_dirs) build_extension = build_ext(Distribution()) build_extension.finalize_options() build_extension.extensions = [extension] -- 2.26.2