cdbe4e863251a33fbd7a5ebeec801888e338bc65
[cython.git] / Cython / Build / Dependencies.py
1 from glob import glob
2 import re, os, sys
3
4 try:
5     set
6 except NameError:
7     # Python 2.3
8     from sets import Set as set
9
10 from distutils.extension import Extension
11
12 from Cython import Utils
13 from Cython.Compiler.Main import Context, CompilationOptions, default_options
14
15 # Unfortunately, Python 2.3 doesn't support decorators.
16 def cached_method(f):
17     cache_name = '__%s_cache' % f.__name__
18     def wrapper(self, *args):
19         cache = getattr(self, cache_name, None)
20         if cache is None:
21             cache = {}
22             setattr(self, cache_name, cache)
23         if args in cache:
24             return cache[args]
25         res = cache[args] = f(self, *args)
26         return res
27     return wrapper
28
29
30 def parse_list(s):
31     if s[0] == '[' and s[-1] == ']':
32         s = s[1:-1]
33         delimiter = ','
34     else:
35         delimiter = ' '
36     s, literals = strip_string_literals(s)
37     def unquote(literal):
38         literal = literal.strip()
39         if literal[0] == "'":
40             return literals[literal[1:-1]]
41         else:
42             return literal
43             
44     return [unquote(item) for item in s.split(delimiter)]
45
46 transitive_str = object()
47 transitive_list = object()
48
49 distutils_settings = {
50     'name':                 str,
51     'sources':              list,
52     'define_macros':        list,
53     'undef_macros':         list,
54     'libraries':            transitive_list,
55     'library_dirs':         transitive_list,
56     'runtime_library_dirs': transitive_list,
57     'include_dirs':         transitive_list,
58     'extra_objects':        list,
59     'extra_compile_args':   list,
60     'extra_link_args':      list,
61     'export_symbols':       list,
62     'depends':              transitive_list,
63     'language':             transitive_str,
64 }
65
66 def line_iter(source):
67     start = 0
68     while True:
69         end = source.find('\n', start)
70         if end == -1:
71             yield source[start:]
72             return
73         yield source[start:end]
74         start = end+1
75
76 class DistutilsInfo(object):
77     
78     def __init__(self, source=None, exn=None):
79         self.values = {}
80         if source is not None:
81             for line in line_iter(source):
82                 line = line.strip()
83                 if line != '' and line[0] != '#':
84                     break
85                 line = line[1:].strip()
86                 if line[:10] == 'distutils:':
87                     line = line[10:]
88                     ix = line.index('=')
89                     key = str(line[:ix].strip())
90                     value = line[ix+1:].strip()
91                     type = distutils_settings[key]
92                     if type in (list, transitive_list):
93                         value = parse_list(value)
94                         if key == 'define_macros':
95                             value = [tuple(macro.split('=')) for macro in value]
96                     self.values[key] = value
97         elif exn is not None:
98             for key in distutils_settings:
99                 if key in ('name', 'sources'):
100                     pass
101                 value = getattr(exn, key, None)
102                 if value:
103                     self.values[key] = value
104     
105     def merge(self, other):
106         if other is None:
107             return self
108         for key, value in other.values.items():
109             type = distutils_settings[key]
110             if type is transitive_str and key not in self.values:
111                 self.values[key] = value
112             elif type is transitive_list:
113                 if key in self.values:
114                     all = self.values[key]
115                     for v in value:
116                         if v not in all:
117                             all.append(v)
118                 else:
119                     self.values[key] = value
120         return self
121     
122     def subs(self, aliases):
123         if aliases is None:
124             return self
125         resolved = DistutilsInfo()
126         for key, value in self.values.items():
127             type = distutils_settings[key]
128             if type in [list, transitive_list]:
129                 new_value_list = []
130                 for v in value:
131                     if v in aliases:
132                         v = aliases[v]
133                     if isinstance(v, list):
134                         new_value_list += v
135                     else:
136                         new_value_list.append(v)
137                 value = new_value_list
138             else:
139                 if value in aliases:
140                     value = aliases[value]
141             resolved.values[key] = value
142         return resolved
143
144
145 def strip_string_literals(code, prefix='__Pyx_L'):
146     """
147     Normalizes every string literal to be of the form '__Pyx_Lxxx', 
148     returning the normalized code and a mapping of labels to
149     string literals. 
150     """
151     new_code = []
152     literals = {}
153     counter = 0
154     start = q = 0
155     in_quote = False
156     raw = False
157     while True:
158         hash_mark = code.find('#', q)
159         single_q = code.find("'", q)
160         double_q = code.find('"', q)
161         q = min(single_q, double_q)
162         if q == -1: q = max(single_q, double_q)
163         
164         # Process comment.
165         if -1 < hash_mark and (hash_mark < q or q == -1):
166             end = code.find('\n', hash_mark)
167             if end == -1:
168                 end = None
169             new_code.append(code[start:hash_mark+1])
170             counter += 1
171             label = "%s%s" % (prefix, counter)
172             literals[label] = code[hash_mark+1:end]
173             new_code.append(label)
174             if end is None:
175                 break
176             q = end
177             start = q
178
179         # We're done.
180         elif q == -1:
181             new_code.append(code[start:])
182             break
183
184         # Try to close the quote.
185         elif in_quote:
186             if code[q-1] == '\\':
187                 k = 2
188                 while q >= k and code[q-k] == '\\':
189                     k += 1
190                 if k % 2 == 0:
191                     q += 1
192                     continue
193             if code[q:q+len(in_quote)] == in_quote:
194                 counter += 1
195                 label = "%s%s" % (prefix, counter)
196                 literals[label] = code[start+len(in_quote):q]
197                 new_code.append("%s%s%s" % (in_quote, label, in_quote))
198                 q += len(in_quote)
199                 in_quote = False
200                 start = q
201             else:
202                 q += 1
203
204         # Open the quote.
205         else:
206             raw = False
207             if len(code) >= q+3 and (code[q+1] == code[q] == code[q+2]):
208                 in_quote = code[q]*3
209             else:
210                 in_quote = code[q]
211             end = q
212             while end>0 and code[end-1] in 'rRbBuU':
213                 if code[end-1] in 'rR':
214                     raw = True
215                 end -= 1
216             new_code.append(code[start:end])
217             start = q
218             q += len(in_quote)
219     
220     return "".join(new_code), literals
221
222
223 def parse_dependencies(source_filename):
224     # Actual parsing is way to slow, so we use regular expressions.
225     # The only catch is that we must strip comments and string
226     # literals ahead of time.
227     source = Utils.open_source_file(source_filename, "rU").read()
228     distutils_info = DistutilsInfo(source)
229     source, literals = strip_string_literals(source)
230     source = source.replace('\\\n', ' ')
231     if '\t' in source:
232         source = source.replace('\t', ' ')
233     # TODO: pure mode
234     dependancy = re.compile(r"(cimport +([0-9a-zA-Z_.]+)\b)|(from +([0-9a-zA-Z_.]+) +cimport)|(include +'([^']+)')|(cdef +extern +from +'([^']+)')")
235     cimports = []
236     includes = []
237     externs  = []
238     for m in dependancy.finditer(source):
239         groups = m.groups()
240         if groups[0]:
241             cimports.append(groups[1])
242         elif groups[2]:
243             cimports.append(groups[3])
244         elif groups[4]:
245             includes.append(literals[groups[5]])
246         else:
247             externs.append(literals[groups[7]])
248     return cimports, includes, externs, distutils_info
249
250
251 class DependencyTree(object):
252     
253     def __init__(self, context):
254         self.context = context
255         self._transitive_cache = {}
256     
257     #@cached_method
258     def parse_dependencies(self, source_filename):
259         return parse_dependencies(source_filename)
260     parse_dependencies = cached_method(parse_dependencies)
261     
262     #@cached_method
263     def cimports_and_externs(self, filename):
264         cimports, includes, externs = self.parse_dependencies(filename)[:3]
265         cimports = set(cimports)
266         externs = set(externs)
267         for include in includes:
268             a, b = self.cimports_and_externs(os.path.join(os.path.dirname(filename), include))
269             cimports.update(a)
270             externs.update(b)
271         return tuple(cimports), tuple(externs)
272     cimports_and_externs = cached_method(cimports_and_externs)
273     
274     def cimports(self, filename):
275         return self.cimports_and_externs(filename)[0]
276     
277     #@cached_method
278     def package(self, filename):
279         dir = os.path.dirname(filename)
280         if os.path.exists(os.path.join(dir, '__init__.py')):
281             return self.package(dir) + (os.path.basename(dir),)
282         else:
283             return ()
284     package = cached_method(package)
285     
286     #@cached_method
287     def fully_qualifeid_name(self, filename):
288         module = os.path.splitext(os.path.basename(filename))[0]
289         return '.'.join(self.package(filename) + (module,))
290     fully_qualifeid_name = cached_method(fully_qualifeid_name)
291     
292     def find_pxd(self, module, filename=None):
293         if module[0] == '.':
294             raise NotImplementedError("New relative imports.")
295         if filename is not None:
296             relative = '.'.join(self.package(filename) + tuple(module.split('.')))
297             pxd = self.context.find_pxd_file(relative, None)
298             if pxd:
299                 return pxd
300         return self.context.find_pxd_file(module, None)
301         
302     #@cached_method
303     def cimported_files(self, filename):
304         if filename[-4:] == '.pyx' and os.path.exists(filename[:-4] + '.pxd'):
305             self_pxd = [filename[:-4] + '.pxd']
306         else:
307             self_pxd = []
308         a = self.cimports(filename)
309         b = filter(None, [self.find_pxd(m, filename) for m in self.cimports(filename)])
310         if len(a) != len(b):
311             print(filename)
312             print("\n\t".join(a))
313             print("\n\t".join(b))
314         return tuple(self_pxd + filter(None, [self.find_pxd(m, filename) for m in self.cimports(filename)]))
315     cimported_files = cached_method(cimported_files)
316     
317     def immediate_dependencies(self, filename):
318         all = list(self.cimported_files(filename))
319         for extern in sum(self.cimports_and_externs(filename), ()):
320             all.append(os.path.normpath(os.path.join(os.path.dirname(filename), extern)))
321         return tuple(all)
322     
323     #@cached_method
324     def timestamp(self, filename):
325         return os.path.getmtime(filename)
326     timestamp = cached_method(timestamp)
327     
328     def extract_timestamp(self, filename):
329         # TODO: .h files from extern blocks
330         return self.timestamp(filename), filename
331     
332     def newest_dependency(self, filename):
333         return self.transitive_merge(filename, self.extract_timestamp, max)
334     
335     def distutils_info0(self, filename):
336         return self.parse_dependencies(filename)[3]
337     
338     def distutils_info(self, filename, aliases=None, base=None):
339         return (self.transitive_merge(filename, self.distutils_info0, DistutilsInfo.merge)
340             .subs(aliases)
341             .merge(base))
342     
343     def transitive_merge(self, node, extract, merge):
344         try:
345             seen = self._transitive_cache[extract, merge]
346         except KeyError:
347             seen = self._transitive_cache[extract, merge] = {}
348         return self.transitive_merge_helper(
349             node, extract, merge, seen, {}, self.cimported_files)[0]
350     
351     def transitive_merge_helper(self, node, extract, merge, seen, stack, outgoing):
352         if node in seen:
353             return seen[node], None
354         deps = extract(node)
355         if node in stack:
356             return deps, node
357         try:
358             stack[node] = len(stack)
359             loop = None
360             for next in outgoing(node):
361                 sub_deps, sub_loop = self.transitive_merge_helper(next, extract, merge, seen, stack, outgoing)
362                 if sub_loop is not None:
363                     if loop is not None and stack[loop] < stack[sub_loop]:
364                         pass
365                     else:
366                         loop = sub_loop
367                 deps = merge(deps, sub_deps)
368             if loop == node:
369                 loop = None
370             if loop is None:
371                 seen[node] = deps
372             return deps, loop
373         finally:
374             del stack[node]
375
376 _dep_tree = None
377 def create_dependency_tree(ctx=None):
378     global _dep_tree
379     if _dep_tree is None:
380         if ctx is None:
381             ctx = Context(["."], CompilationOptions(default_options))
382         _dep_tree = DependencyTree(ctx)
383     return _dep_tree
384
385 # This may be useful for advanced users?
386 def create_extension_list(patterns, ctx=None, aliases=None):
387     seen = set()
388     deps = create_dependency_tree(ctx)
389     if not isinstance(patterns, list):
390         patterns = [patterns]
391     module_list = []
392     for pattern in patterns:
393         if isinstance(pattern, str):
394             filepattern = pattern
395             template = None
396             name = '*'
397             base = None
398             exn_type = Extension
399         elif isinstance(pattern, Extension):
400             filepattern = pattern.sources[0]
401             if os.path.splitext(filepattern)[1] not in ('.py', '.pyx'):
402                 # ignore non-cython modules
403                 module_list.append(pattern)
404                 continue
405             template = pattern
406             name = template.name
407             base = DistutilsInfo(exn=template)
408             exn_type = template.__class__
409         else:
410             raise TypeError(pattern)
411         for file in glob(filepattern):
412             pkg = deps.package(file)
413             if '*' in name:
414                 module_name = deps.fully_qualifeid_name(file)
415             else:
416                 module_name = name
417             if module_name not in seen:
418                 module_list.append(exn_type(
419                         name=module_name,
420                         sources=[file],
421                         **deps.distutils_info(file, aliases, base).values))
422                 seen.add(name)
423     return module_list
424
425 # This is the user-exposed entry point.
426 def cythonize(module_list, nthreads=0, aliases=None, **options):    
427     c_options = CompilationOptions(options)
428     cpp_options = CompilationOptions(options); cpp_options.cplus = True
429     ctx = options.create_context()
430     module_list = create_extension_list(module_list, ctx=ctx, aliases=aliases)
431     deps = create_dependency_tree(ctx)
432     to_compile = []
433     for m in module_list:
434         new_sources = []
435         for source in m.sources:
436             base, ext = os.path.splitext(source)
437             if ext in ('.pyx', '.py'):
438                 if m.language == 'c++':
439                     c_file = base + '.cpp'
440                     options = cpp_options
441                 else:
442                     c_file = base + '.c'
443                     options = c_options
444                 if os.path.exists(c_file):
445                     c_timestamp = os.path.getmtime(c_file)
446                 else:
447                     c_timestamp = -1
448                 # Priority goes first to modified files, second to direct
449                 # dependents, and finally to indirect dependents.
450                 if c_timestamp < deps.timestamp(source):
451                     dep_timestamp, dep = deps.timestamp(source), source
452                     priority = 0
453                 else:
454                     dep_timestamp, dep = deps.newest_dependency(source)
455                     priority = 2 - (dep in deps.immediate_dependencies(source))
456                 if c_timestamp < dep_timestamp:
457                     print("Compiling %s because it depends on %s" % (source, dep))
458                     to_compile.append((priority, source, c_file, options))
459                 new_sources.append(c_file)
460             else:
461                 new_sources.append(source)
462         m.sources = new_sources
463     to_compile.sort()
464     if nthreads:
465         # Requires multiprocessing (or Python >= 2.6)
466         try:
467             import multiprocessing
468             pool = multiprocessing.Pool(nthreads)
469             pool.map(cythonize_one_helper, to_compile)
470         except ImportError:
471             print("multiprocessing required for parallel cythonization")
472             nthreads = 0
473     if not nthreads:
474         for priority, pyx_file, c_file, options in to_compile:
475             cythonize_one(pyx_file, c_file, options)
476     return module_list
477
478 # TODO: Share context? Issue: pyx processing leaks into pxd module
479 def cythonize_one(pyx_file, c_file, options=None):
480     from Cython.Compiler.Main import compile, default_options
481     from Cython.Compiler.Errors import CompileError, PyrexError
482
483     if options is None:
484         options = CompilationOptions(default_options)
485     options.output_file = c_file
486
487     any_failures = 0
488     try:
489         result = compile([pyx_file], options)
490         if result.num_errors > 0:
491             any_failures = 1
492     except (EnvironmentError, PyrexError), e:
493         sys.stderr.write(str(e) + '\n')
494         any_failures = 1
495     if any_failures:
496         raise CompileError(None, pyx_file)
497
498 def cythonize_one_helper(m):
499     return cythonize_one(*m[1:])