From: Dag Sverre Seljebotn <dagss@student.matnat.uio.no>
Date: Fri, 16 May 2008 14:12:20 +0000 (+0200)
Subject: Replace filename strings with more generic source descriptors.
X-Git-Tag: 0.9.8rc1~11^2~10^2~25
X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=0221862606e0f2f1d1ef94080bfbfe1abf8b44bf;p=cython.git

Replace filename strings with more generic source descriptors.

This facilitates using the parser and compiler with runtime sources (such as
strings), while still being able to provide context for error messages/C debugging comments.
---

diff --git a/Cython/Compiler/Code.py b/Cython/Compiler/Code.py
index d107248f..aead4a05 100644
--- a/Cython/Compiler/Code.py
+++ b/Cython/Compiler/Code.py
@@ -8,6 +8,7 @@ import Options
 from Cython.Utils import open_new_file, open_source_file
 from PyrexTypes import py_object_type, typecast
 from TypeSlots import method_coexist
+from Scanning import SourceDescriptor
 
 class CCodeWriter:
     # f                file            output file
@@ -89,21 +90,22 @@ class CCodeWriter:
     def get_py_version_hex(self, pyversion):
         return "0x%02X%02X%02X%02X" % (tuple(pyversion) + (0,0,0,0))[:4]
 
-    def file_contents(self, file):
+    def file_contents(self, source_desc):
         try:
-            return self.input_file_contents[file]
+            return self.input_file_contents[source_desc]
         except KeyError:
             F = [line.encode('ASCII', 'replace').replace(
                     '*/', '*[inserted by cython to avoid comment closer]/')
-                 for line in open_source_file(file)]
-            self.input_file_contents[file] = F
+                 for line in source_desc.get_lines(decode=True)]
+            self.input_file_contents[source_desc] = F
             return F
 
     def mark_pos(self, pos):
         if pos is None:
             return
-        filename, line, col = pos
-        contents = self.file_contents(filename)
+        source_desc, line, col = pos
+        assert isinstance(source_desc, SourceDescriptor)
+        contents = self.file_contents(source_desc)
 
         context = ''
         for i in range(max(0,line-3), min(line+2, len(contents))):
@@ -112,7 +114,7 @@ class CCodeWriter:
                 s = s.rstrip() + '             # <<<<<<<<<<<<<< ' + '\n'
             context += " * " + s
 
-        marker = '"%s":%d\n%s' % (filename.encode('ASCII', 'replace'), line, context)
+        marker = '"%s":%d\n%s' % (str(source_desc).encode('ASCII', 'replace'), line, context)
         if self.last_marker != marker:
             self.marker = marker
 
diff --git a/Cython/Compiler/Errors.py b/Cython/Compiler/Errors.py
index 736d0326..f01feef2 100644
--- a/Cython/Compiler/Errors.py
+++ b/Cython/Compiler/Errors.py
@@ -12,13 +12,17 @@ class PyrexError(Exception):
 class PyrexWarning(Exception):
     pass
 
+
 def context(position):
-    F = open(position[0]).readlines()
-    s = ''.join(F[position[1]-6:position[1]])
+    source = position[0]
+    assert not (isinstance(source, unicode) or isinstance(source, str)), (
+        "Please replace filename strings with Scanning.FileSourceDescriptor instances %r" % source)
+    F = list(source.get_lines())
+    s = ''.join(F[min(0, position[1]-6):position[1]])
     s += ' '*(position[2]-1) + '^'
     s = '-'*60 + '\n...\n' + s + '\n' + '-'*60 + '\n'
     return s
-
+    
 class CompileError(PyrexError):
     
     def __init__(self, position = None, message = ""):
diff --git a/Cython/Compiler/Main.py b/Cython/Compiler/Main.py
index f0e1beb6..f44c9b79 100644
--- a/Cython/Compiler/Main.py
+++ b/Cython/Compiler/Main.py
@@ -9,7 +9,7 @@ if sys.version_info[:2] < (2, 2):
 
 from time import time
 import Version
-from Scanning import PyrexScanner
+from Scanning import PyrexScanner, FileSourceDescriptor
 import Errors
 from Errors import PyrexError, CompileError, error
 import Parsing
@@ -85,7 +85,8 @@ class Context:
                 try:
                     if debug_find_module:
                         print("Context.find_module: Parsing %s" % pxd_pathname)
