Use unbound symbols from local/global scope.
authorRobert Bradshaw <robertwb@math.washington.edu>
Sun, 31 Oct 2010 06:55:13 +0000 (23:55 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Sun, 31 Oct 2010 06:55:13 +0000 (23:55 -0700)
Cython/Build/Inline.py
Cython/Compiler/TreeFragment.py

index 4de123e402d5fc2a7f5ee7116b1b1677adb463f2..cc15a3a00b8e4b73e463c153b9b3b9268cd920b6 100644 (file)
@@ -1,5 +1,7 @@
+print "Warning: Using prototype cython.inline code..."
+
 import tempfile
-import sys, os, re
+import sys, os, re, inspect
 
 try:
     import hashlib
@@ -12,12 +14,44 @@ from Cython.Distutils import build_ext
 
 from Cython.Compiler.Main import Context, CompilationOptions, default_options
 
-code_cache = {}
+from Cython.Compiler.ParseTreeTransforms import CythonTransform, SkipDeclarations, AnalyseDeclarationsTransform
+from Cython.Compiler.TreeFragment import parse_from_strings
+
+_code_cache = {}
+
 
+class AllSymbols(CythonTransform, SkipDeclarations):
+    def __init__(self):
+        CythonTransform.__init__(self, None)
+        self.names = set()
+    def visit_NameNode(self, node):
+        self.names.add(node.name)
+
+def unbound_symbols(code, context=None):
+    if context is None:
+        context = Context([], default_options)
+    from Cython.Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform
+    if isinstance(code, str):
+        code = code.decode('ascii')
+    tree = parse_from_strings('(tree fragment)', code)
+    for phase in context.create_pipeline(pxd=False):
+        if phase is None:
+            continue
+        tree = phase(tree)
+        if isinstance(phase, AnalyseDeclarationsTransform):
+            break
+    symbol_collector = AllSymbols()
+    symbol_collector(tree)
+    unbound = []
+    import __builtin__
+    for name in symbol_collector.names:
+        if not tree.scope.lookup(name) and not hasattr(__builtin__, name):
+            unbound.append(name)
+    return unbound
+        
 
 def get_type(arg, context=None):
     py_type = type(arg)
-    # TODO: extension types
     if py_type in [list, tuple, dict, str]:
         return py_type.__name__
     elif py_type is float:
@@ -40,21 +74,43 @@ def get_type(arg, context=None):
         return 'object'
 
 # TODO: use locals/globals for unbound variables
-def cython_inline(code, types='aggressive', lib_dir=os.path.expanduser('~/.cython/inline'), include_dirs=['.'], **kwds):
+def cython_inline(code, 
+                  types='aggressive',
+                  lib_dir=os.path.expanduser('~/.cython/inline'),
+                  include_dirs=['.'],
+                  locals=None,
+                  globals=None,
+                  **kwds):
     ctx = Context(include_dirs, default_options)
-    _, pyx_file = tempfile.mkstemp('.pyx')
+    if locals is None:
+        locals = inspect.currentframe().f_back.f_back.f_locals
+    if globals is None:
+        globals = inspect.currentframe().f_back.f_back.f_globals
+    try:
+        for symbol in unbound_symbols(code):
+            if symbol in kwds:
+                continue
+            elif symbol in locals:
+                kwds[symbol] = locals[symbol]
+            elif symbol in globals:
+                kwds[symbol] = globals[symbol]
+            else:
+                print "Couldn't find ", symbol
+    except AssertionError:
+        # Parsing from strings not fully supported (e.g. cimports).
+        print "Could not parse code as a string (to extract unbound symbols)."
     arg_names = kwds.keys()
     arg_names.sort()
     arg_sigs = tuple((get_type(kwds[arg], ctx), arg) for arg in arg_names)
     key = code, arg_sigs
-    module = code_cache.get(key)
+    module = _code_cache.get(key)
     if not module:
-        cimports = ''
+        cimports = []
         qualified = re.compile(r'([.\w]+)[.]')
         for type, _ in arg_sigs:
             m = qualified.match(type)
             if m:
-                cimports += '\ncimport %s' % m.groups()[0]
+                cimports.append('\ncimport %s' % m.groups()[0])
         module_body, func_body = extract_func_code(code)
         params = ', '.join('%s %s' % a for a in arg_sigs)
         module_code = """
@@ -62,8 +118,9 @@ def cython_inline(code, types='aggressive', lib_dir=os.path.expanduser('~/.cytho
 %(module_body)s
 def __invoke(%(params)s):
 %(func_body)s
-        """ % locals()
-        print module_code
+        """ % {'cimports': '\n'.join(cimports), 'module_body': module_body, 'params': params, 'func_body': func_body }
+#        print module_code
+        _, pyx_file = tempfile.mkstemp('.pyx')
         open(pyx_file, 'w').write(module_code)
         module = "_" + hashlib.md5(code + str(arg_sigs)).hexdigest()
         extension = Extension(
@@ -78,7 +135,7 @@ def __invoke(%(params)s):
             sys.path.append(lib_dir)
         build_extension.build_lib  = lib_dir
         build_extension.run()
-        code_cache[key] = module
+        _code_cache[key] = module
     arg_list = [kwds[arg] for arg in arg_names]
     return __import__(module).__invoke(*arg_list)
 
index 66feaf09e611a34aebd241afba0280bdcda8bd99..13e0dc111542942a5e5f8418785dc242d6375c91 100644 (file)
@@ -60,6 +60,7 @@ def parse_from_strings(name, code, pxds={}, level=None, initial_pos=None):
                      scope = scope, context = context, initial_pos = initial_pos)
     if level is None:
         tree = Parsing.p_module(scanner, 0, module_name)
+        tree.scope = scope
     else:
         tree = Parsing.p_code(scanner, level=level)
     return tree
@@ -201,6 +202,8 @@ class TreeFragment(object):
             if not isinstance(t, StatListNode):
                 t = StatListNode(pos=mod.pos, stats=[t])
             for transform in pipeline:
+                if transform is None:
+                    continue
                 t = transform(t)
             self.root = t
         elif isinstance(code, Node):