+print "Warning: Using prototype cython.inline code..."
+
import tempfile
-import sys, os, re
+import sys, os, re, inspect
try:
import hashlib
from Cython.Compiler.Main import Context, CompilationOptions, default_options
-code_cache = {}
+from Cython.Compiler.ParseTreeTransforms import CythonTransform, SkipDeclarations, AnalyseDeclarationsTransform
+from Cython.Compiler.TreeFragment import parse_from_strings
+
+_code_cache = {}
+
+class AllSymbols(CythonTransform, SkipDeclarations):
+ def __init__(self):
+ CythonTransform.__init__(self, None)
+ self.names = set()
+ def visit_NameNode(self, node):
+ self.names.add(node.name)
+
+def unbound_symbols(code, context=None):
+ if context is None:
+ context = Context([], default_options)
+ from Cython.Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform
+ if isinstance(code, str):
+ code = code.decode('ascii')
+ tree = parse_from_strings('(tree fragment)', code)
+ for phase in context.create_pipeline(pxd=False):
+ if phase is None:
+ continue
+ tree = phase(tree)
+ if isinstance(phase, AnalyseDeclarationsTransform):
+ break
+ symbol_collector = AllSymbols()
+ symbol_collector(tree)
+ unbound = []
+ import __builtin__
+ for name in symbol_collector.names:
+ if not tree.scope.lookup(name) and not hasattr(__builtin__, name):
+ unbound.append(name)
+ return unbound
+
def get_type(arg, context=None):
py_type = type(arg)
- # TODO: extension types
if py_type in [list, tuple, dict, str]:
return py_type.__name__
elif py_type is float:
return 'object'
# TODO: use locals/globals for unbound variables
-def cython_inline(code, types='aggressive', lib_dir=os.path.expanduser('~/.cython/inline'), include_dirs=['.'], **kwds):
+def cython_inline(code,
+ types='aggressive',
+ lib_dir=os.path.expanduser('~/.cython/inline'),
+ include_dirs=['.'],
+ locals=None,
+ globals=None,
+ **kwds):
ctx = Context(include_dirs, default_options)
- _, pyx_file = tempfile.mkstemp('.pyx')
+ if locals is None:
+ locals = inspect.currentframe().f_back.f_back.f_locals
+ if globals is None:
+ globals = inspect.currentframe().f_back.f_back.f_globals
+ try:
+ for symbol in unbound_symbols(code):
+ if symbol in kwds:
+ continue
+ elif symbol in locals:
+ kwds[symbol] = locals[symbol]
+ elif symbol in globals:
+ kwds[symbol] = globals[symbol]
+ else:
+ print "Couldn't find ", symbol
+ except AssertionError:
+ # Parsing from strings not fully supported (e.g. cimports).
+ print "Could not parse code as a string (to extract unbound symbols)."
arg_names = kwds.keys()
arg_names.sort()
arg_sigs = tuple((get_type(kwds[arg], ctx), arg) for arg in arg_names)
key = code, arg_sigs
- module = code_cache.get(key)
+ module = _code_cache.get(key)
if not module:
- cimports = ''
+ cimports = []
qualified = re.compile(r'([.\w]+)[.]')
for type, _ in arg_sigs:
m = qualified.match(type)
if m:
- cimports += '\ncimport %s' % m.groups()[0]
+ cimports.append('\ncimport %s' % m.groups()[0])
module_body, func_body = extract_func_code(code)
params = ', '.join('%s %s' % a for a in arg_sigs)
module_code = """
%(module_body)s
def __invoke(%(params)s):
%(func_body)s
- """ % locals()
- print module_code
+ """ % {'cimports': '\n'.join(cimports), 'module_body': module_body, 'params': params, 'func_body': func_body }
+# print module_code
+ _, pyx_file = tempfile.mkstemp('.pyx')
open(pyx_file, 'w').write(module_code)
module = "_" + hashlib.md5(code + str(arg_sigs)).hexdigest()
extension = Extension(
sys.path.append(lib_dir)
build_extension.build_lib = lib_dir
build_extension.run()
- code_cache[key] = module
+ _code_cache[key] = module
arg_list = [kwds[arg] for arg in arg_names]
return __import__(module).__invoke(*arg_list)