-                    pxd_tree = self.parse(pxd_pathname, scope.type_names, pxd = 1,
+                    source_desc = FileSourceDescriptor(pxd_pathname)
+                    pxd_tree = self.parse(source_desc, scope.type_names, pxd = 1,
                                           full_module_name = module_name)
                     pxd_tree.analyse_declarations(scope)
                 except CompileError:
@@ -116,7 +117,10 @@ class Context:
         # None if not found, but does not report an error.
         dirs = self.include_directories
         if pos:
-            here_dir = os.path.dirname(pos[0])
+            file_desc = pos[0]
+            if not isinstance(file_desc, FileSourceDescriptor):
+                raise RuntimeError("Only file sources for code supported")
+            here_dir = os.path.dirname(file_desc.filename)
             dirs = [here_dir] + dirs
         for dir in dirs:
             path = os.path.join(dir, filename)
@@ -137,19 +141,21 @@ class Context:
             self.modules[name] = scope
         return scope
 
-    def parse(self, source_filename, type_names, pxd, full_module_name):
-        name = Utils.encode_filename(source_filename)
+    def parse(self, source_desc, type_names, pxd, full_module_name):
+        if not isinstance(source_desc, FileSourceDescriptor):
+            raise RuntimeError("Only file sources for code supported")
+        source_filename = Utils.encode_filename(source_desc.filename)
         # Parse the given source file and return a parse tree.
         try:
             f = Utils.open_source_file(source_filename, "rU")
             try:
-                s = PyrexScanner(f, name, source_encoding = f.encoding,
+                s = PyrexScanner(f, source_desc, source_encoding = f.encoding,
                                  type_names = type_names, context = self)
                 tree = Parsing.p_module(s, pxd, full_module_name)
             finally:
                 f.close()
         except UnicodeDecodeError, msg:
-            error((name, 0, 0), "Decoding error, missing or incorrect coding=<encoding-name> at top of source (%s)" % msg)
+            error((source_desc, 0, 0), "Decoding error, missing or incorrect coding=<encoding-name> at top of source (%s)" % msg)
         if Errors.num_errors > 0:
             raise CompileError
         return tree
@@ -197,6 +203,7 @@ class Context:
             except EnvironmentError:
                 pass
         module_name = full_module_name # self.extract_module_name(source, options)
+        source = FileSourceDescriptor(source)
         initial_pos = (source, 1, 0)
         scope = self.find_module(module_name, pos = initial_pos, need_pxd = 0)
         errors_occurred = False
@@ -339,6 +346,8 @@ def main(command_line = 0):
     if any_failures:
         sys.exit(1)
 
+
+
 #------------------------------------------------------------------------
 #
 #  Set the default options depending on the platform
diff --git a/Cython/Compiler/ModuleNode.py b/Cython/Compiler/ModuleNode.py
index 2acb1016..15f397fa 100644
--- a/Cython/Compiler/ModuleNode.py
+++ b/Cython/Compiler/ModuleNode.py
@@ -427,8 +427,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
         code.putln("")
         code.putln("static char *%s[] = {" % Naming.filenames_cname)
         if code.filename_list:
-            for filename in code.filename_list:
-                filename = os.path.basename(filename)
+            for source_desc in code.filename_list:
+                filename = os.path.basename(str(source_desc))
                 escaped_filename = filename.replace("\\", "\\\\").replace('"', r'\"')
                 code.putln('"%s",' % 
                     escaped_filename)
diff --git a/Cython/Compiler/Parsing.py b/Cython/Compiler/Parsing.py
index b8ad781b..2a4b44b6 100644
--- a/Cython/Compiler/Parsing.py
+++ b/Cython/Compiler/Parsing.py
@@ -5,7 +5,7 @@
 import os, re
 from string import join, replace
 from types import ListType, TupleType
-from Scanning import PyrexScanner
+from Scanning import PyrexScanner, FileSourceDescriptor
 import Nodes
 import ExprNodes
 from ModuleNode import ModuleNode
@@ -1182,7 +1182,8 @@ def p_include_statement(s, level):
         include_file_path = s.context.find_include_file(include_file_name, pos)
         if include_file_path:
             f = Utils.open_source_file(include_file_path, mode="rU")
-            s2 = PyrexScanner(f, include_file_path, s, source_encoding=f.encoding)
+            source_desc = FileSourceDescriptor(include_file_path)
+            s2 = PyrexScanner(f, source_desc, s, source_encoding=f.encoding)
             try:
                 tree = p_statement_list(s2, level)
             finally:
diff --git a/Cython/Compiler/Scanning.py b/Cython/Compiler/Scanning.py
index 278b8a73..15ab7938 100644
--- a/Cython/Compiler/Scanning.py
+++ b/Cython/Compiler/Scanning.py
@@ -17,6 +17,8 @@ from Cython.Plex.Errors import UnrecognizedInput
 from Errors import CompileError, error
 from Lexicon import string_prefixes, make_lexicon
 
+from Cython import Utils
+
 plex_version = getattr(Plex, '_version', None)
 #print "Plex version:", plex_version ###
 
@@ -203,6 +205,57 @@ def initial_compile_time_env():
 
 #------------------------------------------------------------------
 
+class SourceDescriptor:
+    pass
+
+class FileSourceDescriptor(SourceDescriptor):
+    """
+    Represents a code source. A code source is a more generic abstraction
+    for a "filename" (as sometimes the code doesn't come from a file).
+    Instances of code sources are passed to Scanner.__init__ as the
+    optional name argument and will be passed back when asking for
+    the position()-tuple.
+    """
+    def __init__(self, filename):
+        self.filename = filename
+    
+    def get_lines(self, decode=False):
+        # decode is True when called from Code.py (which reserializes in a standard way to ASCII),
+        # while decode is False when called from Errors.py.
+        #
+        # Note that if changing Errors.py in this respect, raising errors over wrong encoding
+        # will no longer be able to produce the line where the encoding problem occurs ...
+        if decode:
+            return Utils.open_source_file(self.filename)
+        else:
+            return open(self.filename)
+    
+    def __str__(self):
+        return self.filename
+    
+    def __repr__(self):
+        return "<FileSourceDescriptor:%s>" % self
+
+class StringSourceDescriptor(SourceDescriptor):
+    """
+    Instances of this class can be used instead of a filenames if the
+    code originates from a string object.
+    """
+    def __init__(self, name, code):
+        self.name = name
+        self.codelines = [x + "\n" for x in code.split("\n")]
+    
+    def get_lines(self, decode=False):
+        return self.codelines
+    
+    def __str__(self):
+        return self.name
+
+    def __repr__(self):
+        return "<StringSourceDescriptor:%s>" % self
+
+#------------------------------------------------------------------
+
 class PyrexScanner(Scanner):
     #  context            Context  Compilation context
     #  type_names         set      Identifiers to be treated as type names