Take first Cython step into function before reading variable in test
[cython.git] / Cython / Compiler / Main.py
1 #
2 #   Cython Top Level
3 #
4
5 import os, sys, re
6 if sys.version_info[:2] < (2, 3):
7     sys.stderr.write("Sorry, Cython requires Python 2.3 or later\n")
8     sys.exit(1)
9
10 try:
11     set
12 except NameError:
13     # Python 2.3
14     from sets import Set as set
15
16 import itertools
17 from time import time
18
19 import Code
20 import Errors
21 # Do not import Parsing here, import it when needed, because Parsing imports
22 # Nodes, which globally needs debug command line options initialized to set a
23 # conditional metaclass. These options are processed by CmdLine called from 
24 # main() in this file.
25 # import Parsing
26 import Version
27 from Scanning import PyrexScanner, FileSourceDescriptor
28 from Errors import PyrexError, CompileError, InternalError, AbortError, error, warning
29 from Symtab import BuiltinScope, ModuleScope
30 from Cython import Utils
31 from Cython.Utils import open_new_file, replace_suffix
32 import CythonScope
33 import DebugFlags
34
35 module_name_pattern = re.compile(r"[A-Za-z_][A-Za-z0-9_]*(\.[A-Za-z_][A-Za-z0-9_]*)*$")
36
37 verbose = 0
38
39 def dumptree(t):
40     # For quick debugging in pipelines
41     print t.dump()
42     return t
43
44 def abort_on_errors(node):
45     # Stop the pipeline if there are any errors.
46     if Errors.num_errors != 0:
47         raise AbortError, "pipeline break"
48     return node
49
50 class CompilationData(object):
51     #  Bundles the information that is passed from transform to transform.
52     #  (For now, this is only)
53
54     #  While Context contains every pxd ever loaded, path information etc.,
55     #  this only contains the data related to a single compilation pass
56     #
57     #  pyx                   ModuleNode              Main code tree of this compilation.
58     #  pxds                  {string : ModuleNode}   Trees for the pxds used in the pyx.
59     #  codewriter            CCodeWriter             Where to output final code.
60     #  options               CompilationOptions
61     #  result                CompilationResult
62     pass
63
64 class Context(object):
65     #  This class encapsulates the context needed for compiling
66     #  one or more Cython implementation files along with their
67     #  associated and imported declaration files. It includes
68     #  the root of the module import namespace and the list
69     #  of directories to search for include files.
70     #
71     #  modules               {string : ModuleScope}
72     #  include_directories   [string]
73     #  future_directives     [object]
74     #  language_level        int     currently 2 or 3 for Python 2/3
75     
76     def __init__(self, include_directories, compiler_directives, cpp=False, language_level=2):
77         #self.modules = {"__builtin__" : BuiltinScope()}
78         import Builtin, CythonScope
79         self.modules = {"__builtin__" : Builtin.builtin_scope}
80         self.modules["cython"] = CythonScope.create_cython_scope(self)
81         self.include_directories = include_directories
82         self.future_directives = set()
83         self.compiler_directives = compiler_directives
84         self.cpp = cpp
85
86         self.pxds = {} # full name -> node tree
87
88         standard_include_path = os.path.abspath(os.path.normpath(
89             os.path.join(os.path.dirname(__file__), os.path.pardir, 'Includes')))
90         self.include_directories = include_directories + [standard_include_path]
91
92         self.set_language_level(language_level)
93         
94         self.gdb_debug_outputwriter = None
95
96     def set_language_level(self, level):
97         self.language_level = level
98         if level >= 3:
99             from Future import print_function, unicode_literals
100             self.future_directives.add(print_function)
101             self.future_directives.add(unicode_literals)
102
103     def create_pipeline(self, pxd, py=False):
104         from Visitor import PrintTree
105         from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, PxdPostParse
106         from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
107         from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
108         from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
109         from ParseTreeTransforms import ExpandInplaceOperators
110         from TypeInference import MarkAssignments, MarkOverflowingArithmetic
111         from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck
112         from AnalysedTreeTransforms import AutoTestDictTransform
113         from AutoDocTransforms import EmbedSignature
114         from Optimize import FlattenInListTransform, SwitchTransform, IterationTransform
115         from Optimize import EarlyReplaceBuiltinCalls, OptimizeBuiltinCalls
116         from Optimize import ConstantFolding, FinalOptimizePhase
117         from Optimize import DropRefcountingTransform
118         from Buffer import IntroduceBufferAuxiliaryVars
119         from ModuleNode import check_c_declarations, check_c_declarations_pxd
120
121         if pxd:
122             _check_c_declarations = check_c_declarations_pxd
123             _specific_post_parse = PxdPostParse(self)
124         else:
125             _check_c_declarations = check_c_declarations
126             _specific_post_parse = None
127             
128         if py and not pxd:
129             _align_function_definitions = AlignFunctionDefinitions(self)
130         else:
131             _align_function_definitions = None
132  
133         return [
134             NormalizeTree(self),
135             PostParse(self),
136             _specific_post_parse,
137             InterpretCompilerDirectives(self, self.compiler_directives),
138             _align_function_definitions,
139             MarkClosureVisitor(self),
140             ConstantFolding(),
141             FlattenInListTransform(),
142             WithTransform(self),
143             DecoratorTransform(self),
144             AnalyseDeclarationsTransform(self),
145             AutoTestDictTransform(self),
146             EmbedSignature(self),
147             EarlyReplaceBuiltinCalls(self),  ## Necessary?
148             MarkAssignments(self),
149             MarkOverflowingArithmetic(self),
150             TransformBuiltinMethods(self),  ## Necessary?
151             IntroduceBufferAuxiliaryVars(self),
152             _check_c_declarations,
153             AnalyseExpressionsTransform(self),
154             CreateClosureClasses(self),  ## After all lookups and type inference
155             ExpandInplaceOperators(self),
156             OptimizeBuiltinCalls(self),  ## Necessary?
157             IterationTransform(),
158             SwitchTransform(),
159             DropRefcountingTransform(),
160             FinalOptimizePhase(self),
161             GilCheck(),
162             ]
163
164     def create_pyx_pipeline(self, options, result, py=False):
165         def generate_pyx_code(module_node):
166             module_node.process_implementation(options, result)
167             result.compilation_source = module_node.compilation_source
168             return result
169
170         def inject_pxd_code(module_node):
171             from textwrap import dedent
172             stats = module_node.body.stats
173             for name, (statlistnode, scope) in self.pxds.iteritems():
174                 # Copy over function nodes to the module
175                 # (this seems strange -- I believe the right concept is to split
176                 # ModuleNode into a ModuleNode and a CodeGenerator, and tell that
177                 # CodeGenerator to generate code both from the pyx and pxd ModuleNodes.
178                  stats.append(statlistnode)
179                  # Until utility code is moved to code generation phase everywhere,
180                  # we need to copy it over to the main scope
181                  module_node.scope.utility_code_list.extend(scope.utility_code_list)
182             return module_node
183
184         test_support = []
185         if options.evaluate_tree_assertions:
186             from Cython.TestUtils import TreeAssertVisitor
187             test_support.append(TreeAssertVisitor())
188
189         if options.gdb_debug:
190             from Cython.Debugger import DebugWriter
191             from ParseTreeTransforms import DebugTransform
192             self.gdb_debug_outputwriter = DebugWriter.CythonDebugWriter(
193                 options.output_dir)
194             debug_transform = [DebugTransform(self, options, result)]
195         else:
196             debug_transform = []
197             
198         return list(itertools.chain(
199             [create_parse(self)],
200             self.create_pipeline(pxd=False, py=py),
201             test_support,
202             [inject_pxd_code, abort_on_errors],
203             debug_transform,
204             [generate_pyx_code]))
205
206     def create_pxd_pipeline(self, scope, module_name):
207         def parse_pxd(source_desc):
208             tree = self.parse(source_desc, scope, pxd=True,
209                               full_module_name=module_name)
210             tree.scope = scope
211             tree.is_pxd = True
212             return tree
213
214         from CodeGeneration import ExtractPxdCode
215
216         # The pxd pipeline ends up with a CCodeWriter containing the
217         # code of the pxd, as well as a pxd scope.
218         return [parse_pxd] + self.create_pipeline(pxd=True) + [
219             ExtractPxdCode(self),
220             ]
221             
222     def create_py_pipeline(self, options, result):
223         return self.create_pyx_pipeline(options, result, py=True)
224
225
226     def process_pxd(self, source_desc, scope, module_name):
227         pipeline = self.create_pxd_pipeline(scope, module_name)
228         result = self.run_pipeline(pipeline, source_desc)
229         return result
230     
231     def nonfatal_error(self, exc):
232         return Errors.report_error(exc)
233
234     def run_pipeline(self, pipeline, source):
235         error = None
236         data = source
237         try:
238             try:
239                 for phase in pipeline:
240                     if phase is not None:
241                         if DebugFlags.debug_verbose_pipeline:
242                             t = time()
243                             print "Entering pipeline phase %r" % phase
244                         data = phase(data)
245                         if DebugFlags.debug_verbose_pipeline:
246                             print "    %.3f seconds" % (time() - t)
247             except CompileError, err:
248                 # err is set
249                 Errors.report_error(err)
250                 error = err
251         except InternalError, err:
252             # Only raise if there was not an earlier error
253             if Errors.num_errors == 0:
254                 raise
255             error = err
256         except AbortError, err:
257             error = err
258         return (error, data)
259
260     def find_module(self, module_name, 
261             relative_to = None, pos = None, need_pxd = 1):
262         # Finds and returns the module scope corresponding to
263         # the given relative or absolute module name. If this
264         # is the first time the module has been requested, finds
265         # the corresponding .pxd file and process it.
266         # If relative_to is not None, it must be a module scope,
267         # and the module will first be searched for relative to
268         # that module, provided its name is not a dotted name.
269         debug_find_module = 0
270         if debug_find_module:
271             print("Context.find_module: module_name = %s, relative_to = %s, pos = %s, need_pxd = %s" % (
272                     module_name, relative_to, pos, need_pxd))
273
274         scope = None
275         pxd_pathname = None
276         if not module_name_pattern.match(module_name):
277             if pos is None:
278                 pos = (module_name, 0, 0)
279             raise CompileError(pos,
280                 "'%s' is not a valid module name" % module_name)
281         if "." not in module_name and relative_to:
282             if debug_find_module:
283                 print("...trying relative import")
284             scope = relative_to.lookup_submodule(module_name)
285             if not scope:
286                 qualified_name = relative_to.qualify_name(module_name)
287                 pxd_pathname = self.find_pxd_file(qualified_name, pos)
288                 if pxd_pathname:
289                     scope = relative_to.find_submodule(module_name)
290         if not scope:
291             if debug_find_module:
292                 print("...trying absolute import")
293             scope = self
294             for name in module_name.split("."):
295                 scope = scope.find_submodule(name)
296         if debug_find_module:
297             print("...scope =", scope)
298         if not scope.pxd_file_loaded:
299             if debug_find_module:
300                 print("...pxd not loaded")
301             scope.pxd_file_loaded = 1
302             if not pxd_pathname:
303                 if debug_find_module:
304                     print("...looking for pxd file")
305                 pxd_pathname = self.find_pxd_file(module_name, pos)
306                 if debug_find_module:
307                     print("......found ", pxd_pathname)
308                 if not pxd_pathname and need_pxd:
309                     package_pathname = self.search_include_directories(module_name, ".py", pos)
310                     if package_pathname and package_pathname.endswith('__init__.py'):
311                         pass
312                     else:
313                         error(pos, "'%s.pxd' not found" % module_name)
314             if pxd_pathname:
315                 try:
316                     if debug_find_module:
317                         print("Context.find_module: Parsing %s" % pxd_pathname)
318                     source_desc = FileSourceDescriptor(pxd_pathname)
319                     err, result = self.process_pxd(source_desc, scope, module_name)
320                     if err:
321                         raise err
322                     (pxd_codenodes, pxd_scope) = result
323                     self.pxds[module_name] = (pxd_codenodes, pxd_scope)
324                 except CompileError:
325                     pass
326         return scope
327     
328     def find_pxd_file(self, qualified_name, pos):
329         # Search include path for the .pxd file corresponding to the
330         # given fully-qualified module name.
331         # Will find either a dotted filename or a file in a
332         # package directory. If a source file position is given,
333         # the directory containing the source file is searched first
334         # for a dotted filename, and its containing package root
335         # directory is searched first for a non-dotted filename.
336         pxd = self.search_include_directories(qualified_name, ".pxd", pos)
337         if pxd is None: # XXX Keep this until Includes/Deprecated is removed
338             if (qualified_name.startswith('python') or
339                 qualified_name in ('stdlib', 'stdio', 'stl')):
340                 standard_include_path = os.path.abspath(os.path.normpath(
341                         os.path.join(os.path.dirname(__file__), os.path.pardir, 'Includes')))
342                 deprecated_include_path = os.path.join(standard_include_path, 'Deprecated')
343                 self.include_directories.append(deprecated_include_path)
344                 try:
345                     pxd = self.search_include_directories(qualified_name, ".pxd", pos)
346                 finally:
347                     self.include_directories.pop()
348                 if pxd:
349                     name = qualified_name
350                     if name.startswith('python'):
351                         warning(pos, "'%s' is deprecated, use 'cpython'" % name, 1)
352                     elif name in ('stdlib', 'stdio'):
353                         warning(pos, "'%s' is deprecated, use 'libc.%s'" % (name, name), 1)
354                     elif name in ('stl'):
355                         warning(pos, "'%s' is deprecated, use 'libcpp.*.*'" % name, 1)
356         return pxd
357
358     def find_pyx_file(self, qualified_name, pos):
359         # Search include path for the .pyx file corresponding to the
360         # given fully-qualified module name, as for find_pxd_file().
361         return self.search_include_directories(qualified_name, ".pyx", pos)
362     
363     def find_include_file(self, filename, pos):
364         # Search list of include directories for filename.
365         # Reports an error and returns None if not found.
366         path = self.search_include_directories(filename, "", pos,
367                                                include=True)
368         if not path:
369             error(pos, "'%s' not found" % filename)
370         return path
371     
372     def search_include_directories(self, qualified_name, suffix, pos,
373                                    include=False):
374         # Search the list of include directories for the given
375         # file name. If a source file position is given, first
376         # searches the directory containing that file. Returns
377         # None if not found, but does not report an error.
378         # The 'include' option will disable package dereferencing.
379         dirs = self.include_directories
380         if pos:
381             file_desc = pos[0]
382             if not isinstance(file_desc, FileSourceDescriptor):
383                 raise RuntimeError("Only file sources for code supported")
384             if include:
385                 dirs = [os.path.dirname(file_desc.filename)] + dirs
386             else:
387                 dirs = [self.find_root_package_dir(file_desc.filename)] + dirs
388
389         dotted_filename = qualified_name
390         if suffix:
391             dotted_filename += suffix
392         if not include:
393             names = qualified_name.split('.')
394             package_names = names[:-1]
395             module_name = names[-1]
396             module_filename = module_name + suffix
397             package_filename = "__init__" + suffix
398
399         for dir in dirs:
400             path = os.path.join(dir, dotted_filename)
401             if Utils.path_exists(path):
402                 return path
403             if not include:
404                 package_dir = self.check_package_dir(dir, package_names)
405                 if package_dir is not None:
406                     path = os.path.join(package_dir, module_filename)
407                     if Utils.path_exists(path):
408                         return path
409                     path = os.path.join(dir, package_dir, module_name,
410                                         package_filename)
411                     if Utils.path_exists(path):
412                         return path
413         return None
414
415     def find_root_package_dir(self, file_path):
416         dir = os.path.dirname(file_path)
417         while self.is_package_dir(dir):
418             parent = os.path.dirname(dir)
419             if parent == dir:
420                 break
421             dir = parent
422         return dir
423
424     def check_package_dir(self, dir, package_names):
425         for dirname in package_names:
426             dir = os.path.join(dir, dirname)
427             if not self.is_package_dir(dir):
428                 return None
429         return dir
430
431     def c_file_out_of_date(self, source_path):
432         c_path = Utils.replace_suffix(source_path, ".c")
433         if not os.path.exists(c_path):
434             return 1
435         c_time = Utils.modification_time(c_path)
436         if Utils.file_newer_than(source_path, c_time):
437             return 1
438         pos = [source_path]
439         pxd_path = Utils.replace_suffix(source_path, ".pxd")
440         if os.path.exists(pxd_path) and Utils.file_newer_than(pxd_path, c_time):
441             return 1
442         for kind, name in self.read_dependency_file(source_path):
443             if kind == "cimport":
444                 dep_path = self.find_pxd_file(name, pos)
445             elif kind == "include":
446                 dep_path = self.search_include_directories(name, pos)
447             else:
448                 continue
449             if dep_path and Utils.file_newer_than(dep_path, c_time):
450                 return 1
451         return 0
452     
453     def find_cimported_module_names(self, source_path):
454         return [ name for kind, name in self.read_dependency_file(source_path)
455                  if kind == "cimport" ]
456
457     def is_package_dir(self, dir_path):
458         #  Return true if the given directory is a package directory.
459         for filename in ("__init__.py", 
460                          "__init__.pyx", 
461                          "__init__.pxd"):
462             path = os.path.join(dir_path, filename)
463             if Utils.path_exists(path):
464                 return 1
465
466     def read_dependency_file(self, source_path):
467         dep_path = Utils.replace_suffix(source_path, ".dep")
468         if os.path.exists(dep_path):
469             f = open(dep_path, "rU")
470             chunks = [ line.strip().split(" ", 1)
471                        for line in f.readlines()
472                        if " " in line.strip() ]
473             f.close()
474             return chunks
475         else:
476             return ()
477
478     def lookup_submodule(self, name):
479         # Look up a top-level module. Returns None if not found.
480         return self.modules.get(name, None)
481
482     def find_submodule(self, name):
483         # Find a top-level module, creating a new one if needed.
484         scope = self.lookup_submodule(name)
485         if not scope:
486             scope = ModuleScope(name, 
487                 parent_module = None, context = self)
488             self.modules[name] = scope
489         return scope
490
491     def parse(self, source_desc, scope, pxd, full_module_name):
492         if not isinstance(source_desc, FileSourceDescriptor):
493             raise RuntimeError("Only file sources for code supported")
494         source_filename = source_desc.filename
495         scope.cpp = self.cpp
496         # Parse the given source file and return a parse tree.
497         try:
498             f = Utils.open_source_file(source_filename, "rU")
499             try:
500                 import Parsing
501                 s = PyrexScanner(f, source_desc, source_encoding = f.encoding,
502                                  scope = scope, context = self)
503                 tree = Parsing.p_module(s, pxd, full_module_name)
504             finally:
505                 f.close()
506         except UnicodeDecodeError, msg:
507             #import traceback
508             #traceback.print_exc()
509             error((source_desc, 0, 0), "Decoding error, missing or incorrect coding=<encoding-name> at top of source (%s)" % msg)
510         if Errors.num_errors > 0:
511             raise CompileError
512         return tree
513
514     def extract_module_name(self, path, options):
515         # Find fully_qualified module name from the full pathname
516         # of a source file.
517         dir, filename = os.path.split(path)
518         module_name, _ = os.path.splitext(filename)
519         if "." in module_name:
520             return module_name
521         if module_name == "__init__":
522             dir, module_name = os.path.split(dir)
523         names = [module_name]
524         while self.is_package_dir(dir):
525             parent, package_name = os.path.split(dir)
526             if parent == dir:
527                 break
528             names.append(package_name)
529             dir = parent
530         names.reverse()
531         return ".".join(names)
532
533     def setup_errors(self, options, result):
534         Errors.reset() # clear any remaining error state
535         if options.use_listing_file:
536             result.listing_file = Utils.replace_suffix(source, ".lis")
537             path = result.listing_file
538         else:
539             path = None
540         Errors.open_listing_file(path=path,
541                                  echo_to_stderr=options.errors_to_stderr)
542
543     def teardown_errors(self, err, options, result):
544         source_desc = result.compilation_source.source_desc
545         if not isinstance(source_desc, FileSourceDescriptor):
546             raise RuntimeError("Only file sources for code supported")
547         Errors.close_listing_file()
548         result.num_errors = Errors.num_errors
549         if result.num_errors > 0:
550             err = True
551         if err and result.c_file:
552             try:
553                 Utils.castrate_file(result.c_file, os.stat(source_desc.filename))
554             except EnvironmentError:
555                 pass
556             result.c_file = None
557
558 def create_parse(context):
559     def parse(compsrc):
560         source_desc = compsrc.source_desc
561         full_module_name = compsrc.full_module_name
562         initial_pos = (source_desc, 1, 0)
563         scope = context.find_module(full_module_name, pos = initial_pos, need_pxd = 0)
564         tree = context.parse(source_desc, scope, pxd = 0, full_module_name = full_module_name)
565         tree.compilation_source = compsrc
566         tree.scope = scope
567         tree.is_pxd = False
568         return tree
569     return parse
570
571 def create_default_resultobj(compilation_source, options):
572     result = CompilationResult()
573     result.main_source_file = compilation_source.source_desc.filename
574     result.compilation_source = compilation_source
575     source_desc = compilation_source.source_desc
576     if options.output_file:
577         result.c_file = os.path.join(compilation_source.cwd, options.output_file)
578     else:
579         if options.cplus:
580             c_suffix = ".cpp"
581         else:
582             c_suffix = ".c"
583         result.c_file = Utils.replace_suffix(source_desc.filename, c_suffix)
584     return result
585
586 def run_pipeline(source, options, full_module_name = None):
587     # Set up context
588     context = options.create_context()
589
590     # Set up source object
591     cwd = os.getcwd()
592     source_desc = FileSourceDescriptor(os.path.join(cwd, source))
593     full_module_name = full_module_name or context.extract_module_name(source, options)
594     source = CompilationSource(source_desc, full_module_name, cwd)
595
596     # Set up result object
597     result = create_default_resultobj(source, options)
598     
599     # Get pipeline
600     if source_desc.filename.endswith(".py"):
601         pipeline = context.create_py_pipeline(options, result)
602     else:
603         pipeline = context.create_pyx_pipeline(options, result)
604
605     context.setup_errors(options, result)
606     err, enddata = context.run_pipeline(pipeline, source)
607     context.teardown_errors(err, options, result)
608     return result
609     
610
611 #------------------------------------------------------------------------
612 #
613 #  Main Python entry points
614 #
615 #------------------------------------------------------------------------
616
617 class CompilationSource(object):
618     """
619     Contains the data necesarry to start up a compilation pipeline for
620     a single compilation unit.
621     """
622     def __init__(self, source_desc, full_module_name, cwd):
623         self.source_desc = source_desc
624         self.full_module_name = full_module_name
625         self.cwd = cwd
626
627 class CompilationOptions(object):
628     """
629     Options to the Cython compiler:
630     
631     show_version      boolean   Display version number
632     use_listing_file  boolean   Generate a .lis file
633     errors_to_stderr  boolean   Echo errors to stderr when using .lis
634     include_path      [string]  Directories to search for include files
635     output_file       string    Name of generated .c file
636     generate_pxi      boolean   Generate .pxi file for public declarations
637     recursive         boolean   Recursively find and compile dependencies
638     timestamps        boolean   Only compile changed source files. If None,
639                                 defaults to true when recursive is true.
640     verbose           boolean   Always print source names being compiled
641     quiet             boolean   Don't print source names in recursive mode
642     compiler_directives  dict      Overrides for pragma options (see Options.py)
643     evaluate_tree_assertions boolean  Test support: evaluate parse tree assertions
644     language_level    integer   The Python language level: 2 or 3
645     
646     cplus             boolean   Compile as c++ code
647     """
648     
649     def __init__(self, defaults = None, **kw):
650         self.include_path = []
651         if defaults:
652             if isinstance(defaults, CompilationOptions):
653                 defaults = defaults.__dict__
654         else:
655             defaults = default_options
656         self.__dict__.update(defaults)
657         self.__dict__.update(kw)
658
659     def create_context(self):
660         return Context(self.include_path, self.compiler_directives,
661                       self.cplus, self.language_level)
662
663
664 class CompilationResult(object):
665     """
666     Results from the Cython compiler:
667     
668     c_file           string or None   The generated C source file
669     h_file           string or None   The generated C header file
670     i_file           string or None   The generated .pxi file
671     api_file         string or None   The generated C API .h file
672     listing_file     string or None   File of error messages
673     object_file      string or None   Result of compiling the C file
674     extension_file   string or None   Result of linking the object file
675     num_errors       integer          Number of compilation errors
676     compilation_source CompilationSource
677     """
678     
679     def __init__(self):
680         self.c_file = None
681         self.h_file = None
682         self.i_file = None
683         self.api_file = None
684         self.listing_file = None
685         self.object_file = None
686         self.extension_file = None
687         self.main_source_file = None
688
689
690 class CompilationResultSet(dict):
691     """
692     Results from compiling multiple Pyrex source files. A mapping
693     from source file paths to CompilationResult instances. Also
694     has the following attributes:
695     
696     num_errors   integer   Total number of compilation errors
697     """
698     
699     num_errors = 0
700
701     def add(self, source, result):
702         self[source] = result
703         self.num_errors += result.num_errors
704
705
706 def compile_single(source, options, full_module_name = None):
707     """
708     compile_single(source, options, full_module_name)
709     
710     Compile the given Pyrex implementation file and return a CompilationResult.
711     Always compiles a single file; does not perform timestamp checking or
712     recursion.
713     """
714     return run_pipeline(source, options, full_module_name)
715
716
717 def compile_multiple(sources, options):
718     """
719     compile_multiple(sources, options)
720     
721     Compiles the given sequence of Pyrex implementation files and returns
722     a CompilationResultSet. Performs timestamp checking and/or recursion
723     if these are specified in the options.
724     """
725     context = options.create_context()
726     sources = [os.path.abspath(source) for source in sources]
727     processed = set()
728     results = CompilationResultSet()
729     recursive = options.recursive
730     timestamps = options.timestamps
731     if timestamps is None:
732         timestamps = recursive
733     verbose = options.verbose or ((recursive or timestamps) and not options.quiet)
734     for source in sources:
735         if source not in processed:
736             # Compiling multiple sources in one context doesn't quite
737             # work properly yet.
738             if not timestamps or context.c_file_out_of_date(source):
739                 if verbose:
740                     sys.stderr.write("Compiling %s\n" % source)
741
742                 result = run_pipeline(source, options)
743                 results.add(source, result)
744             processed.add(source)
745             if recursive:
746                 for module_name in context.find_cimported_module_names(source):
747                     path = context.find_pyx_file(module_name, [source])
748                     if path:
749                         sources.append(path)
750                     else:
751                         sys.stderr.write(
752                             "Cannot find .pyx file for cimported module '%s'\n" % module_name)
753     return results
754
755 def compile(source, options = None, full_module_name = None, **kwds):
756     """
757     compile(source [, options], [, <option> = <value>]...)
758     
759     Compile one or more Pyrex implementation files, with optional timestamp
760     checking and recursing on dependecies. The source argument may be a string
761     or a sequence of strings If it is a string and no recursion or timestamp
762     checking is requested, a CompilationResult is returned, otherwise a
763     CompilationResultSet is returned.
764     """
765     options = CompilationOptions(defaults = options, **kwds)
766     if isinstance(source, basestring) and not options.timestamps \
767             and not options.recursive:
768         return compile_single(source, options, full_module_name)
769     else:
770         return compile_multiple(source, options)
771
772 #------------------------------------------------------------------------
773 #
774 #  Main command-line entry point
775 #
776 #------------------------------------------------------------------------
777 def setuptools_main():
778     return main(command_line = 1)
779
780 def main(command_line = 0):
781     args = sys.argv[1:]
782     any_failures = 0
783     if command_line:
784         from CmdLine import parse_command_line
785         options, sources = parse_command_line(args)
786     else:
787         options = CompilationOptions(default_options)
788         sources = args
789
790     if options.show_version:
791         sys.stderr.write("Cython version %s\n" % Version.version)
792     if options.working_path!="":
793         os.chdir(options.working_path)
794     try:
795         result = compile(sources, options)
796         if result.num_errors > 0:
797             any_failures = 1
798     except (EnvironmentError, PyrexError), e:
799         sys.stderr.write(str(e) + '\n')
800         any_failures = 1
801     if any_failures:
802         sys.exit(1)
803
804
805
806 #------------------------------------------------------------------------
807 #
808 #  Set the default options depending on the platform
809 #
810 #------------------------------------------------------------------------
811
812 default_options = dict(
813     show_version = 0,
814     use_listing_file = 0,
815     errors_to_stderr = 1,
816     cplus = 0,
817     output_file = None,
818     annotate = False,
819     generate_pxi = 0,
820     working_path = "",
821     recursive = 0,
822     timestamps = None,
823     verbose = 0,
824     quiet = 0,
825     compiler_directives = {},
826     evaluate_tree_assertions = False,
827     emit_linenums = False,
828     language_level = 2,
829     gdb_debug = False,
830 )