2 import sys, os, re, inspect
10 from distutils.core import Distribution, Extension
11 from distutils.command.build_ext import build_ext
14 from Cython.Compiler.Main import Context, CompilationOptions, default_options
16 from Cython.Compiler.ParseTreeTransforms import CythonTransform, SkipDeclarations, AnalyseDeclarationsTransform
17 from Cython.Compiler.TreeFragment import parse_from_strings
18 from Cython.Build.Dependencies import strip_string_literals, cythonize
20 # A utility function to convert user-supplied ASCII strings to unicode.
21 if sys.version_info[0] < 3:
23 if not isinstance(s, unicode):
24 return s.decode('ascii')
28 to_unicode = lambda x: x
33 class AllSymbols(CythonTransform, SkipDeclarations):
35 CythonTransform.__init__(self, None)
37 def visit_NameNode(self, node):
38 self.names.add(node.name)
40 def unbound_symbols(code, context=None):
41 code = to_unicode(code)
43 context = Context([], default_options)
44 from Cython.Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform
45 tree = parse_from_strings('(tree fragment)', code)
46 for phase in context.create_pipeline(pxd=False):
50 if isinstance(phase, AnalyseDeclarationsTransform):
52 symbol_collector = AllSymbols()
53 symbol_collector(tree)
56 for name in symbol_collector.names:
57 if not tree.scope.lookup(name) and not hasattr(__builtin__, name):
61 def unsafe_type(arg, context=None):
66 return safe_type(arg, context)
68 def safe_type(arg, context=None):
70 if py_type in [list, tuple, dict, str]:
71 return py_type.__name__
72 elif py_type is complex:
73 return 'double complex'
74 elif py_type is float:
78 elif 'numpy' in sys.modules and isinstance(arg, sys.modules['numpy'].ndarray):
79 return 'numpy.ndarray[numpy.%s_t, ndim=%s]' % (arg.dtype.name, arg.ndim)
81 for base_type in py_type.mro():
82 if base_type.__module__ == '__builtin__':
84 module = context.find_module(base_type.__module__, need_pxd=False)
86 entry = module.lookup(base_type.__name__)
88 return '%s.%s' % (base_type.__module__, base_type.__name__)
91 def cython_inline(code,
93 lib_dir=os.path.expanduser('~/.cython/inline'),
94 cython_include_dirs=['.'],
101 get_type = lambda x: 'object'
102 code = to_unicode(code)
103 code, literals = strip_string_literals(code)
104 code = strip_common_indent(code)
105 ctx = Context(cython_include_dirs, default_options)
107 locals = inspect.currentframe().f_back.f_back.f_locals
109 globals = inspect.currentframe().f_back.f_back.f_globals
111 for symbol in unbound_symbols(code):
114 elif symbol in locals:
115 kwds[symbol] = locals[symbol]
116 elif symbol in globals:
117 kwds[symbol] = globals[symbol]
119 print("Couldn't find ", symbol)
120 except AssertionError:
122 # Parsing from strings not fully supported (e.g. cimports).
123 print("Could not parse code as a string (to extract unbound symbols).")
124 arg_names = kwds.keys()
126 arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names])
127 key = code, arg_sigs, sys.version_info, sys.executable, Cython.__version__
128 module_name = "_cython_inline_" + hashlib.md5(str(key)).hexdigest()
130 if not os.path.exists(lib_dir):
132 if lib_dir not in sys.path:
133 sys.path.append(lib_dir)
137 __import__(module_name)
142 qualified = re.compile(r'([.\w]+)[.]')
143 for type, _ in arg_sigs:
144 m = qualified.match(type)
146 cimports.append('\ncimport %s' % m.groups()[0])
148 if m.groups()[0] == 'numpy':
150 c_include_dirs.append(numpy.get_include())
151 cflags.append('-Wno-unused')
152 module_body, func_body = extract_func_code(code)
153 params = ', '.join(['%s %s' % a for a in arg_sigs])
157 def __invoke(%(params)s):
159 """ % {'cimports': '\n'.join(cimports), 'module_body': module_body, 'params': params, 'func_body': func_body }
160 for key, value in literals.items():
161 module_code = module_code.replace(key, value)
162 pyx_file = os.path.join(lib_dir, module_name + '.pyx')
163 open(pyx_file, 'w').write(module_code)
164 extension = Extension(
166 sources = [pyx_file],
167 include_dirs = c_include_dirs,
168 extra_compile_args = cflags)
169 build_extension = build_ext(Distribution())
170 build_extension.finalize_options()
171 build_extension.extensions = cythonize([extension], ctx=ctx, quiet=quiet)
172 build_extension.build_temp = os.path.dirname(pyx_file)
173 build_extension.build_lib = lib_dir
174 build_extension.run()
175 _code_cache[key] = module_name
176 arg_list = [kwds[arg] for arg in arg_names]
177 return __import__(module_name).__invoke(*arg_list)
179 non_space = re.compile('[^ ]')
180 def strip_common_indent(code):
182 lines = code.split('\n')
184 match = non_space.search(line)
187 indent = match.start()
188 if line[indent] == '#':
190 elif min_indent is None or min_indent > indent:
192 for ix, line in enumerate(lines):
193 match = non_space.search(line)
194 if not match or line[indent] == '#':
197 lines[ix] = line[min_indent:]
198 return '\n'.join(lines)
200 module_statement = re.compile(r'^((cdef +(extern|class))|cimport|(from .+ cimport)|(from .+ import +[*]))')
201 def extract_func_code(code):
205 code = code.replace('\t', ' ')
206 lines = code.split('\n')
208 if not line.startswith(' '):
209 if module_statement.match(line):
214 return '\n'.join(module), ' ' + '\n '.join(function)
219 from inspect import getcallargs
221 def getcallargs(func, *arg_values, **kwd_values):
223 args, varargs, kwds, defaults = inspect.getargspec(func)
224 if varargs is not None:
225 all[varargs] = arg_values[len(args):]
226 for name, value in zip(args, arg_values):
228 for name, value in kwd_values.items():
231 raise TypeError, "Duplicate argument %s" % name
232 all[name] = kwd_values.pop(name)
234 all[kwds] = kwd_values
236 raise TypeError, "Unexpected keyword arguments: %s" % kwd_values.keys()
239 first_default = len(args) - len(defaults)
240 for ix, name in enumerate(args):
242 if ix >= first_default:
243 all[name] = defaults[ix - first_default]
245 raise TypeError, "Missing argument: %s" % name
248 def get_body(source):
249 ix = source.index(':')
250 if source[:5] == 'lambda':
251 return "return %s" % source[ix+1:]
255 # Lots to be done here... It would be especially cool if compiled functions
256 # could invoke each other quickly.
257 class RuntimeCompiledFunction(object):
259 def __init__(self, f):
261 self._body = get_body(inspect.getsource(f))
263 def __call__(self, *args, **kwds):
264 all = getcallargs(self._f, *args, **kwds)
265 return cython_inline(self._body, locals=self._f.func_globals, globals=self._f.func_globals, **all)