move 'with' statement implementation back into WithTransform to fix 'with' statement...
[cython.git] / Cython / Compiler / Main.py
index 51d0a08143e47e2b6b14a540fad06586857f673d..fb8d6fcb02d5d9ffdc7fea9b4774e763164ac3e4 100644 (file)
@@ -2,7 +2,7 @@
 #   Cython Top Level
 #
 
-import os, sys, re
+import os, sys, re, codecs
 if sys.version_info[:2] < (2, 3):
     sys.stderr.write("Sorry, Cython requires Python 2.3 or later\n")
     sys.exit(1)
@@ -13,13 +13,19 @@ except NameError:
     # Python 2.3
     from sets import Set as set
 
+import itertools
 from time import time
+
 import Code
 import Errors
-import Parsing
+# Do not import Parsing here, import it when needed, because Parsing imports
+# Nodes, which globally needs debug command line options initialized to set a
+# conditional metaclass. These options are processed by CmdLine called from
+# main() in this file.
+# import Parsing
 import Version
 from Scanning import PyrexScanner, FileSourceDescriptor
-from Errors import PyrexError, CompileError, InternalError, error, warning
+from Errors import PyrexError, CompileError, InternalError, AbortError, error, warning
 from Symtab import BuiltinScope, ModuleScope
 from Cython import Utils
 from Cython.Utils import open_new_file, replace_suffix
@@ -38,7 +44,7 @@ def dumptree(t):
 def abort_on_errors(node):
     # Stop the pipeline if there are any errors.
     if Errors.num_errors != 0:
-        raise InternalError, "abort"
+        raise AbortError, "pipeline break"
     return node
 
 class CompilationData(object):
@@ -66,9 +72,8 @@ class Context(object):
     #  include_directories   [string]
     #  future_directives     [object]
     #  language_level        int     currently 2 or 3 for Python 2/3
-    
+
     def __init__(self, include_directories, compiler_directives, cpp=False, language_level=2):
-        #self.modules = {"__builtin__" : BuiltinScope()}
         import Builtin, CythonScope
         self.modules = {"__builtin__" : Builtin.builtin_scope}
         self.modules["cython"] = CythonScope.create_cython_scope(self)
@@ -85,12 +90,15 @@ class Context(object):
 
         self.set_language_level(language_level)
 
+        self.gdb_debug_outputwriter = None
+
     def set_language_level(self, level):
         self.language_level = level
         if level >= 3:
             from Future import print_function, unicode_literals
             self.future_directives.add(print_function)
             self.future_directives.add(unicode_literals)
+            self.modules['builtins'] = self.modules['__builtin__']
 
     def create_pipeline(self, pxd, py=False):
         from Visitor import PrintTree
@@ -116,19 +124,19 @@ class Context(object):
         else:
             _check_c_declarations = check_c_declarations
             _specific_post_parse = None
-            
+
         if py and not pxd:
             _align_function_definitions = AlignFunctionDefinitions(self)
         else:
             _align_function_definitions = None
