From 880dee0a5701d0c9d6efc06b19297846ce9d2394 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Sat, 30 Oct 2010 23:55:13 -0700 Subject: [PATCH] Use unbound symbols from local/global scope. --- Cython/Build/Inline.py | 79 ++++++++++++++++++++++++++++----- Cython/Compiler/TreeFragment.py | 3 ++ 2 files changed, 71 insertions(+), 11 deletions(-) diff --git a/Cython/Build/Inline.py b/Cython/Build/Inline.py index 4de123e4..cc15a3a0 100644 --- a/Cython/Build/Inline.py +++ b/Cython/Build/Inline.py @@ -1,5 +1,7 @@ +print "Warning: Using prototype cython.inline code..." + import tempfile -import sys, os, re +import sys, os, re, inspect try: import hashlib @@ -12,12 +14,44 @@ from Cython.Distutils import build_ext from Cython.Compiler.Main import Context, CompilationOptions, default_options -code_cache = {} +from Cython.Compiler.ParseTreeTransforms import CythonTransform, SkipDeclarations, AnalyseDeclarationsTransform +from Cython.Compiler.TreeFragment import parse_from_strings + +_code_cache = {} + +class AllSymbols(CythonTransform, SkipDeclarations): + def __init__(self): + CythonTransform.__init__(self, None) + self.names = set() + def visit_NameNode(self, node): + self.names.add(node.name) + +def unbound_symbols(code, context=None): + if context is None: + context = Context([], default_options) + from Cython.Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform + if isinstance(code, str): + code = code.decode('ascii') + tree = parse_from_strings('(tree fragment)', code) + for phase in context.create_pipeline(pxd=False): + if phase is None: + continue + tree = phase(tree) + if isinstance(phase, AnalyseDeclarationsTransform): + break + symbol_collector = AllSymbols() + symbol_collector(tree) + unbound = [] + import __builtin__ + for name in symbol_collector.names: + if not tree.scope.lookup(name) and not hasattr(__builtin__, name): + unbound.append(name) + return unbound + def get_type(arg, context=None): py_type = type(arg) - # TODO: extension types if py_type in [list, tuple, dict, str]: return py_type.__name__ elif py_type is float: @@ -40,21 +74,43 @@ def get_type(arg, context=None): return 'object' # TODO: use locals/globals for unbound variables -def cython_inline(code, types='aggressive', lib_dir=os.path.expanduser('~/.cython/inline'), include_dirs=['.'], **kwds): +def cython_inline(code, + types='aggressive', + lib_dir=os.path.expanduser('~/.cython/inline'), + include_dirs=['.'], + locals=None, + globals=None, + **kwds): ctx = Context(include_dirs, default_options) - _, pyx_file = tempfile.mkstemp('.pyx') + if locals is None: + locals = inspect.currentframe().f_back.f_back.f_locals + if globals is None: + globals = inspect.currentframe().f_back.f_back.f_globals + try: + for symbol in unbound_symbols(code): + if symbol in kwds: + continue + elif symbol in locals: + kwds[symbol] = locals[symbol] + elif symbol in globals: + kwds[symbol] = globals[symbol] + else: + print "Couldn't find ", symbol + except AssertionError: + # Parsing from strings not fully supported (e.g. cimports). + print "Could not parse code as a string (to extract unbound symbols)." arg_names = kwds.keys() arg_names.sort() arg_sigs = tuple((get_type(kwds[arg], ctx), arg) for arg in arg_names) key = code, arg_sigs - module = code_cache.get(key) + module = _code_cache.get(key) if not module: - cimports = '' + cimports = [] qualified = re.compile(r'([.\w]+)[.]') for type, _ in arg_sigs: m = qualified.match(type) if m: - cimports += '\ncimport %s' % m.groups()[0] + cimports.append('\ncimport %s' % m.groups()[0]) module_body, func_body = extract_func_code(code) params = ', '.join('%s %s' % a for a in arg_sigs) module_code = """ @@ -62,8 +118,9 @@ def cython_inline(code, types='aggressive', lib_dir=os.path.expanduser('~/.cytho %(module_body)s def __invoke(%(params)s): %(func_body)s - """ % locals() - print module_code + """ % {'cimports': '\n'.join(cimports), 'module_body': module_body, 'params': params, 'func_body': func_body } +# print module_code + _, pyx_file = tempfile.mkstemp('.pyx') open(pyx_file, 'w').write(module_code) module = "_" + hashlib.md5(code + str(arg_sigs)).hexdigest() extension = Extension( @@ -78,7 +135,7 @@ def __invoke(%(params)s): sys.path.append(lib_dir) build_extension.build_lib = lib_dir build_extension.run() - code_cache[key] = module + _code_cache[key] = module arg_list = [kwds[arg] for arg in arg_names] return __import__(module).__invoke(*arg_list) diff --git a/Cython/Compiler/TreeFragment.py b/Cython/Compiler/TreeFragment.py index 66feaf09..13e0dc11 100644 --- a/Cython/Compiler/TreeFragment.py +++ b/Cython/Compiler/TreeFragment.py @@ -60,6 +60,7 @@ def parse_from_strings(name, code, pxds={}, level=None, initial_pos=None): scope = scope, context = context, initial_pos = initial_pos) if level is None: tree = Parsing.p_module(scanner, 0, module_name) + tree.scope = scope else: tree = Parsing.p_code(scanner, level=level) return tree @@ -201,6 +202,8 @@ class TreeFragment(object): if not isinstance(t, StatListNode): t = StatListNode(pos=mod.pos, stats=[t]) for transform in pipeline: + if transform is None: + continue t = transform(t) self.root = t elif isinstance(code, Node): -- 2.26.2