Remove trailing whitespace.
[cython.git] / Cython / Build / Inline.py
1 import tempfile
2 import sys, os, re, inspect
3 from cython import set
4
5 try:
6     import hashlib
7 except ImportError:
8     import md5 as hashlib
9
10 from distutils.core import Distribution, Extension
11 from distutils.command.build_ext import build_ext
12
13 import Cython
14 from Cython.Compiler.Main import Context, CompilationOptions, default_options
15
16 from Cython.Compiler.ParseTreeTransforms import CythonTransform, SkipDeclarations, AnalyseDeclarationsTransform
17 from Cython.Compiler.TreeFragment import parse_from_strings
18 from Cython.Build.Dependencies import strip_string_literals, cythonize
19
20 # A utility function to convert user-supplied ASCII strings to unicode.
21 if sys.version_info[0] < 3:
22     def to_unicode(s):
23         if not isinstance(s, unicode):
24             return s.decode('ascii')
25         else:
26             return s
27 else:
28     to_unicode = lambda x: x
29
30 _code_cache = {}
31
32
33 class AllSymbols(CythonTransform, SkipDeclarations):
34     def __init__(self):
35         CythonTransform.__init__(self, None)
36         self.names = set()
37     def visit_NameNode(self, node):
38         self.names.add(node.name)
39
40 def unbound_symbols(code, context=None):
41     code = to_unicode(code)
42     if context is None:
43         context = Context([], default_options)
44     from Cython.Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform
45     tree = parse_from_strings('(tree fragment)', code)
46     for phase in context.create_pipeline(pxd=False):
47         if phase is None:
48             continue
49         tree = phase(tree)
50         if isinstance(phase, AnalyseDeclarationsTransform):
51             break
52     symbol_collector = AllSymbols()
53     symbol_collector(tree)
54     unbound = []
55     import __builtin__
56     for name in symbol_collector.names:
57         if not tree.scope.lookup(name) and not hasattr(__builtin__, name):
58             unbound.append(name)
59     return unbound
60
61 def unsafe_type(arg, context=None):
62     py_type = type(arg)
63     if py_type is int:
64         return 'long'
65     else:
66         return safe_type(arg, context)
67
68 def safe_type(arg, context=None):
69     py_type = type(arg)
70     if py_type in [list, tuple, dict, str]:
71         return py_type.__name__
72     elif py_type is complex:
73         return 'double complex'
74     elif py_type is float:
75         return 'double'
76     elif py_type is bool:
77         return 'bint'
78     elif 'numpy' in sys.modules and isinstance(arg, sys.modules['numpy'].ndarray):
79         return 'numpy.ndarray[numpy.%s_t, ndim=%s]' % (arg.dtype.name, arg.ndim)
80     else:
81         for base_type in py_type.mro():
82             if base_type.__module__ == '__builtin__':
83                 return 'object'
84             module = context.find_module(base_type.__module__, need_pxd=False)
85             if module:
86                 entry = module.lookup(base_type.__name__)
87                 if entry.is_type:
88                     return '%s.%s' % (base_type.__module__, base_type.__name__)
89         return 'object'
90
91 def cython_inline(code,
92                   get_type=unsafe_type,
93                   lib_dir=os.path.expanduser('~/.cython/inline'),
94                   cython_include_dirs=['.'],
95                   force=False,
96                   quiet=False,
97                   locals=None,
98                   globals=None,
99                   **kwds):
100     if get_type is None:
101         get_type = lambda x: 'object'
102     code = to_unicode(code)
103     code, literals = strip_string_literals(code)
104     code = strip_common_indent(code)
105     ctx = Context(cython_include_dirs, default_options)
106     if locals is None:
107         locals = inspect.currentframe().f_back.f_back.f_locals
108     if globals is None:
109         globals = inspect.currentframe().f_back.f_back.f_globals
110     try:
111         for symbol in unbound_symbols(code):
112             if symbol in kwds:
113                 continue
114             elif symbol in locals:
115                 kwds[symbol] = locals[symbol]
116             elif symbol in globals:
117                 kwds[symbol] = globals[symbol]
118             else:
119                 print("Couldn't find ", symbol)
120     except AssertionError:
121         if not quiet:
122             # Parsing from strings not fully supported (e.g. cimports).
123             print("Could not parse code as a string (to extract unbound symbols).")
124     arg_names = kwds.keys()
125     arg_names.sort()
126     arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names])
127     key = code, arg_sigs, sys.version_info, sys.executable, Cython.__version__
128     module_name = "_cython_inline_" + hashlib.md5(str(key)).hexdigest()
129     try:
130         if not os.path.exists(lib_dir):
131             os.makedirs(lib_dir)
132         if lib_dir not in sys.path:
133             sys.path.append(lib_dir)
134         if force:
135             raise ImportError
136         else:
137             __import__(module_name)
138     except ImportError:
139         cflags = []
140         c_include_dirs = []
141         cimports = []
142         qualified = re.compile(r'([.\w]+)[.]')
143         for type, _ in arg_sigs:
144             m = qualified.match(type)
145             if m:
146                 cimports.append('\ncimport %s' % m.groups()[0])
147                 # one special case
148                 if m.groups()[0] == 'numpy':
149                     import numpy
150                     c_include_dirs.append(numpy.get_include())
151                     cflags.append('-Wno-unused')
152         module_body, func_body = extract_func_code(code)
153         params = ', '.join(['%s %s' % a for a in arg_sigs])
154         module_code = """
155 %(module_body)s
156 %(cimports)s
157 def __invoke(%(params)s):
158 %(func_body)s
159         """ % {'cimports': '\n'.join(cimports), 'module_body': module_body, 'params': params, 'func_body': func_body }
160         for key, value in literals.items():
161             module_code = module_code.replace(key, value)
162         pyx_file = os.path.join(lib_dir, module_name + '.pyx')
163         open(pyx_file, 'w').write(module_code)
164         extension = Extension(
165             name = module_name,
166             sources = [pyx_file],
167             include_dirs = c_include_dirs,
168             extra_compile_args = cflags)
169         build_extension = build_ext(Distribution())
170         build_extension.finalize_options()
171         build_extension.extensions = cythonize([extension], ctx=ctx, quiet=quiet)
172         build_extension.build_temp = os.path.dirname(pyx_file)
173         build_extension.build_lib  = lib_dir
174         build_extension.run()
175         _code_cache[key] = module_name
176     arg_list = [kwds[arg] for arg in arg_names]
177     return __import__(module_name).__invoke(*arg_list)
178
179 non_space = re.compile('[^ ]')
180 def strip_common_indent(code):
181     min_indent = None
182     lines = code.split('\n')
183     for line in lines:
184         match = non_space.search(line)
185         if not match:
186             continue # blank
187         indent = match.start()
188         if line[indent] == '#':
189             continue # comment
190         elif min_indent is None or min_indent > indent:
191             min_indent = indent
192     for ix, line in enumerate(lines):
193         match = non_space.search(line)
194         if not match or line[indent] == '#':
195             continue
196         else:
197             lines[ix] = line[min_indent:]
198     return '\n'.join(lines)
199
200 module_statement = re.compile(r'^((cdef +(extern|class))|cimport|(from .+ cimport)|(from .+ import +[*]))')
201 def extract_func_code(code):
202     module = []
203     function = []
204     current = function
205     code = code.replace('\t', ' ')
206     lines = code.split('\n')
207     for line in lines:
208         if not line.startswith(' '):
209             if module_statement.match(line):
210                 current = module
211             else:
212                 current = function
213         current.append(line)
214     return '\n'.join(module), '    ' + '\n    '.join(function)
215
216
217
218 try:
219     from inspect import getcallargs
220 except ImportError:
221     def getcallargs(func, *arg_values, **kwd_values):
222         all = {}
223         args, varargs, kwds, defaults = inspect.getargspec(func)
224         if varargs is not None:
225             all[varargs] = arg_values[len(args):]
226         for name, value in zip(args, arg_values):
227             all[name] = value
228         for name, value in kwd_values.items():
229             if name in args:
230                 if name in all:
231                     raise TypeError, "Duplicate argument %s" % name
232                 all[name] = kwd_values.pop(name)
233         if kwds is not None:
234             all[kwds] = kwd_values
235         elif kwd_values:
236             raise TypeError, "Unexpected keyword arguments: %s" % kwd_values.keys()
237         if defaults is None:
238             defaults = ()
239         first_default = len(args) - len(defaults)
240         for ix, name in enumerate(args):
241             if name not in all:
242                 if ix >= first_default:
243                     all[name] = defaults[ix - first_default]
244                 else:
245                     raise TypeError, "Missing argument: %s" % name
246         return all
247
248 def get_body(source):
249     ix = source.index(':')
250     if source[:5] == 'lambda':
251         return "return %s" % source[ix+1:]
252     else:
253         return source[ix+1:]
254
255 # Lots to be done here... It would be especially cool if compiled functions
256 # could invoke each other quickly.
257 class RuntimeCompiledFunction(object):
258
259     def __init__(self, f):
260         self._f = f
261         self._body = get_body(inspect.getsource(f))
262
263     def __call__(self, *args, **kwds):
264         all = getcallargs(self._f, *args, **kwds)
265         return cython_inline(self._body, locals=self._f.func_globals, globals=self._f.func_globals, **all)