Merge MarkClosureVisitor and MarkGeneratorVisitor
[cython.git] / Cython / Compiler / Main.py
1 #
2 #   Cython Top Level
3 #
4
5 import os, sys, re
6 if sys.version_info[:2] < (2, 3):
7     sys.stderr.write("Sorry, Cython requires Python 2.3 or later\n")
8     sys.exit(1)
9
10 try:
11     set
12 except NameError:
13     # Python 2.3
14     from sets import Set as set
15
16 import itertools
17 from time import time
18
19 import Code
20 import Errors
21 import Parsing
22 import Version
23 from Scanning import PyrexScanner, FileSourceDescriptor
24 from Errors import PyrexError, CompileError, InternalError, AbortError, error, warning
25 from Symtab import BuiltinScope, ModuleScope
26 from Cython import Utils
27 from Cython.Utils import open_new_file, replace_suffix
28 import CythonScope
29 import DebugFlags
30
31 module_name_pattern = re.compile(r"[A-Za-z_][A-Za-z0-9_]*(\.[A-Za-z_][A-Za-z0-9_]*)*$")
32
33 verbose = 0
34
35 def dumptree(t):
36     # For quick debugging in pipelines
37     print t.dump()
38     return t
39
40 def abort_on_errors(node):
41     # Stop the pipeline if there are any errors.
42     if Errors.num_errors != 0:
43         raise AbortError, "pipeline break"
44     return node
45
46 class CompilationData(object):
47     #  Bundles the information that is passed from transform to transform.
48     #  (For now, this is only)
49
50     #  While Context contains every pxd ever loaded, path information etc.,
51     #  this only contains the data related to a single compilation pass
52     #
53     #  pyx                   ModuleNode              Main code tree of this compilation.
54     #  pxds                  {string : ModuleNode}   Trees for the pxds used in the pyx.
55     #  codewriter            CCodeWriter             Where to output final code.
56     #  options               CompilationOptions
57     #  result                CompilationResult
58     pass
59
60 class Context(object):
61     #  This class encapsulates the context needed for compiling
62     #  one or more Cython implementation files along with their
63     #  associated and imported declaration files. It includes
64     #  the root of the module import namespace and the list
65     #  of directories to search for include files.
66     #
67     #  modules               {string : ModuleScope}
68     #  include_directories   [string]
69     #  future_directives     [object]
70     #  language_level        int     currently 2 or 3 for Python 2/3
71     
72     def __init__(self, include_directories, compiler_directives, cpp=False, language_level=2):
73         #self.modules = {"__builtin__" : BuiltinScope()}
74         import Builtin, CythonScope
75         self.modules = {"__builtin__" : Builtin.builtin_scope}
76         self.modules["cython"] = CythonScope.create_cython_scope(self)
77         self.include_directories = include_directories
78         self.future_directives = set()
79         self.compiler_directives = compiler_directives
80         self.cpp = cpp
81
82         self.pxds = {} # full name -> node tree
83
84         standard_include_path = os.path.abspath(os.path.normpath(
85             os.path.join(os.path.dirname(__file__), os.path.pardir, 'Includes')))
86         self.include_directories = include_directories + [standard_include_path]
87
88         self.set_language_level(language_level)
89         
90         self.gdb_debug_outputwriter = None
91
92     def set_language_level(self, level):
93         self.language_level = level
94         if level >= 3:
95             from Future import print_function, unicode_literals
96             self.future_directives.add(print_function)
97             self.future_directives.add(unicode_literals)
98
99     def create_pipeline(self, pxd, py=False):
100         from Visitor import PrintTree
101         from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, PxdPostParse
102         from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
103         from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
104         from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
105         from ParseTreeTransforms import ExpandInplaceOperators
106         from TypeInference import MarkAssignments, MarkOverflowingArithmetic
107         from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck
108         from AnalysedTreeTransforms import AutoTestDictTransform
109         from AutoDocTransforms import EmbedSignature
110         from Optimize import FlattenInListTransform, SwitchTransform, IterationTransform
111         from Optimize import EarlyReplaceBuiltinCalls, OptimizeBuiltinCalls
112         from Optimize import ConstantFolding, FinalOptimizePhase
113         from Optimize import DropRefcountingTransform
114         from Buffer import IntroduceBufferAuxiliaryVars
115         from ModuleNode import check_c_declarations, check_c_declarations_pxd
116
117         if pxd:
118             _check_c_declarations = check_c_declarations_pxd
119             _specific_post_parse = PxdPostParse(self)
120         else:
121             _check_c_declarations = check_c_declarations
122             _specific_post_parse = None
123             
124         if py and not pxd:
125             _align_function_definitions = AlignFunctionDefinitions(self)
126         else:
127             _align_function_definitions = None
128  
129         return [
130             NormalizeTree(self),
131             PostParse(self),
132             _specific_post_parse,
133             InterpretCompilerDirectives(self, self.compiler_directives),
134             _align_function_definitions,
135             MarkClosureVisitor(self),
136             ConstantFolding(),
137             FlattenInListTransform(),
138             WithTransform(self),
139             DecoratorTransform(self),
140             AnalyseDeclarationsTransform(self),
141             AutoTestDictTransform(self),
142             EmbedSignature(self),
143             EarlyReplaceBuiltinCalls(self),  ## Necessary?
144             MarkAssignments(self),
145             MarkOverflowingArithmetic(self),
146             TransformBuiltinMethods(self),  ## Necessary?
147             IntroduceBufferAuxiliaryVars(self),
148             _check_c_declarations,
149             AnalyseExpressionsTransform(self),
150             CreateClosureClasses(self),  ## After all lookups and type inference
151             ExpandInplaceOperators(self),
152             OptimizeBuiltinCalls(self),  ## Necessary?
153             IterationTransform(),
154             SwitchTransform(),
155             DropRefcountingTransform(),
156             FinalOptimizePhase(self),
157             GilCheck(),
158             ]
159
160     def create_pyx_pipeline(self, options, result, py=False):
161         def generate_pyx_code(module_node):
162             module_node.process_implementation(options, result)
163             result.compilation_source = module_node.compilation_source
164             return result
165
166         def inject_pxd_code(module_node):
167             from textwrap import dedent
168             stats = module_node.body.stats
169             for name, (statlistnode, scope) in self.pxds.iteritems():
170                 # Copy over function nodes to the module
171                 # (this seems strange -- I believe the right concept is to split
172                 # ModuleNode into a ModuleNode and a CodeGenerator, and tell that
173                 # CodeGenerator to generate code both from the pyx and pxd ModuleNodes.
174                  stats.append(statlistnode)
175                  # Until utility code is moved to code generation phase everywhere,
176                  # we need to copy it over to the main scope
177                  module_node.scope.utility_code_list.extend(scope.utility_code_list)
178             return module_node
179
180         test_support = []
181         if options.evaluate_tree_assertions:
182             from Cython.TestUtils import TreeAssertVisitor
183             test_support.append(TreeAssertVisitor())
184
185         if options.gdb_debug:
186             from Cython.Debugger import DebugWriter
187             from ParseTreeTransforms import DebugTransform
188             self.gdb_debug_outputwriter = DebugWriter.CythonDebugWriter(
189                 options.output_dir)
190             debug_transform = [DebugTransform(self, options, result)]
191         else:
192             debug_transform = []
193             
194         return list(itertools.chain(
195             [create_parse(self)],
196             self.create_pipeline(pxd=False, py=py),
197             test_support,
198             [inject_pxd_code, abort_on_errors],
199             debug_transform,
200             [generate_pyx_code]))
201
202     def create_pxd_pipeline(self, scope, module_name):
203         def parse_pxd(source_desc):
204             tree = self.parse(source_desc, scope, pxd=True,
205                               full_module_name=module_name)
206             tree.scope = scope
207             tree.is_pxd = True
208             return tree
209
210         from CodeGeneration import ExtractPxdCode
211
212         # The pxd pipeline ends up with a CCodeWriter containing the
213         # code of the pxd, as well as a pxd scope.
214         return [parse_pxd] + self.create_pipeline(pxd=True) + [
215             ExtractPxdCode(self),
216             ]
217             
218     def create_py_pipeline(self, options, result):
219         return self.create_pyx_pipeline(options, result, py=True)
220
221
222     def process_pxd(self, source_desc, scope, module_name):
223         pipeline = self.create_pxd_pipeline(scope, module_name)
224         result = self.run_pipeline(pipeline, source_desc)
225         return result
226     
227     def nonfatal_error(self, exc):
228         return Errors.report_error(exc)
229
230     def run_pipeline(self, pipeline, source):
231         error = None
232         data = source
233         try:
234             try:
235                 for phase in pipeline:
236                     if phase is not None:
237                         if DebugFlags.debug_verbose_pipeline:
238                             t = time()
239                             print "Entering pipeline phase %r" % phase
240                         data = phase(data)
241                         if DebugFlags.debug_verbose_pipeline:
242                             print "    %.3f seconds" % (time() - t)
243             except CompileError, err:
244                 # err is set
245                 Errors.report_error(err)
246                 error = err
247         except InternalError, err:
248             # Only raise if there was not an earlier error
249             if Errors.num_errors == 0:
250                 raise
251             error = err
252         except AbortError, err:
253             error = err
254         return (error, data)
255
256     def find_module(self, module_name, 
257             relative_to = None, pos = None, need_pxd = 1):
258         # Finds and returns the module scope corresponding to
259         # the given relative or absolute module name. If this
260         # is the first time the module has been requested, finds
261         # the corresponding .pxd file and process it.
262         # If relative_to is not None, it must be a module scope,
263         # and the module will first be searched for relative to
264         # that module, provided its name is not a dotted name.
265         debug_find_module = 0
266         if debug_find_module:
267             print("Context.find_module: module_name = %s, relative_to = %s, pos = %s, need_pxd = %s" % (
268                     module_name, relative_to, pos, need_pxd))
269
270         scope = None
271         pxd_pathname = None
272         if not module_name_pattern.match(module_name):
273             if pos is None:
274                 pos = (module_name, 0, 0)
275             raise CompileError(pos,
276                 "'%s' is not a valid module name" % module_name)
277         if "." not in module_name and relative_to:
278             if debug_find_module:
279                 print("...trying relative import")
280             scope = relative_to.lookup_submodule(module_name)
281             if not scope:
282                 qualified_name = relative_to.qualify_name(module_name)
283                 pxd_pathname = self.find_pxd_file(qualified_name, pos)
284                 if pxd_pathname:
285                     scope = relative_to.find_submodule(module_name)
286         if not scope:
287             if debug_find_module:
288                 print("...trying absolute import")
289             scope = self
290             for name in module_name.split("."):
291                 scope = scope.find_submodule(name)
292         if debug_find_module:
293             print("...scope =", scope)
294         if not scope.pxd_file_loaded:
295             if debug_find_module:
296                 print("...pxd not loaded")
297             scope.pxd_file_loaded = 1
298             if not pxd_pathname:
299                 if debug_find_module:
300                     print("...looking for pxd file")
301                 pxd_pathname = self.find_pxd_file(module_name, pos)
302                 if debug_find_module:
303                     print("......found ", pxd_pathname)
304                 if not pxd_pathname and need_pxd:
305                     package_pathname = self.search_include_directories(module_name, ".py", pos)
306                     if package_pathname and package_pathname.endswith('__init__.py'):
307                         pass
308                     else:
309                         error(pos, "'%s.pxd' not found" % module_name)
310             if pxd_pathname:
311                 try:
312                     if debug_find_module:
313                         print("Context.find_module: Parsing %s" % pxd_pathname)
314                     source_desc = FileSourceDescriptor(pxd_pathname)
315                     err, result = self.process_pxd(source_desc, scope, module_name)
316                     if err:
317                         raise err
318                     (pxd_codenodes, pxd_scope) = result
319                     self.pxds[module_name] = (pxd_codenodes, pxd_scope)
320                 except CompileError:
321                     pass
322         return scope
323     
324     def find_pxd_file(self, qualified_name, pos):
325         # Search include path for the .pxd file corresponding to the
326         # given fully-qualified module name.
327         # Will find either a dotted filename or a file in a
328         # package directory. If a source file position is given,
329         # the directory containing the source file is searched first
330         # for a dotted filename, and its containing package root
331         # directory is searched first for a non-dotted filename.
332         pxd = self.search_include_directories(qualified_name, ".pxd", pos)
333         if pxd is None: # XXX Keep this until Includes/Deprecated is removed
334             if (qualified_name.startswith('python') or
335                 qualified_name in ('stdlib', 'stdio', 'stl')):
336                 standard_include_path = os.path.abspath(os.path.normpath(
337                         os.path.join(os.path.dirname(__file__), os.path.pardir, 'Includes')))
338                 deprecated_include_path = os.path.join(standard_include_path, 'Deprecated')
339                 self.include_directories.append(deprecated_include_path)
340                 try:
341                     pxd = self.search_include_directories(qualified_name, ".pxd", pos)
342                 finally:
343                     self.include_directories.pop()
344                 if pxd:
345                     name = qualified_name
346                     if name.startswith('python'):
347                         warning(pos, "'%s' is deprecated, use 'cpython'" % name, 1)
348                     elif name in ('stdlib', 'stdio'):
349                         warning(pos, "'%s' is deprecated, use 'libc.%s'" % (name, name), 1)
350                     elif name in ('stl'):
351                         warning(pos, "'%s' is deprecated, use 'libcpp.*.*'" % name, 1)
352         return pxd
353
354     def find_pyx_file(self, qualified_name, pos):
355         # Search include path for the .pyx file corresponding to the
356         # given fully-qualified module name, as for find_pxd_file().
357         return self.search_include_directories(qualified_name, ".pyx", pos)
358     
359     def find_include_file(self, filename, pos):
360         # Search list of include directories for filename.
361         # Reports an error and returns None if not found.
362         path = self.search_include_directories(filename, "", pos,
363                                                include=True)
364         if not path:
365             error(pos, "'%s' not found" % filename)
366         return path
367     
368     def search_include_directories(self, qualified_name, suffix, pos,
369                                    include=False):
370         # Search the list of include directories for the given
371         # file name. If a source file position is given, first
372         # searches the directory containing that file. Returns
373         # None if not found, but does not report an error.
374         # The 'include' option will disable package dereferencing.
375         dirs = self.include_directories
376         if pos:
377             file_desc = pos[0]
378             if not isinstance(file_desc, FileSourceDescriptor):
379                 raise RuntimeError("Only file sources for code supported")
380             if include:
381                 dirs = [os.path.dirname(file_desc.filename)] + dirs
382             else:
383                 dirs = [self.find_root_package_dir(file_desc.filename)] + dirs
384
385         dotted_filename = qualified_name
386         if suffix:
387             dotted_filename += suffix
388         if not include:
389             names = qualified_name.split('.')
390             package_names = names[:-1]
391             module_name = names[-1]
392             module_filename = module_name + suffix
393             package_filename = "__init__" + suffix
394
395         for dir in dirs:
396             path = os.path.join(dir, dotted_filename)
397             if Utils.path_exists(path):
398                 return path
399             if not include:
400                 package_dir = self.check_package_dir(dir, package_names)
401                 if package_dir is not None:
402                     path = os.path.join(package_dir, module_filename)
403                     if Utils.path_exists(path):
404                         return path
405                     path = os.path.join(dir, package_dir, module_name,
406                                         package_filename)
407                     if Utils.path_exists(path):
408                         return path
409         return None
410
411     def find_root_package_dir(self, file_path):
412         dir = os.path.dirname(file_path)
413         while self.is_package_dir(dir):
414             parent = os.path.dirname(dir)
415             if parent == dir:
416                 break
417             dir = parent
418         return dir
419
420     def check_package_dir(self, dir, package_names):
421         for dirname in package_names:
422             dir = os.path.join(dir, dirname)
423             if not self.is_package_dir(dir):
424                 return None
425         return dir
426
427     def c_file_out_of_date(self, source_path):
428         c_path = Utils.replace_suffix(source_path, ".c")
429         if not os.path.exists(c_path):
430             return 1
431         c_time = Utils.modification_time(c_path)
432         if Utils.file_newer_than(source_path, c_time):
433             return 1
434         pos = [source_path]
435         pxd_path = Utils.replace_suffix(source_path, ".pxd")
436         if os.path.exists(pxd_path) and Utils.file_newer_than(pxd_path, c_time):
437             return 1
438         for kind, name in self.read_dependency_file(source_path):
439             if kind == "cimport":
440                 dep_path = self.find_pxd_file(name, pos)
441             elif kind == "include":
442                 dep_path = self.search_include_directories(name, pos)
443             else:
444                 continue
445             if dep_path and Utils.file_newer_than(dep_path, c_time):
446                 return 1
447         return 0
448     
449     def find_cimported_module_names(self, source_path):
450         return [ name for kind, name in self.read_dependency_file(source_path)
451                  if kind == "cimport" ]
452
453     def is_package_dir(self, dir_path):
454         #  Return true if the given directory is a package directory.
455         for filename in ("__init__.py", 
456                          "__init__.pyx", 
457                          "__init__.pxd"):
458             path = os.path.join(dir_path, filename)
459             if Utils.path_exists(path):
460                 return 1
461
462     def read_dependency_file(self, source_path):
463         dep_path = Utils.replace_suffix(source_path, ".dep")
464         if os.path.exists(dep_path):
465             f = open(dep_path, "rU")
466             chunks = [ line.strip().split(" ", 1)
467                        for line in f.readlines()
468                        if " " in line.strip() ]
469             f.close()
470             return chunks
471         else:
472             return ()
473
474     def lookup_submodule(self, name):
475         # Look up a top-level module. Returns None if not found.
476         return self.modules.get(name, None)
477
478     def find_submodule(self, name):
479         # Find a top-level module, creating a new one if needed.
480         scope = self.lookup_submodule(name)
481         if not scope:
482             scope = ModuleScope(name, 
483                 parent_module = None, context = self)
484             self.modules[name] = scope
485         return scope
486
487     def parse(self, source_desc, scope, pxd, full_module_name):
488         if not isinstance(source_desc, FileSourceDescriptor):
489             raise RuntimeError("Only file sources for code supported")
490         source_filename = source_desc.filename
491         scope.cpp = self.cpp
492         # Parse the given source file and return a parse tree.
493         try:
494             f = Utils.open_source_file(source_filename, "rU")
495             try:
496                 s = PyrexScanner(f, source_desc, source_encoding = f.encoding,
497                                  scope = scope, context = self)
498                 tree = Parsing.p_module(s, pxd, full_module_name)
499             finally:
500                 f.close()
501         except UnicodeDecodeError, msg:
502             #import traceback
503             #traceback.print_exc()
504             error((source_desc, 0, 0), "Decoding error, missing or incorrect coding=<encoding-name> at top of source (%s)" % msg)
505         if Errors.num_errors > 0:
506             raise CompileError
507         return tree
508
509     def extract_module_name(self, path, options):
510         # Find fully_qualified module name from the full pathname
511         # of a source file.
512         dir, filename = os.path.split(path)
513         module_name, _ = os.path.splitext(filename)
514         if "." in module_name:
515             return module_name
516         if module_name == "__init__":
517             dir, module_name = os.path.split(dir)
518         names = [module_name]
519         while self.is_package_dir(dir):
520             parent, package_name = os.path.split(dir)
521             if parent == dir:
522                 break
523             names.append(package_name)
524             dir = parent
525         names.reverse()
526         return ".".join(names)
527
528     def setup_errors(self, options, result):
529         Errors.reset() # clear any remaining error state
530         if options.use_listing_file:
531             result.listing_file = Utils.replace_suffix(source, ".lis")
532             path = result.listing_file
533         else:
534             path = None
535         Errors.open_listing_file(path=path,
536                                  echo_to_stderr=options.errors_to_stderr)
537
538     def teardown_errors(self, err, options, result):
539         source_desc = result.compilation_source.source_desc
540         if not isinstance(source_desc, FileSourceDescriptor):
541             raise RuntimeError("Only file sources for code supported")
542         Errors.close_listing_file()
543         result.num_errors = Errors.num_errors
544         if result.num_errors > 0:
545             err = True
546         if err and result.c_file:
547             try:
548                 Utils.castrate_file(result.c_file, os.stat(source_desc.filename))
549             except EnvironmentError:
550                 pass
551             result.c_file = None
552
553 def create_parse(context):
554     def parse(compsrc):
555         source_desc = compsrc.source_desc
556         full_module_name = compsrc.full_module_name
557         initial_pos = (source_desc, 1, 0)
558         scope = context.find_module(full_module_name, pos = initial_pos, need_pxd = 0)
559         tree = context.parse(source_desc, scope, pxd = 0, full_module_name = full_module_name)
560         tree.compilation_source = compsrc
561         tree.scope = scope
562         tree.is_pxd = False
563         return tree
564     return parse
565
566 def create_default_resultobj(compilation_source, options):
567     result = CompilationResult()
568     result.main_source_file = compilation_source.source_desc.filename
569     result.compilation_source = compilation_source
570     source_desc = compilation_source.source_desc
571     if options.output_file:
572         result.c_file = os.path.join(compilation_source.cwd, options.output_file)
573     else:
574         if options.cplus:
575             c_suffix = ".cpp"
576         else:
577             c_suffix = ".c"
578         result.c_file = Utils.replace_suffix(source_desc.filename, c_suffix)
579     return result
580
581 def run_pipeline(source, options, full_module_name = None):
582     # Set up context
583     context = options.create_context()
584
585     # Set up source object
586     cwd = os.getcwd()
587     source_desc = FileSourceDescriptor(os.path.join(cwd, source))
588     full_module_name = full_module_name or context.extract_module_name(source, options)
589     source = CompilationSource(source_desc, full_module_name, cwd)
590
591     # Set up result object
592     result = create_default_resultobj(source, options)
593     
594     # Get pipeline
595     if source_desc.filename.endswith(".py"):
596         pipeline = context.create_py_pipeline(options, result)
597     else:
598         pipeline = context.create_pyx_pipeline(options, result)
599
600     context.setup_errors(options, result)
601     err, enddata = context.run_pipeline(pipeline, source)
602     context.teardown_errors(err, options, result)
603     return result
604     
605
606 #------------------------------------------------------------------------
607 #
608 #  Main Python entry points
609 #
610 #------------------------------------------------------------------------
611
612 class CompilationSource(object):
613     """
614     Contains the data necesarry to start up a compilation pipeline for
615     a single compilation unit.
616     """
617     def __init__(self, source_desc, full_module_name, cwd):
618         self.source_desc = source_desc
619         self.full_module_name = full_module_name
620         self.cwd = cwd
621
622 class CompilationOptions(object):
623     """
624     Options to the Cython compiler:
625     
626     show_version      boolean   Display version number
627     use_listing_file  boolean   Generate a .lis file
628     errors_to_stderr  boolean   Echo errors to stderr when using .lis
629     include_path      [string]  Directories to search for include files
630     output_file       string    Name of generated .c file
631     generate_pxi      boolean   Generate .pxi file for public declarations
632     recursive         boolean   Recursively find and compile dependencies
633     timestamps        boolean   Only compile changed source files. If None,
634                                 defaults to true when recursive is true.
635     verbose           boolean   Always print source names being compiled
636     quiet             boolean   Don't print source names in recursive mode
637     compiler_directives  dict      Overrides for pragma options (see Options.py)
638     evaluate_tree_assertions boolean  Test support: evaluate parse tree assertions
639     language_level    integer   The Python language level: 2 or 3
640     
641     cplus             boolean   Compile as c++ code
642     """
643     
644     def __init__(self, defaults = None, **kw):
645         self.include_path = []
646         if defaults:
647             if isinstance(defaults, CompilationOptions):
648                 defaults = defaults.__dict__
649         else:
650             defaults = default_options
651         self.__dict__.update(defaults)
652         self.__dict__.update(kw)
653
654     def create_context(self):
655         return Context(self.include_path, self.compiler_directives,
656                       self.cplus, self.language_level)
657
658
659 class CompilationResult(object):
660     """
661     Results from the Cython compiler:
662     
663     c_file           string or None   The generated C source file
664     h_file           string or None   The generated C header file
665     i_file           string or None   The generated .pxi file
666     api_file         string or None   The generated C API .h file
667     listing_file     string or None   File of error messages
668     object_file      string or None   Result of compiling the C file
669     extension_file   string or None   Result of linking the object file
670     num_errors       integer          Number of compilation errors
671     compilation_source CompilationSource
672     """
673     
674     def __init__(self):
675         self.c_file = None
676         self.h_file = None
677         self.i_file = None
678         self.api_file = None
679         self.listing_file = None
680         self.object_file = None
681         self.extension_file = None
682         self.main_source_file = None
683
684
685 class CompilationResultSet(dict):
686     """
687     Results from compiling multiple Pyrex source files. A mapping
688     from source file paths to CompilationResult instances. Also
689     has the following attributes:
690     
691     num_errors   integer   Total number of compilation errors
692     """
693     
694     num_errors = 0
695
696     def add(self, source, result):
697         self[source] = result
698         self.num_errors += result.num_errors
699
700
701 def compile_single(source, options, full_module_name = None):
702     """
703     compile_single(source, options, full_module_name)
704     
705     Compile the given Pyrex implementation file and return a CompilationResult.
706     Always compiles a single file; does not perform timestamp checking or
707     recursion.
708     """
709     return run_pipeline(source, options, full_module_name)
710
711
712 def compile_multiple(sources, options):
713     """
714     compile_multiple(sources, options)
715     
716     Compiles the given sequence of Pyrex implementation files and returns
717     a CompilationResultSet. Performs timestamp checking and/or recursion
718     if these are specified in the options.
719     """
720     context = options.create_context()
721     sources = [os.path.abspath(source) for source in sources]
722     processed = set()
723     results = CompilationResultSet()
724     recursive = options.recursive
725     timestamps = options.timestamps
726     if timestamps is None:
727         timestamps = recursive
728     verbose = options.verbose or ((recursive or timestamps) and not options.quiet)
729     for source in sources:
730         if source not in processed:
731             # Compiling multiple sources in one context doesn't quite
732             # work properly yet.
733             if not timestamps or context.c_file_out_of_date(source):
734                 if verbose:
735                     sys.stderr.write("Compiling %s\n" % source)
736
737                 result = run_pipeline(source, options)
738                 results.add(source, result)
739             processed.add(source)
740             if recursive:
741                 for module_name in context.find_cimported_module_names(source):
742                     path = context.find_pyx_file(module_name, [source])
743                     if path:
744                         sources.append(path)
745                     else:
746                         sys.stderr.write(
747                             "Cannot find .pyx file for cimported module '%s'\n" % module_name)
748     return results
749
750 def compile(source, options = None, full_module_name = None, **kwds):
751     """
752     compile(source [, options], [, <option> = <value>]...)
753     
754     Compile one or more Pyrex implementation files, with optional timestamp
755     checking and recursing on dependecies. The source argument may be a string
756     or a sequence of strings If it is a string and no recursion or timestamp
757     checking is requested, a CompilationResult is returned, otherwise a
758     CompilationResultSet is returned.
759     """
760     options = CompilationOptions(defaults = options, **kwds)
761     if isinstance(source, basestring) and not options.timestamps \
762             and not options.recursive:
763         return compile_single(source, options, full_module_name)
764     else:
765         return compile_multiple(source, options)
766
767 #------------------------------------------------------------------------
768 #
769 #  Main command-line entry point
770 #
771 #------------------------------------------------------------------------
772 def setuptools_main():
773     return main(command_line = 1)
774
775 def main(command_line = 0):
776     args = sys.argv[1:]
777     any_failures = 0
778     if command_line:
779         from CmdLine import parse_command_line
780         options, sources = parse_command_line(args)
781     else:
782         options = CompilationOptions(default_options)
783         sources = args
784
785     if options.show_version:
786         sys.stderr.write("Cython version %s\n" % Version.version)
787     if options.working_path!="":
788         os.chdir(options.working_path)
789     try:
790         result = compile(sources, options)
791         if result.num_errors > 0:
792             any_failures = 1
793     except (EnvironmentError, PyrexError), e:
794         sys.stderr.write(str(e) + '\n')
795         any_failures = 1
796     if any_failures:
797         sys.exit(1)
798
799
800
801 #------------------------------------------------------------------------
802 #
803 #  Set the default options depending on the platform
804 #
805 #------------------------------------------------------------------------
806
807 default_options = dict(
808     show_version = 0,
809     use_listing_file = 0,
810     errors_to_stderr = 1,
811     cplus = 0,
812     output_file = None,
813     annotate = False,
814     generate_pxi = 0,
815     working_path = "",
816     recursive = 0,
817     timestamps = None,
818     verbose = 0,
819     quiet = 0,
820     compiler_directives = {},
821     evaluate_tree_assertions = False,
822     emit_linenums = False,
823     language_level = 2,
824     gdb_debug = False,
825 )