Compile decorator.
authorRobert Bradshaw <robertwb@math.washington.edu>
Sat, 6 Nov 2010 06:34:10 +0000 (23:34 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Sat, 6 Nov 2010 06:34:10 +0000 (23:34 -0700)
Cython/Build/Dependencies.py
Cython/Build/Inline.py
Cython/Shadow.py

index 9d6e5f5eee0650fd73dbcbb3cdf409adef8c346d..7af60c1107a6a89edcd32d41a8f6e3d353c48d03 100644 (file)
@@ -161,7 +161,7 @@ def strip_string_literals(code, prefix='__Pyx_L'):
         if q == -1: q = max(single_q, double_q)
         
         # Process comment.
-        if hash_mark < q or hash_mark > -1 == q:
+        if -1 < hash_mark and (hash_mark < q or q == -1):
             end = code.find('\n', hash_mark)
             if end == -1:
                 end = None
@@ -173,6 +173,7 @@ def strip_string_literals(code, prefix='__Pyx_L'):
             if end is None:
                 break
             q = end
+            start = q
 
         # We're done.
         elif q == -1:
@@ -194,8 +195,8 @@ def strip_string_literals(code, prefix='__Pyx_L'):
                 literals[label] = code[start+len(in_quote):q]
                 new_code.append("%s%s%s" % (in_quote, label, in_quote))
                 q += len(in_quote)
-                start = q
                 in_quote = False
+                start = q
             else:
                 q += 1
 
index 94b79aa7fa9f9ca2688b3b2303a95e94227f66c0..ced513c65e442358ef6e4fb049c2f2a8f5bc214a 100644 (file)
@@ -108,7 +108,12 @@ def cython_inline(code,
     arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names])
     key = code, arg_sigs, sys.version_info, sys.executable, Cython.__version__
     module_name = "_cython_inline_" + hashlib.md5(str(key)).hexdigest()
+#    # TODO: Does this cover all the platforms?
+#    if (not os.path.exists(os.path.join(lib_dir, module_name + ".so")) and 
+#        not os.path.exists(os.path.join(lib_dir, module_name + ".dll"))):
     try:
+        if not os.path.exists(lib_dir):
+            os.makedirs(lib_dir)
         if lib_dir not in sys.path:
             sys.path.append(lib_dir)
         __import__(module_name)
@@ -134,7 +139,7 @@ def __invoke(%(params)s):
         """ % {'cimports': '\n'.join(cimports), 'module_body': module_body, 'params': params, 'func_body': func_body }
         for key, value in literals.items():
             module_code = module_code.replace(key, value)
-        pyx_file = os.path.join(tempfile.mkdtemp(), module_name + '.pyx')
+        pyx_file = os.path.join(lib_dir, module_name + '.pyx')
         open(pyx_file, 'w').write(module_code)
         extension = Extension(
             name = module_name,
@@ -175,7 +180,6 @@ module_statement = re.compile(r'^((cdef +(extern|class))|cimport|(from .+ cimpor
 def extract_func_code(code):
     module = []
     function = []
-    # TODO: string literals, backslash
     current = function
     code = code.replace('\t', ' ')
     lines = code.split('\n')
@@ -187,3 +191,54 @@ def extract_func_code(code):
                 current = function
         current.append(line)
     return '\n'.join(module), '    ' + '\n    '.join(function)
+
+
+
+try:
+    from inspect import getcallargs
+except ImportError:
+    def getcallargs(func, *arg_values, **kwd_values):
+        all = {}
+        args, varargs, kwds, defaults = inspect.getargspec(func)
+        if varargs is not None:
+            all[varargs] = arg_values[len(args):]
+        for name, value in zip(args, arg_values):
+            all[name] = value
+        for name, value in kwd_values.items():
+            if name in args:
+                if name in all:
+                    raise TypeError, "Duplicate argument %s" % name
+                all[name] = kwd_values.pop(name)
+        if kwds is not None:
+            all[kwds] = kwd_values
+        elif kwd_values:
+            raise TypeError, "Unexpected keyword arguments: %s" % kwd_values.keys()
+        if defaults is None:
+            defaults = ()
+        first_default = len(args) - len(defaults)
+        for ix, name in enumerate(args):
+            if name not in all:
+                if ix >= first_default:
+                    all[name] = defaults[ix - first_default]
+                else:
+                    raise TypeError, "Missing argument: %s" % name
+        return all
+
+def get_body(source):
+    ix = source.index(':')
+    if source[:5] == 'lambda':
+        return "return %s" % source[ix+1:]
+    else:
+        return source[ix+1:]
+
+# Lots to be done here... It would be especially cool if compiled functions 
+# could invoke each other quickly.
+class RuntimeCompiledFunction(object):
+
+    def __init__(self, f):
+        self._f = f
+        self._body = get_body(inspect.getsource(f))
+    
+    def __call__(self, *args, **kwds):
+        all = getcallargs(self._f, *args, **kwds)
+        return cython_inline(self._body, locals=self._f.func_globals, globals=self._f.func_globals, **all)
index 48278c0382d7a3bb9e7a56ed80209c3070329b1a..a4bbd20662b7c41511d8379067c11d2371a518b9 100644 (file)
@@ -18,6 +18,10 @@ def inline(f, *args, **kwds):
     assert len(args) == len(kwds) == 0
     return f
 
+def compile(f):
+    from Cython.Build.Inline import RuntimeCompiledFunction
+    return RuntimeCompiledFunction(f)
+
 # Special functions
 
 def cdiv(a, b):