+
         return [
             NormalizeTree(self),
             PostParse(self),
             _specific_post_parse,
             InterpretCompilerDirectives(self, self.compiler_directives),
-            _align_function_definitions,
             MarkClosureVisitor(self),
+            _align_function_definitions,
             ConstantFolding(),
             FlattenInListTransform(),
             WithTransform(self),
@@ -178,13 +186,22 @@ class Context(object):
             from Cython.TestUtils import TreeAssertVisitor
             test_support.append(TreeAssertVisitor())
 
-        return ([
-                create_parse(self),
-            ] + self.create_pipeline(pxd=False, py=py) + test_support + [
-                inject_pxd_code,
-                abort_on_errors,
-                generate_pyx_code,
-            ])
+        if options.gdb_debug:
+            from Cython.Debugger import DebugWriter
+            from ParseTreeTransforms import DebugTransform
+            self.gdb_debug_outputwriter = DebugWriter.CythonDebugWriter(
+                options.output_dir)
+            debug_transform = [DebugTransform(self, options, result)]
+        else:
+            debug_transform = []
+
+        return list(itertools.chain(
+            [create_parse(self)],
+            self.create_pipeline(pxd=False, py=py),
+            test_support,
+            [inject_pxd_code, abort_on_errors],
+            debug_transform,
+            [generate_pyx_code]))
 
     def create_pxd_pipeline(self, scope, module_name):
         def parse_pxd(source_desc):
@@ -201,7 +218,7 @@ class Context(object):
         return [parse_pxd] + self.create_pipeline(pxd=True) + [
             ExtractPxdCode(self),
             ]
-            
+
     def create_py_pipeline(self, options, result):
         return self.create_pyx_pipeline(options, result, py=True)
 
@@ -210,7 +227,7 @@ class Context(object):
         pipeline = self.create_pxd_pipeline(scope, module_name)
         result = self.run_pipeline(pipeline, source_desc)
         return result
-    
+
     def nonfatal_error(self, exc):
         return Errors.report_error(exc)
 
@@ -218,26 +235,29 @@ class Context(object):
         error = None
         data = source
         try:
-            for phase in pipeline:
-                if phase is not None:
-                    if DebugFlags.debug_verbose_pipeline:
-                        t = time()
-                        print "Entering pipeline phase %r" % phase
-                    data = phase(data)
-                    if DebugFlags.debug_verbose_pipeline:
-                        print "    %.3f seconds" % (time() - t)
-        except CompileError, err:
-            # err is set
-            Errors.report_error(err)
-            error = err
+            try:
+                for phase in pipeline:
+                    if phase is not None:
+                        if DebugFlags.debug_verbose_pipeline:
+                            t = time()
+                            print "Entering pipeline phase %r" % phase
+                        data = phase(data)
+                        if DebugFlags.debug_verbose_pipeline:
+                            print "    %.3f seconds" % (time() - t)
+            except CompileError, err:
+                # err is set
+                Errors.report_error(err)
+                error = err
         except InternalError, err:
             # Only raise if there was not an earlier error
             if Errors.num_errors == 0:
                 raise
             error = err
+        except AbortError, err:
+            error = err
         return (error, data)
 
-    def find_module(self, module_name, 
+    def find_module(self, module_name,
             relative_to = None, pos = None, need_pxd = 1):
         # Finds and returns the module scope corresponding to
         # the given relative or absolute module name. If this
@@ -295,7 +315,10 @@ class Context(object):
                 try:
                     if debug_find_module:
                         print("Context.find_module: Parsing %s" % pxd_pathname)
-                    source_desc = FileSourceDescriptor(pxd_pathname)
+                    rel_path = module_name.replace('.', os.sep) + os.path.splitext(pxd_pathname)[1]
+                    if not pxd_pathname.endswith(rel_path):
+                        rel_path = pxd_pathname # safety measure to prevent printing incorrect paths
+                    source_desc = FileSourceDescriptor(pxd_pathname, rel_path)
                     err, result = self.process_pxd(source_desc, scope, module_name)
                     if err:
                         raise err
@@ -304,7 +327,7 @@ class Context(object):
                 except CompileError:
                     pass
         return scope
-    
+
     def find_pxd_file(self, qualified_name, pos):
         # Search include path for the .pxd file corresponding to the
         # given fully-qualified module name.
@@ -339,7 +362,7 @@ class Context(object):
         # Search include path for the .pyx file corresponding to the
         # given fully-qualified module name, as for find_pxd_file().
         return self.search_include_directories(qualified_name, ".pyx", pos)
-    
+
     def find_include_file(self, filename, pos):
         # Search list of include directories for filename.
         # Reports an error and returns None if not found.
@@ -348,7 +371,7 @@ class Context(object):
         if not path:
             error(pos, "'%s' not found" % filename)
         return path
-    
+
     def search_include_directories(self, qualified_name, suffix, pos,
                                    include=False):
         # Search the list of include directories for the given
@@ -429,15 +452,15 @@ class Context(object):
             if dep_path and Utils.file_newer_than(dep_path, c_time):
                 return 1
         return 0
-    
+
     def find_cimported_module_names(self, source_path):
         return [ name for kind, name in self.read_dependency_file(source_path)
                  if kind == "cimport" ]
 
     def is_package_dir(self, dir_path):
         #  Return true if the given directory is a package directory.
-        for filename in ("__init__.py", 
-                         "__init__.pyx", 
+        for filename in ("__init__.py",
+                         "__init__.pyx",
                          "__init__.pxd"):
             path = os.path.join(dir_path, filename)
             if Utils.path_exists(path):
@@ -463,7 +486,7 @@ class Context(object):
         # Find a top-level module, creating a new one if needed.
         scope = self.lookup_submodule(name)
         if not scope:
-            scope = ModuleScope(name, 
+            scope = ModuleScope(name,
                 parent_module = None, context = self)
             self.modules[name] = scope
         return scope
@@ -477,6 +500,7 @@ class Context(object):
         try:
             f = Utils.open_source_file(source_filename, "rU")
             try:
+                import Parsing
                 s = PyrexScanner(f, source_desc, source_encoding = f.encoding,
                                  scope = scope, context = self)
                 tree = Parsing.p_module(s, pxd, full_module_name)
@@ -564,19 +588,35 @@ def create_default_resultobj(compilation_source, options):
 
 def run_pipeline(source, options, full_module_name = None):
     # Set up context
-    context = optons.create_context()
+    context = options.create_context()
 
     # Set up source object
     cwd = os.getcwd()
-    source_desc = FileSourceDescriptor(os.path.join(cwd, source))
+    abs_path = os.path.abspath(source)
+    source_ext = os.path.splitext(source)[1]
     full_module_name = full_module_name or context.extract_module_name(source, options)
+    if options.relative_path_in_code_position_comments:
+        rel_path = full_module_name.replace('.', os.sep) + source_ext
+        if not abs_path.endswith(rel_path):
+            rel_path = source # safety measure to prevent printing incorrect paths
+    else:
+        rel_path = abs_path
+    source_desc = FileSourceDescriptor(abs_path, rel_path)
     source = CompilationSource(source_desc, full_module_name, cwd)
 
     # Set up result object
     result = create_default_resultobj(source, options)
-    
+
+    if options.annotate is None:
+        # By default, decide based on whether an html file already exists.
+        html_filename = os.path.splitext(result.c_file)[0] + ".html"
+        if os.path.exists(html_filename):
+            line = codecs.open(html_filename, "r", encoding="UTF-8").readline()
+            if line.startswith(u'<!-- Generated by Cython'):
+                options.annotate = True
+
     # Get pipeline
-    if source_desc.filename.endswith(".py"):
+    if source_ext.lower() == '.py':
         pipeline = context.create_py_pipeline(options, result)
     else:
         pipeline = context.create_pyx_pipeline(options, result)
@@ -585,7 +625,7 @@ def run_pipeline(source, options, full_module_name = None):
     err, enddata = context.run_pipeline(pipeline, source)
     context.teardown_errors(err, options, result)
     return result
-    
+
 
 #------------------------------------------------------------------------
 #
@@ -606,7 +646,7 @@ class CompilationSource(object):
 class CompilationOptions(object):
     """
     Options to the Cython compiler:
-    
+
     show_version      boolean   Display version number
     use_listing_file  boolean   Generate a .lis file
     errors_to_stderr  boolean   Echo errors to stderr when using .lis
@@ -621,10 +661,10 @@ class CompilationOptions(object):
     compiler_directives  dict      Overrides for pragma options (see Options.py)
     evaluate_tree_assertions boolean  Test support: evaluate parse tree assertions
     language_level    integer   The Python language level: 2 or 3
-    
+
     cplus             boolean   Compile as c++ code
     """
-    
+
     def __init__(self, defaults = None, **kw):
         self.include_path = []
         if defaults:
@@ -643,7 +683,7 @@ class CompilationOptions(object):
 class CompilationResult(object):
     """
     Results from the Cython compiler:
-    
+
     c_file           string or None   The generated C source file
     h_file           string or None   The generated C header file
     i_file           string or None   The generated .pxi file
@@ -654,7 +694,7 @@ class CompilationResult(object):
     num_errors       integer          Number of compilation errors
     compilation_source CompilationSource
     """
-    
+
     def __init__(self):
         self.c_file = None
         self.h_file = None
@@ -671,10 +711,10 @@ class CompilationResultSet(dict):
     Results from compiling multiple Pyrex source files. A mapping
     from source file paths to CompilationResult instances. Also
     has the following attributes:
-    
+
     num_errors   integer   Total number of compilation errors
     """
-    
+
     num_errors = 0
 
     def add(self, source, result):
@@ -685,7 +725,7 @@ class CompilationResultSet(dict):
 def compile_single(source, options, full_module_name = None):
     """
     compile_single(source, options, full_module_name)
-    
+
     Compile the given Pyrex implementation file and return a CompilationResult.
     Always compiles a single file; does not perform timestamp checking or
     recursion.
@@ -696,11 +736,12 @@ def compile_single(source, options, full_module_name = None):
 def compile_multiple(sources, options):
     """
     compile_multiple(sources, options)
-    
+
     Compiles the given sequence of Pyrex implementation files and returns
     a CompilationResultSet. Performs timestamp checking and/or recursion
     if these are specified in the options.
     """
+    context = options.create_context()
     sources = [os.path.abspath(source) for source in sources]
     processed = set()
     results = CompilationResultSet()
@@ -733,7 +774,7 @@ def compile_multiple(sources, options):
 def compile(source, options = None, full_module_name = None, **kwds):
     """
     compile(source [, options], [, <option> = <value>]...)
-    
+
     Compile one or more Pyrex implementation files, with optional timestamp
     checking and recursing on dependecies. The source argument may be a string
     or a sequence of strings If it is a string and no recursion or timestamp
@@ -793,7 +834,7 @@ default_options = dict(
     errors_to_stderr = 1,
     cplus = 0,
     output_file = None,
-    annotate = False,
+    annotate = None,
     generate_pxi = 0,
     working_path = "",
     recursive = 0,
@@ -803,5 +844,8 @@ default_options = dict(
     compiler_directives = {},
     evaluate_tree_assertions = False,
     emit_linenums = False,
+    relative_path_in_code_position_comments = True,
+    c_line_in_traceback = True,
     language_level = 2,
+    gdb_debug = False,
 )