Remove trailing whitespace.
[cython.git] / Cython / Build / Dependencies.py
1 from glob import glob
2 import re, os, sys
3 from cython import set
4
5
6 from distutils.extension import Extension
7
8 from Cython import Utils
9 from Cython.Compiler.Main import Context, CompilationOptions, default_options
10
11 # Unfortunately, Python 2.3 doesn't support decorators.
12 def cached_method(f):
13     cache_name = '__%s_cache' % f.__name__
14     def wrapper(self, *args):
15         cache = getattr(self, cache_name, None)
16         if cache is None:
17             cache = {}
18             setattr(self, cache_name, cache)
19         if args in cache:
20             return cache[args]
21         res = cache[args] = f(self, *args)
22         return res
23     return wrapper
24
25
26 def parse_list(s):
27     if s[0] == '[' and s[-1] == ']':
28         s = s[1:-1]
29         delimiter = ','
30     else:
31         delimiter = ' '
32     s, literals = strip_string_literals(s)
33     def unquote(literal):
34         literal = literal.strip()
35         if literal[0] == "'":
36             return literals[literal[1:-1]]
37         else:
38             return literal
39
40     return [unquote(item) for item in s.split(delimiter)]
41
42 transitive_str = object()
43 transitive_list = object()
44
45 distutils_settings = {
46     'name':                 str,
47     'sources':              list,
48     'define_macros':        list,
49     'undef_macros':         list,
50     'libraries':            transitive_list,
51     'library_dirs':         transitive_list,
52     'runtime_library_dirs': transitive_list,
53     'include_dirs':         transitive_list,
54     'extra_objects':        list,
55     'extra_compile_args':   transitive_list,
56     'extra_link_args':      transitive_list,
57     'export_symbols':       list,
58     'depends':              transitive_list,
59     'language':             transitive_str,
60 }
61
62 def line_iter(source):
63     start = 0
64     while True:
65         end = source.find('\n', start)
66         if end == -1:
67             yield source[start:]
68             return
69         yield source[start:end]
70         start = end+1
71
72 class DistutilsInfo(object):
73
74     def __init__(self, source=None, exn=None):
75         self.values = {}
76         if source is not None:
77             for line in line_iter(source):
78                 line = line.strip()
79                 if line != '' and line[0] != '#':
80                     break
81                 line = line[1:].strip()
82                 if line[:10] == 'distutils:':
83                     line = line[10:]
84                     ix = line.index('=')
85                     key = str(line[:ix].strip())
86                     value = line[ix+1:].strip()
87                     type = distutils_settings[key]
88                     if type in (list, transitive_list):
89                         value = parse_list(value)
90                         if key == 'define_macros':
91                             value = [tuple(macro.split('=')) for macro in value]
92                     self.values[key] = value
93         elif exn is not None:
94             for key in distutils_settings:
95                 if key in ('name', 'sources'):
96                     continue
97                 value = getattr(exn, key, None)
98                 if value:
99                     self.values[key] = value
100
101     def merge(self, other):
102         if other is None:
103             return self
104         for key, value in other.values.items():
105             type = distutils_settings[key]
106             if type is transitive_str and key not in self.values:
107                 self.values[key] = value
108             elif type is transitive_list:
109                 if key in self.values:
110                     all = self.values[key]
111                     for v in value:
112                         if v not in all:
113                             all.append(v)
114                 else:
115                     self.values[key] = value
116         return self
117
118     def subs(self, aliases):
119         if aliases is None:
120             return self
121         resolved = DistutilsInfo()
122         for key, value in self.values.items():
123             type = distutils_settings[key]
124             if type in [list, transitive_list]:
125                 new_value_list = []
126                 for v in value:
127                     if v in aliases:
128                         v = aliases[v]
129                     if isinstance(v, list):
130                         new_value_list += v
131                     else:
132                         new_value_list.append(v)
133                 value = new_value_list
134             else:
135                 if value in aliases:
136                     value = aliases[value]
137             resolved.values[key] = value
138         return resolved
139
140
141 def strip_string_literals(code, prefix='__Pyx_L'):
142     """
143     Normalizes every string literal to be of the form '__Pyx_Lxxx',
144     returning the normalized code and a mapping of labels to
145     string literals.
146     """
147     new_code = []
148     literals = {}
149     counter = 0
150     start = q = 0
151     in_quote = False
152     raw = False
153     while True:
154         hash_mark = code.find('#', q)
155         single_q = code.find("'", q)
156         double_q = code.find('"', q)
157         q = min(single_q, double_q)
158         if q == -1: q = max(single_q, double_q)
159
160         # We're done.
161         if q == -1 and hash_mark == -1:
162             new_code.append(code[start:])
163             break
164
165         # Try to close the quote.
166         elif in_quote:
167             if code[q-1] == '\\' and not raw:
168                 k = 2
169                 while q >= k and code[q-k] == '\\':
170                     k += 1
171                 if k % 2 == 0:
172                     q += 1
173                     continue
174             if code[q:q+len(in_quote)] == in_quote:
175                 counter += 1
176                 label = "%s%s_" % (prefix, counter)
177                 literals[label] = code[start+len(in_quote):q]
178                 new_code.append("%s%s%s" % (in_quote, label, in_quote))
179                 q += len(in_quote)
180                 in_quote = False
181                 start = q
182             else:
183                 q += 1
184
185         # Process comment.
186         elif -1 != hash_mark and (hash_mark < q or q == -1):
187             end = code.find('\n', hash_mark)
188             if end == -1:
189                 end = None
190             new_code.append(code[start:hash_mark+1])
191             counter += 1
192             label = "%s%s_" % (prefix, counter)
193             literals[label] = code[hash_mark+1:end]
194             new_code.append(label)
195             if end is None:
196                 break
197             q = end
198             start = q
199
200         # Open the quote.
201         else:
202             raw = False
203             if len(code) >= q+3 and (code[q+1] == code[q] == code[q+2]):
204                 in_quote = code[q]*3
205             else:
206                 in_quote = code[q]
207             end = marker = q
208             while marker > 0 and code[marker-1] in 'rRbBuU':
209                 if code[marker-1] in 'rR':
210                     raw = True
211                 marker -= 1
212             new_code.append(code[start:end])
213             start = q
214             q += len(in_quote)
215
216     return "".join(new_code), literals
217
218
219 def parse_dependencies(source_filename):
220     # Actual parsing is way to slow, so we use regular expressions.
221     # The only catch is that we must strip comments and string
222     # literals ahead of time.
223     source = Utils.open_source_file(source_filename, "rU").read()
224     distutils_info = DistutilsInfo(source)
225     source, literals = strip_string_literals(source)
226     source = source.replace('\\\n', ' ')
227     if '\t' in source:
228         source = source.replace('\t', ' ')
229     # TODO: pure mode
230     dependancy = re.compile(r"(cimport +([0-9a-zA-Z_.]+)\b)|(from +([0-9a-zA-Z_.]+) +cimport)|(include +'([^']+)')|(cdef +extern +from +'([^']+)')")
231     cimports = []
232     includes = []
233     externs  = []
234     for m in dependancy.finditer(source):
235         groups = m.groups()
236         if groups[0]:
237             cimports.append(groups[1])
238         elif groups[2]:
239             cimports.append(groups[3])
240         elif groups[4]:
241             includes.append(literals[groups[5]])
242         else:
243             externs.append(literals[groups[7]])
244     return cimports, includes, externs, distutils_info
245
246
247 class DependencyTree(object):
248
249     def __init__(self, context):
250         self.context = context
251         self._transitive_cache = {}
252
253     #@cached_method
254     def parse_dependencies(self, source_filename):
255         return parse_dependencies(source_filename)
256     parse_dependencies = cached_method(parse_dependencies)
257
258     #@cached_method
259     def cimports_and_externs(self, filename):
260         cimports, includes, externs = self.parse_dependencies(filename)[:3]
261         cimports = set(cimports)
262         externs = set(externs)
263         for include in includes:
264             include_path = os.path.join(os.path.dirname(filename), include)
265             if not os.path.exists(include_path):
266                 include_path = self.context.find_include_file(include, None)
267             if include_path:
268                 a, b = self.cimports_and_externs(include_path)
269                 cimports.update(a)
270                 externs.update(b)
271             else:
272                 print("Unable to locate '%s' referenced from '%s'" % (filename, include))
273         return tuple(cimports), tuple(externs)
274     cimports_and_externs = cached_method(cimports_and_externs)
275
276     def cimports(self, filename):
277         return self.cimports_and_externs(filename)[0]
278
279     #@cached_method
280     def package(self, filename):
281         dir = os.path.dirname(filename)
282         if os.path.exists(os.path.join(dir, '__init__.py')):
283             return self.package(dir) + (os.path.basename(dir),)
284         else:
285             return ()
286     package = cached_method(package)
287
288     #@cached_method
289     def fully_qualifeid_name(self, filename):
290         module = os.path.splitext(os.path.basename(filename))[0]
291         return '.'.join(self.package(filename) + (module,))
292     fully_qualifeid_name = cached_method(fully_qualifeid_name)
293
294     def find_pxd(self, module, filename=None):
295         if module[0] == '.':
296             raise NotImplementedError("New relative imports.")
297         if filename is not None:
298             relative = '.'.join(self.package(filename) + tuple(module.split('.')))
299             pxd = self.context.find_pxd_file(relative, None)
300             if pxd:
301                 return pxd
302         return self.context.find_pxd_file(module, None)
303     find_pxd = cached_method(find_pxd)
304
305     #@cached_method
306     def cimported_files(self, filename):
307         if filename[-4:] == '.pyx' and os.path.exists(filename[:-4] + '.pxd'):
308             self_pxd = [filename[:-4] + '.pxd']
309         else:
310             self_pxd = []
311         a = self.cimports(filename)
312         b = filter(None, [self.find_pxd(m, filename) for m in self.cimports(filename)])
313         if len(a) - int('cython' in a) != len(b):
314             print("missing cimport", filename)
315             print("\n\t".join(a))
316             print("\n\t".join(b))
317         return tuple(self_pxd + filter(None, [self.find_pxd(m, filename) for m in self.cimports(filename)]))
318     cimported_files = cached_method(cimported_files)
319
320     def immediate_dependencies(self, filename):
321         all = list(self.cimported_files(filename))
322         for extern in sum(self.cimports_and_externs(filename), ()):
323             all.append(os.path.normpath(os.path.join(os.path.dirname(filename), extern)))
324         return tuple(all)
325
326     #@cached_method
327     def timestamp(self, filename):
328         return os.path.getmtime(filename)
329     timestamp = cached_method(timestamp)
330
331     def extract_timestamp(self, filename):
332         # TODO: .h files from extern blocks
333         return self.timestamp(filename), filename
334
335     def newest_dependency(self, filename):
336         return self.transitive_merge(filename, self.extract_timestamp, max)
337
338     def distutils_info0(self, filename):
339         return self.parse_dependencies(filename)[3]
340
341     def distutils_info(self, filename, aliases=None, base=None):
342         return (self.transitive_merge(filename, self.distutils_info0, DistutilsInfo.merge)
343             .subs(aliases)
344             .merge(base))
345
346     def transitive_merge(self, node, extract, merge):
347         try:
348             seen = self._transitive_cache[extract, merge]
349         except KeyError:
350             seen = self._transitive_cache[extract, merge] = {}
351         return self.transitive_merge_helper(
352             node, extract, merge, seen, {}, self.cimported_files)[0]
353
354     def transitive_merge_helper(self, node, extract, merge, seen, stack, outgoing):
355         if node in seen:
356             return seen[node], None
357         deps = extract(node)
358         if node in stack:
359             return deps, node
360         try:
361             stack[node] = len(stack)
362             loop = None
363             for next in outgoing(node):
364                 sub_deps, sub_loop = self.transitive_merge_helper(next, extract, merge, seen, stack, outgoing)
365                 if sub_loop is not None:
366                     if loop is not None and stack[loop] < stack[sub_loop]:
367                         pass
368                     else:
369                         loop = sub_loop
370                 deps = merge(deps, sub_deps)
371             if loop == node:
372                 loop = None
373             if loop is None:
374                 seen[node] = deps
375             return deps, loop
376         finally:
377             del stack[node]
378
379 _dep_tree = None
380 def create_dependency_tree(ctx=None):
381     global _dep_tree
382     if _dep_tree is None:
383         if ctx is None:
384             ctx = Context(["."], CompilationOptions(default_options))
385         _dep_tree = DependencyTree(ctx)
386     return _dep_tree
387
388 # This may be useful for advanced users?
389 def create_extension_list(patterns, exclude=[], ctx=None, aliases=None):
390     seen = set()
391     deps = create_dependency_tree(ctx)
392     to_exclude = set()
393     if not isinstance(exclude, list):
394         exclude = [exclude]
395     for pattern in exclude:
396         to_exclude.update(glob(pattern))
397     if not isinstance(patterns, list):
398         patterns = [patterns]
399     module_list = []
400     for pattern in patterns:
401         if isinstance(pattern, str):
402             filepattern = pattern
403             template = None
404             name = '*'
405             base = None
406             exn_type = Extension
407         elif isinstance(pattern, Extension):
408             filepattern = pattern.sources[0]
409             if os.path.splitext(filepattern)[1] not in ('.py', '.pyx'):
410                 # ignore non-cython modules
411                 module_list.append(pattern)
412                 continue
413             template = pattern
414             name = template.name
415             base = DistutilsInfo(exn=template)
416             exn_type = template.__class__
417         else:
418             raise TypeError(pattern)
419         for file in glob(filepattern):
420             if file in to_exclude:
421                 continue
422             pkg = deps.package(file)
423             if '*' in name:
424                 module_name = deps.fully_qualifeid_name(file)
425             else:
426                 module_name = name
427             if module_name not in seen:
428                 kwds = deps.distutils_info(file, aliases, base).values
429                 if base is not None:
430                     for key, value in base.values.items():
431                         if key not in kwds:
432                             kwds[key] = value
433                 module_list.append(exn_type(
434                         name=module_name,
435                         sources=[file],
436                         **kwds))
437                 m = module_list[-1]
438                 seen.add(name)
439     return module_list
440
441 # This is the user-exposed entry point.
442 def cythonize(module_list, exclude=[], nthreads=0, aliases=None, quiet=False, **options):
443     if 'include_path' not in options:
444         options['include_path'] = ['.']
445     c_options = CompilationOptions(**options)
446     cpp_options = CompilationOptions(**options); cpp_options.cplus = True
447     ctx = c_options.create_context()
448     module_list = create_extension_list(
449         module_list,
450         exclude=exclude,
451         ctx=ctx,
452         aliases=aliases)
453     deps = create_dependency_tree(ctx)
454     to_compile = []
455     for m in module_list:
456         new_sources = []
457         for source in m.sources:
458             base, ext = os.path.splitext(source)
459             if ext in ('.pyx', '.py'):
460                 if m.language == 'c++':
461                     c_file = base + '.cpp'
462                     options = cpp_options
463                 else:
464                     c_file = base + '.c'
465                     options = c_options
466                 if os.path.exists(c_file):
467                     c_timestamp = os.path.getmtime(c_file)
468                 else:
469                     c_timestamp = -1
470                 # Priority goes first to modified files, second to direct
471                 # dependents, and finally to indirect dependents.
472                 if c_timestamp < deps.timestamp(source):
473                     dep_timestamp, dep = deps.timestamp(source), source
474                     priority = 0
475                 else:
476                     dep_timestamp, dep = deps.newest_dependency(source)
477                     priority = 2 - (dep in deps.immediate_dependencies(source))
478                 if c_timestamp < dep_timestamp:
479                     if not quiet:
480                         if source == dep:
481                             print("Compiling %s because it changed." % source)
482                         else:
483                             print("Compiling %s because it depends on %s." % (source, dep))
484                     to_compile.append((priority, source, c_file, options))
485                 new_sources.append(c_file)
486             else:
487                 new_sources.append(source)
488         m.sources = new_sources
489     to_compile.sort()
490     if nthreads:
491         # Requires multiprocessing (or Python >= 2.6)
492         try:
493             import multiprocessing
494             pool = multiprocessing.Pool(nthreads)
495             pool.map(cythonize_one_helper, to_compile)
496         except ImportError:
497             print("multiprocessing required for parallel cythonization")
498             nthreads = 0
499     if not nthreads:
500         for priority, pyx_file, c_file, options in to_compile:
501             cythonize_one(pyx_file, c_file, options)
502     return module_list
503
504 # TODO: Share context? Issue: pyx processing leaks into pxd module
505 def cythonize_one(pyx_file, c_file, options=None):
506     from Cython.Compiler.Main import compile, default_options
507     from Cython.Compiler.Errors import CompileError, PyrexError
508
509     if options is None:
510         options = CompilationOptions(default_options)
511     options.output_file = c_file
512
513     any_failures = 0
514     try:
515         result = compile([pyx_file], options)
516         if result.num_errors > 0:
517             any_failures = 1
518     except (EnvironmentError, PyrexError), e:
519         sys.stderr.write(str(e) + '\n')
520         any_failures = 1
521     if any_failures:
522         raise CompileError(None, pyx_file)
523
524 def cythonize_one_helper(m):
525     return cythonize_one(*m[1:])