Another necessary "optimization."
[cython.git] / Cython / Compiler / Main.py
1 #
2 #   Cython Top Level
3 #
4
5 import os, sys, re
6 if sys.version_info[:2] < (2, 3):
7     sys.stderr.write("Sorry, Cython requires Python 2.3 or later\n")
8     sys.exit(1)
9
10 try:
11     set
12 except NameError:
13     # Python 2.3
14     from sets import Set as set
15
16 from time import time
17 import Code
18 import Errors
19 import Parsing
20 import Version
21 from Scanning import PyrexScanner, FileSourceDescriptor
22 from Errors import PyrexError, CompileError, InternalError, error, warning
23 from Symtab import BuiltinScope, ModuleScope
24 from Cython import Utils
25 from Cython.Utils import open_new_file, replace_suffix
26 import CythonScope
27 import DebugFlags
28
29 module_name_pattern = re.compile(r"[A-Za-z_][A-Za-z0-9_]*(\.[A-Za-z_][A-Za-z0-9_]*)*$")
30
31 verbose = 0
32
33 def dumptree(t):
34     # For quick debugging in pipelines
35     print t.dump()
36     return t
37
38 def abort_on_errors(node):
39     # Stop the pipeline if there are any errors.
40     if Errors.num_errors != 0:
41         raise InternalError, "abort"
42     return node
43
44 class CompilationData(object):
45     #  Bundles the information that is passed from transform to transform.
46     #  (For now, this is only)
47
48     #  While Context contains every pxd ever loaded, path information etc.,
49     #  this only contains the data related to a single compilation pass
50     #
51     #  pyx                   ModuleNode              Main code tree of this compilation.
52     #  pxds                  {string : ModuleNode}   Trees for the pxds used in the pyx.
53     #  codewriter            CCodeWriter             Where to output final code.
54     #  options               CompilationOptions
55     #  result                CompilationResult
56     pass
57
58 class Context(object):
59     #  This class encapsulates the context needed for compiling
60     #  one or more Cython implementation files along with their
61     #  associated and imported declaration files. It includes
62     #  the root of the module import namespace and the list
63     #  of directories to search for include files.
64     #
65     #  modules               {string : ModuleScope}
66     #  include_directories   [string]
67     #  future_directives     [object]
68     #  language_level        int     currently 2 or 3 for Python 2/3
69     
70     def __init__(self, include_directories, compiler_directives, cpp=False, language_level=2):
71         #self.modules = {"__builtin__" : BuiltinScope()}
72         import Builtin, CythonScope
73         self.modules = {"__builtin__" : Builtin.builtin_scope}
74         self.modules["cython"] = CythonScope.create_cython_scope(self)
75         self.include_directories = include_directories
76         self.future_directives = set()
77         self.compiler_directives = compiler_directives
78         self.cpp = cpp
79
80         self.pxds = {} # full name -> node tree
81
82         standard_include_path = os.path.abspath(os.path.normpath(
83             os.path.join(os.path.dirname(__file__), os.path.pardir, 'Includes')))
84         self.include_directories = include_directories + [standard_include_path]
85
86         self.set_language_level(language_level)
87
88     def set_language_level(self, level):
89         self.language_level = level
90         if level >= 3:
91             from Future import print_function, unicode_literals
92             self.future_directives.add(print_function)
93             self.future_directives.add(unicode_literals)
94
95     def create_pipeline(self, pxd, py=False):
96         from Visitor import PrintTree
97         from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, PxdPostParse
98         from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
99         from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
100         from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
101         from TypeInference import MarkAssignments, MarkOverflowingArithmetic
102         from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck
103         from AnalysedTreeTransforms import AutoTestDictTransform
104         from AutoDocTransforms import EmbedSignature
105         from Optimize import FlattenInListTransform, SwitchTransform, IterationTransform
106         from Optimize import EarlyReplaceBuiltinCalls, OptimizeBuiltinCalls
107         from Optimize import ConstantFolding, FinalOptimizePhase
108         from Optimize import DropRefcountingTransform
109         from Buffer import IntroduceBufferAuxiliaryVars
110         from ModuleNode import check_c_declarations, check_c_declarations_pxd
111
112         # Temporary hack that can be used to ensure that all result_code's
113         # are generated at code generation time.
114         import Visitor
115         class ClearResultCodes(Visitor.CythonTransform):
116             def visit_ExprNode(self, node):
117                 self.visitchildren(node)
118                 node.result_code = "<cleared>"
119                 return node
120
121         if pxd:
122             _check_c_declarations = check_c_declarations_pxd
123             _specific_post_parse = PxdPostParse(self)
124         else:
125             _check_c_declarations = check_c_declarations
126             _specific_post_parse = None
127             
128         if py and not pxd:
129             _align_function_definitions = AlignFunctionDefinitions(self)
130         else:
131             _align_function_definitions = None
132  
133         return [
134             NormalizeTree(self),
135             PostParse(self),
136             _specific_post_parse,
137             InterpretCompilerDirectives(self, self.compiler_directives),
138             _align_function_definitions,
139             MarkClosureVisitor(self),
140             ConstantFolding(),
141 #            FlattenInListTransform(),
142             WithTransform(self),
143             DecoratorTransform(self),
144             AnalyseDeclarationsTransform(self),
145             CreateClosureClasses(self),
146             AutoTestDictTransform(self),
147             EmbedSignature(self),
148             EarlyReplaceBuiltinCalls(self),  ## Necessary?
149             MarkAssignments(self),
150             MarkOverflowingArithmetic(self),
151             TransformBuiltinMethods(self),
152             IntroduceBufferAuxiliaryVars(self),
153             _check_c_declarations,
154             AnalyseExpressionsTransform(self),
155 #            OptimizeBuiltinCalls(self),
156 #            IterationTransform(),
157             SwitchTransform(),
158 #            DropRefcountingTransform(),
159 #            FinalOptimizePhase(self),
160             GilCheck(),
161             #ClearResultCodes(self),
162             #SpecialFunctions(self),
163             #CreateClosureClasses(context),
164             ]
165
166     def create_pyx_pipeline(self, options, result, py=False):
167         def generate_pyx_code(module_node):
168             module_node.process_implementation(options, result)
169             result.compilation_source = module_node.compilation_source
170             return result
171
172         def inject_pxd_code(module_node):
173             from textwrap import dedent
174             stats = module_node.body.stats
175             for name, (statlistnode, scope) in self.pxds.iteritems():
176                 # Copy over function nodes to the module
177                 # (this seems strange -- I believe the right concept is to split
178                 # ModuleNode into a ModuleNode and a CodeGenerator, and tell that
179                 # CodeGenerator to generate code both from the pyx and pxd ModuleNodes.
180                  stats.append(statlistnode)
181                  # Until utility code is moved to code generation phase everywhere,
182                  # we need to copy it over to the main scope
183                  module_node.scope.utility_code_list.extend(scope.utility_code_list)
184             return module_node
185
186         test_support = []
187         if options.evaluate_tree_assertions:
188             from Cython.TestUtils import TreeAssertVisitor
189             test_support.append(TreeAssertVisitor())
190
191         return ([
192                 create_parse(self),
193             ] + self.create_pipeline(pxd=False, py=py) + test_support + [
194                 inject_pxd_code,
195                 abort_on_errors,
196                 generate_pyx_code,
197             ])
198
199     def create_pxd_pipeline(self, scope, module_name):
200         def parse_pxd(source_desc):
201             tree = self.parse(source_desc, scope, pxd=True,
202                               full_module_name=module_name)
203             tree.scope = scope
204             tree.is_pxd = True
205             return tree
206
207         from CodeGeneration import ExtractPxdCode
208
209         # The pxd pipeline ends up with a CCodeWriter containing the
210         # code of the pxd, as well as a pxd scope.
211         return [parse_pxd] + self.create_pipeline(pxd=True) + [
212             ExtractPxdCode(self),
213             ]
214             
215     def create_py_pipeline(self, options, result):
216         return self.create_pyx_pipeline(options, result, py=True)
217
218
219     def process_pxd(self, source_desc, scope, module_name):
220         pipeline = self.create_pxd_pipeline(scope, module_name)
221         result = self.run_pipeline(pipeline, source_desc)
222         return result
223     
224     def nonfatal_error(self, exc):
225         return Errors.report_error(exc)
226
227     def run_pipeline(self, pipeline, source):
228         error = None
229         data = source
230         try:
231             for phase in pipeline:
232                 if phase is not None:
233                     if DebugFlags.debug_verbose_pipeline:
234                         t = time()
235                         print "Entering pipeline phase %r" % phase
236                     data = phase(data)
237                     if DebugFlags.debug_verbose_pipeline:
238                         print "    %.3f seconds" % (time() - t)
239         except CompileError, err:
240             # err is set
241             Errors.report_error(err)
242             error = err
243         except InternalError, err:
244             # Only raise if there was not an earlier error
245             if Errors.num_errors == 0:
246                 raise
247             error = err
248         return (error, data)
249
250     def find_module(self, module_name, 
251             relative_to = None, pos = None, need_pxd = 1):
252         # Finds and returns the module scope corresponding to
253         # the given relative or absolute module name. If this
254         # is the first time the module has been requested, finds
255         # the corresponding .pxd file and process it.
256         # If relative_to is not None, it must be a module scope,
257         # and the module will first be searched for relative to
258         # that module, provided its name is not a dotted name.
259         debug_find_module = 0
260         if debug_find_module:
261             print("Context.find_module: module_name = %s, relative_to = %s, pos = %s, need_pxd = %s" % (
262                     module_name, relative_to, pos, need_pxd))
263
264         scope = None
265         pxd_pathname = None
266         if not module_name_pattern.match(module_name):
267             if pos is None:
268                 pos = (module_name, 0, 0)
269             raise CompileError(pos,
270                 "'%s' is not a valid module name" % module_name)
271         if "." not in module_name and relative_to:
272             if debug_find_module:
273                 print("...trying relative import")
274             scope = relative_to.lookup_submodule(module_name)
275             if not scope:
276                 qualified_name = relative_to.qualify_name(module_name)
277                 pxd_pathname = self.find_pxd_file(qualified_name, pos)
278                 if pxd_pathname:
279                     scope = relative_to.find_submodule(module_name)
280         if not scope:
281             if debug_find_module:
282                 print("...trying absolute import")
283             scope = self
284             for name in module_name.split("."):
285                 scope = scope.find_submodule(name)
286         if debug_find_module:
287             print("...scope =", scope)
288         if not scope.pxd_file_loaded:
289             if debug_find_module:
290                 print("...pxd not loaded")
291             scope.pxd_file_loaded = 1
292             if not pxd_pathname:
293                 if debug_find_module:
294                     print("...looking for pxd file")
295                 pxd_pathname = self.find_pxd_file(module_name, pos)
296                 if debug_find_module:
297                     print("......found ", pxd_pathname)
298                 if not pxd_pathname and need_pxd:
299                     package_pathname = self.search_include_directories(module_name, ".py", pos)
300                     if package_pathname and package_pathname.endswith('__init__.py'):
301                         pass
302                     else:
303                         error(pos, "'%s.pxd' not found" % module_name)
304             if pxd_pathname:
305                 try:
306                     if debug_find_module:
307                         print("Context.find_module: Parsing %s" % pxd_pathname)
308                     source_desc = FileSourceDescriptor(pxd_pathname)
309                     err, result = self.process_pxd(source_desc, scope, module_name)
310                     if err:
311                         raise err
312                     (pxd_codenodes, pxd_scope) = result
313                     self.pxds[module_name] = (pxd_codenodes, pxd_scope)
314                 except CompileError:
315                     pass
316         return scope
317     
318     def find_pxd_file(self, qualified_name, pos):
319         # Search include path for the .pxd file corresponding to the
320         # given fully-qualified module name.
321         # Will find either a dotted filename or a file in a
322         # package directory. If a source file position is given,
323         # the directory containing the source file is searched first
324         # for a dotted filename, and its containing package root
325         # directory is searched first for a non-dotted filename.
326         pxd = self.search_include_directories(qualified_name, ".pxd", pos)
327         if pxd is None: # XXX Keep this until Includes/Deprecated is removed
328             if (qualified_name.startswith('python') or
329                 qualified_name in ('stdlib', 'stdio', 'stl')):
330                 standard_include_path = os.path.abspath(os.path.normpath(
331                         os.path.join(os.path.dirname(__file__), os.path.pardir, 'Includes')))
332                 deprecated_include_path = os.path.join(standard_include_path, 'Deprecated')
333                 self.include_directories.append(deprecated_include_path)
334                 try:
335                     pxd = self.search_include_directories(qualified_name, ".pxd", pos)
336                 finally:
337                     self.include_directories.pop()
338                 if pxd:
339                     name = qualified_name
340                     if name.startswith('python'):
341                         warning(pos, "'%s' is deprecated, use 'cpython'" % name, 1)
342                     elif name in ('stdlib', 'stdio'):
343                         warning(pos, "'%s' is deprecated, use 'libc.%s'" % (name, name), 1)
344                     elif name in ('stl'):
345                         warning(pos, "'%s' is deprecated, use 'libcpp.*.*'" % name, 1)
346         return pxd
347
348     def find_pyx_file(self, qualified_name, pos):
349         # Search include path for the .pyx file corresponding to the
350         # given fully-qualified module name, as for find_pxd_file().
351         return self.search_include_directories(qualified_name, ".pyx", pos)
352     
353     def find_include_file(self, filename, pos):
354         # Search list of include directories for filename.
355         # Reports an error and returns None if not found.
356         path = self.search_include_directories(filename, "", pos,
357                                                include=True)
358         if not path:
359             error(pos, "'%s' not found" % filename)
360         return path
361     
362     def search_include_directories(self, qualified_name, suffix, pos,
363                                    include=False):
364         # Search the list of include directories for the given
365         # file name. If a source file position is given, first
366         # searches the directory containing that file. Returns
367         # None if not found, but does not report an error.
368         # The 'include' option will disable package dereferencing.
369         dirs = self.include_directories
370         if pos:
371             file_desc = pos[0]
372             if not isinstance(file_desc, FileSourceDescriptor):
373                 raise RuntimeError("Only file sources for code supported")
374             if include:
375                 dirs = [os.path.dirname(file_desc.filename)] + dirs
376             else:
377                 dirs = [self.find_root_package_dir(file_desc.filename)] + dirs
378
379         dotted_filename = qualified_name
380         if suffix:
381             dotted_filename += suffix
382         if not include:
383             names = qualified_name.split('.')
384             package_names = names[:-1]
385             module_name = names[-1]
386             module_filename = module_name + suffix
387             package_filename = "__init__" + suffix
388
389         for dir in dirs:
390             path = os.path.join(dir, dotted_filename)
391             if Utils.path_exists(path):
392                 return path
393             if not include:
394                 package_dir = self.check_package_dir(dir, package_names)
395                 if package_dir is not None:
396                     path = os.path.join(package_dir, module_filename)
397                     if Utils.path_exists(path):
398                         return path
399                     path = os.path.join(dir, package_dir, module_name,
400                                         package_filename)
401                     if Utils.path_exists(path):
402                         return path
403         return None
404
405     def find_root_package_dir(self, file_path):
406         dir = os.path.dirname(file_path)
407         while self.is_package_dir(dir):
408             parent = os.path.dirname(dir)
409             if parent == dir:
410                 break
411             dir = parent
412         return dir
413
414     def check_package_dir(self, dir, package_names):
415         for dirname in package_names:
416             dir = os.path.join(dir, dirname)
417             if not self.is_package_dir(dir):
418                 return None
419         return dir
420
421     def c_file_out_of_date(self, source_path):
422         c_path = Utils.replace_suffix(source_path, ".c")
423         if not os.path.exists(c_path):
424             return 1
425         c_time = Utils.modification_time(c_path)
426         if Utils.file_newer_than(source_path, c_time):
427             return 1
428         pos = [source_path]
429         pxd_path = Utils.replace_suffix(source_path, ".pxd")
430         if os.path.exists(pxd_path) and Utils.file_newer_than(pxd_path, c_time):
431             return 1
432         for kind, name in self.read_dependency_file(source_path):
433             if kind == "cimport":
434                 dep_path = self.find_pxd_file(name, pos)
435             elif kind == "include":
436                 dep_path = self.search_include_directories(name, pos)
437             else:
438                 continue
439             if dep_path and Utils.file_newer_than(dep_path, c_time):
440                 return 1
441         return 0
442     
443     def find_cimported_module_names(self, source_path):
444         return [ name for kind, name in self.read_dependency_file(source_path)
445                  if kind == "cimport" ]
446
447     def is_package_dir(self, dir_path):
448         #  Return true if the given directory is a package directory.
449         for filename in ("__init__.py", 
450                          "__init__.pyx", 
451                          "__init__.pxd"):
452             path = os.path.join(dir_path, filename)
453             if Utils.path_exists(path):
454                 return 1
455
456     def read_dependency_file(self, source_path):
457         dep_path = Utils.replace_suffix(source_path, ".dep")
458         if os.path.exists(dep_path):
459             f = open(dep_path, "rU")
460             chunks = [ line.strip().split(" ", 1)
461                        for line in f.readlines()
462                        if " " in line.strip() ]
463             f.close()
464             return chunks
465         else:
466             return ()
467
468     def lookup_submodule(self, name):
469         # Look up a top-level module. Returns None if not found.
470         return self.modules.get(name, None)
471
472     def find_submodule(self, name):
473         # Find a top-level module, creating a new one if needed.
474         scope = self.lookup_submodule(name)
475         if not scope:
476             scope = ModuleScope(name, 
477                 parent_module = None, context = self)
478             self.modules[name] = scope
479         return scope
480
481     def parse(self, source_desc, scope, pxd, full_module_name):
482         if not isinstance(source_desc, FileSourceDescriptor):
483             raise RuntimeError("Only file sources for code supported")
484         source_filename = source_desc.filename
485         scope.cpp = self.cpp
486         # Parse the given source file and return a parse tree.
487         try:
488             f = Utils.open_source_file(source_filename, "rU")
489             try:
490                 s = PyrexScanner(f, source_desc, source_encoding = f.encoding,
491                                  scope = scope, context = self)
492                 tree = Parsing.p_module(s, pxd, full_module_name)
493             finally:
494                 f.close()
495         except UnicodeDecodeError, msg:
496             #import traceback
497             #traceback.print_exc()
498             error((source_desc, 0, 0), "Decoding error, missing or incorrect coding=<encoding-name> at top of source (%s)" % msg)
499         if Errors.num_errors > 0:
500             raise CompileError
501         return tree
502
503     def extract_module_name(self, path, options):
504         # Find fully_qualified module name from the full pathname
505         # of a source file.
506         dir, filename = os.path.split(path)
507         module_name, _ = os.path.splitext(filename)
508         if "." in module_name:
509             return module_name
510         if module_name == "__init__":
511             dir, module_name = os.path.split(dir)
512         names = [module_name]
513         while self.is_package_dir(dir):
514             parent, package_name = os.path.split(dir)
515             if parent == dir:
516                 break
517             names.append(package_name)
518             dir = parent
519         names.reverse()
520         return ".".join(names)
521
522     def setup_errors(self, options, result):
523         if options.use_listing_file:
524             result.listing_file = Utils.replace_suffix(source, ".lis")
525             path = result.listing_file
526         else:
527             path = None
528         Errors.open_listing_file(path=path,
529                                  echo_to_stderr=options.errors_to_stderr)
530
531     def teardown_errors(self, err, options, result):
532         source_desc = result.compilation_source.source_desc
533         if not isinstance(source_desc, FileSourceDescriptor):
534             raise RuntimeError("Only file sources for code supported")
535         Errors.close_listing_file()
536         result.num_errors = Errors.num_errors
537         if result.num_errors > 0:
538             err = True
539         if err and result.c_file:
540             try:
541                 Utils.castrate_file(result.c_file, os.stat(source_desc.filename))
542             except EnvironmentError:
543                 pass
544             result.c_file = None
545
546 def create_parse(context):
547     def parse(compsrc):
548         source_desc = compsrc.source_desc
549         full_module_name = compsrc.full_module_name
550         initial_pos = (source_desc, 1, 0)
551         scope = context.find_module(full_module_name, pos = initial_pos, need_pxd = 0)
552         tree = context.parse(source_desc, scope, pxd = 0, full_module_name = full_module_name)
553         tree.compilation_source = compsrc
554         tree.scope = scope
555         tree.is_pxd = False
556         return tree
557     return parse
558
559 def create_default_resultobj(compilation_source, options):
560     result = CompilationResult()
561     result.main_source_file = compilation_source.source_desc.filename
562     result.compilation_source = compilation_source
563     source_desc = compilation_source.source_desc
564     if options.output_file:
565         result.c_file = os.path.join(compilation_source.cwd, options.output_file)
566     else:
567         if options.cplus:
568             c_suffix = ".cpp"
569         else:
570             c_suffix = ".c"
571         result.c_file = Utils.replace_suffix(source_desc.filename, c_suffix)
572     return result
573
574 def run_pipeline(source, options, full_module_name = None):
575     # Set up context
576     context = Context(options.include_path, options.compiler_directives,
577                       options.cplus, options.language_level)
578
579     # Set up source object
580     cwd = os.getcwd()
581     source_desc = FileSourceDescriptor(os.path.join(cwd, source))
582     full_module_name = full_module_name or context.extract_module_name(source, options)
583     source = CompilationSource(source_desc, full_module_name, cwd)
584
585     # Set up result object
586     result = create_default_resultobj(source, options)
587     
588     # Get pipeline
589     if source_desc.filename.endswith(".py"):
590         pipeline = context.create_py_pipeline(options, result)
591     else:
592         pipeline = context.create_pyx_pipeline(options, result)
593
594     context.setup_errors(options, result)
595     err, enddata = context.run_pipeline(pipeline, source)
596     context.teardown_errors(err, options, result)
597     return result
598
599 #------------------------------------------------------------------------
600 #
601 #  Main Python entry points
602 #
603 #------------------------------------------------------------------------
604
605 class CompilationSource(object):
606     """
607     Contains the data necesarry to start up a compilation pipeline for
608     a single compilation unit.
609     """
610     def __init__(self, source_desc, full_module_name, cwd):
611         self.source_desc = source_desc
612         self.full_module_name = full_module_name
613         self.cwd = cwd
614
615 class CompilationOptions(object):
616     """
617     Options to the Cython compiler:
618     
619     show_version      boolean   Display version number
620     use_listing_file  boolean   Generate a .lis file
621     errors_to_stderr  boolean   Echo errors to stderr when using .lis
622     include_path      [string]  Directories to search for include files
623     output_file       string    Name of generated .c file
624     generate_pxi      boolean   Generate .pxi file for public declarations
625     recursive         boolean   Recursively find and compile dependencies
626     timestamps        boolean   Only compile changed source files. If None,
627                                 defaults to true when recursive is true.
628     verbose           boolean   Always print source names being compiled
629     quiet             boolean   Don't print source names in recursive mode
630     compiler_directives  dict      Overrides for pragma options (see Options.py)
631     evaluate_tree_assertions boolean  Test support: evaluate parse tree assertions
632     language_level    integer   The Python language level: 2 or 3
633     
634     cplus             boolean   Compile as c++ code
635     """
636     
637     def __init__(self, defaults = None, **kw):
638         self.include_path = []
639         if defaults:
640             if isinstance(defaults, CompilationOptions):
641                 defaults = defaults.__dict__
642         else:
643             defaults = default_options
644         self.__dict__.update(defaults)
645         self.__dict__.update(kw)
646
647
648 class CompilationResult(object):
649     """
650     Results from the Cython compiler:
651     
652     c_file           string or None   The generated C source file
653     h_file           string or None   The generated C header file
654     i_file           string or None   The generated .pxi file
655     api_file         string or None   The generated C API .h file
656     listing_file     string or None   File of error messages
657     object_file      string or None   Result of compiling the C file
658     extension_file   string or None   Result of linking the object file
659     num_errors       integer          Number of compilation errors
660     compilation_source CompilationSource
661     """
662     
663     def __init__(self):
664         self.c_file = None
665         self.h_file = None
666         self.i_file = None
667         self.api_file = None
668         self.listing_file = None
669         self.object_file = None
670         self.extension_file = None
671         self.main_source_file = None
672
673
674 class CompilationResultSet(dict):
675     """
676     Results from compiling multiple Pyrex source files. A mapping
677     from source file paths to CompilationResult instances. Also
678     has the following attributes:
679     
680     num_errors   integer   Total number of compilation errors
681     """
682     
683     num_errors = 0
684
685     def add(self, source, result):
686         self[source] = result
687         self.num_errors += result.num_errors
688
689
690 def compile_single(source, options, full_module_name = None):
691     """
692     compile_single(source, options, full_module_name)
693     
694     Compile the given Pyrex implementation file and return a CompilationResult.
695     Always compiles a single file; does not perform timestamp checking or
696     recursion.
697     """
698     return run_pipeline(source, options, full_module_name)
699
700
701 def compile_multiple(sources, options):
702     """
703     compile_multiple(sources, options)
704     
705     Compiles the given sequence of Pyrex implementation files and returns
706     a CompilationResultSet. Performs timestamp checking and/or recursion
707     if these are specified in the options.
708     """
709     sources = [os.path.abspath(source) for source in sources]
710     processed = set()
711     results = CompilationResultSet()
712     recursive = options.recursive
713     timestamps = options.timestamps
714     if timestamps is None:
715         timestamps = recursive
716     verbose = options.verbose or ((recursive or timestamps) and not options.quiet)
717     for source in sources:
718         if source not in processed:
719             # Compiling multiple sources in one context doesn't quite
720             # work properly yet.
721             if not timestamps or context.c_file_out_of_date(source):
722                 if verbose:
723                     sys.stderr.write("Compiling %s\n" % source)
724
725                 result = run_pipeline(source, options)
726                 results.add(source, result)
727             processed.add(source)
728             if recursive:
729                 for module_name in context.find_cimported_module_names(source):
730                     path = context.find_pyx_file(module_name, [source])
731                     if path:
732                         sources.append(path)
733                     else:
734                         sys.stderr.write(
735                             "Cannot find .pyx file for cimported module '%s'\n" % module_name)
736     return results
737
738 def compile(source, options = None, full_module_name = None, **kwds):
739     """
740     compile(source [, options], [, <option> = <value>]...)
741     
742     Compile one or more Pyrex implementation files, with optional timestamp
743     checking and recursing on dependecies. The source argument may be a string
744     or a sequence of strings If it is a string and no recursion or timestamp
745     checking is requested, a CompilationResult is returned, otherwise a
746     CompilationResultSet is returned.
747     """
748     options = CompilationOptions(defaults = options, **kwds)
749     if isinstance(source, basestring) and not options.timestamps \
750             and not options.recursive:
751         return compile_single(source, options, full_module_name)
752     else:
753         return compile_multiple(source, options)
754
755 #------------------------------------------------------------------------
756 #
757 #  Main command-line entry point
758 #
759 #------------------------------------------------------------------------
760 def setuptools_main():
761     return main(command_line = 1)
762
763 def main(command_line = 0):
764     args = sys.argv[1:]
765     any_failures = 0
766     if command_line:
767         from CmdLine import parse_command_line
768         options, sources = parse_command_line(args)
769     else:
770         options = CompilationOptions(default_options)
771         sources = args
772
773     if options.show_version:
774         sys.stderr.write("Cython version %s\n" % Version.version)
775     if options.working_path!="":
776         os.chdir(options.working_path)
777     try:
778         result = compile(sources, options)
779         if result.num_errors > 0:
780             any_failures = 1
781     except (EnvironmentError, PyrexError), e:
782         sys.stderr.write(str(e) + '\n')
783         any_failures = 1
784     if any_failures:
785         sys.exit(1)
786
787
788
789 #------------------------------------------------------------------------
790 #
791 #  Set the default options depending on the platform
792 #
793 #------------------------------------------------------------------------
794
795 default_options = dict(
796     show_version = 0,
797     use_listing_file = 0,
798     errors_to_stderr = 1,
799     cplus = 0,
800     output_file = None,
801     annotate = False,
802     generate_pxi = 0,
803     working_path = "",
804     recursive = 0,
805     timestamps = None,
806     verbose = 0,
807     quiet = 0,
808     compiler_directives = {},
809     evaluate_tree_assertions = False,
810     emit_linenums = False,
811     language_level = 2,
812 )