merge
[cython.git] / Cython / Compiler / Scanning.py
index ee01e08f9867667d86c83ebbc79d3073e89dd6b2..456b3dce4ef52d415c548bd877a306ae1950a9b5 100644 (file)
@@ -11,15 +11,21 @@ import stat
 import sys
 from time import time
 
+import cython
+cython.declare(EncodedString=object, string_prefixes=object, raw_prefixes=object, IDENT=object)
+
 from Cython import Plex, Utils
-from Cython.Plex import Scanner
+from Cython.Plex.Scanners import Scanner
 from Cython.Plex.Errors import UnrecognizedInput
 from Errors import CompileError, error
-from Lexicon import string_prefixes, raw_prefixes, make_lexicon
+from Lexicon import string_prefixes, raw_prefixes, make_lexicon, IDENT
 
-from Cython import Utils
+from StringEncoding import EncodedString
 
-plex_version = getattr(Plex, '_version', None)
+try:
+    plex_version = Plex._version
+except AttributeError:
+    plex_version = None
 #print "Plex version:", plex_version ###
 
 debug_scanner = 0
@@ -38,7 +44,11 @@ def hash_source_file(path):
     # Try to calculate a hash code for the given source file.
     # Returns an empty string if the file cannot be accessed.
     #print "Hashing", path ###
-    import md5
+    try:
+        from hashlib import md5 as new_md5
+    except ImportError:
+        from md5 import new as new_md5
+    f = None
     try:
         try:
             f = open(path, "rU")
@@ -47,14 +57,15 @@ def hash_source_file(path):
             print("Unable to hash scanner source file (%s)" % e)
             return ""
     finally:
-        f.close()
+        if f:
+            f.close()
     # Normalise spaces/tabs. We don't know what sort of
     # space-tab substitution the file may have been
     # through, so we replace all spans of spaces and
     # tabs by a single space.
     import re
     text = re.sub("[ \t]+", " ", text)
-    hash = md5.new(text).hexdigest()
+    hash = new_md5(text.encode("ASCII")).hexdigest()
     return hash
 
 def open_pickled_lexicon(expected_hash):
@@ -75,7 +86,7 @@ def open_pickled_lexicon(expected_hash):
                 print("Lexicon hash mismatch:")       ###
                 print("   expected " + expected_hash) ###
                 print("   got     " + actual_hash)    ###
-        except IOError, e:
+        except (IOError, pickle.UnpicklingError), e:
             print("Warning: Unable to read pickled lexicon " + lexicon_pickle)
             print(e)
     if f:
@@ -88,12 +99,17 @@ def try_to_unpickle_lexicon():
     source_file = os.path.join(dir, "Lexicon.py")
     lexicon_hash = hash_source_file(source_file)
     lexicon_pickle = os.path.join(dir, "Lexicon.pickle")
-    f = open_pickled_lexicon(expected_hash = lexicon_hash)
+    f = open_pickled_lexicon(lexicon_hash)
     if f:
         if notify_lexicon_unpickling:
             t0 = time()
             print("Unpickling lexicon...")
-        lexicon = pickle.load(f)
+        try:
+            lexicon = pickle.load(f)
+        except Exception, e:
+            print "WARNING: Exception while loading lexicon pickle, regenerating"
+            print e
+            lexicon = None
         f.close()
         if notify_lexicon_unpickling:
             t1 = time()
@@ -141,11 +157,11 @@ reserved_words = [
     "print", "del", "pass", "break", "continue", "return",
     "raise", "import", "exec", "try", "except", "finally",
     "while", "if", "elif", "else", "for", "in", "assert",
-    "and", "or", "not", "is", "in", "lambda", "from",
-    "NULL", "cimport", "by", "with", "cpdef", "DEF", "IF", "ELIF", "ELSE"
+    "and", "or", "not", "is", "in", "lambda", "from", "yield",
+    "cimport", "by", "with", "cpdef", "DEF", "IF", "ELIF", "ELSE"
 ]
 
-class Method:
+class Method(object):
 
     def __init__(self, name):
         self.name = name
@@ -162,6 +178,9 @@ def build_resword_dict():
         d[word] = 1
     return d
 
+cython.declare(resword_dict=object)
+resword_dict = build_resword_dict()
+
 #------------------------------------------------------------------
 
 class CompileTimeScope(object):
@@ -175,6 +194,9 @@ class CompileTimeScope(object):
     
     def lookup_here(self, name):
         return self.entries[name]
+        
+    def __contains__(self, name):
+        return name in self.entries
     
     def lookup(self, name):
         try:
@@ -192,24 +214,29 @@ def initial_compile_time_env():
         'UNAME_VERSION', 'UNAME_MACHINE')
     for name, value in zip(names, platform.uname()):
         benv.declare(name, value)
-    import __builtin__
+    import __builtin__ as builtins
     names = ('False', 'True',
         'abs', 'bool', 'chr', 'cmp', 'complex', 'dict', 'divmod', 'enumerate',
         'float', 'hash', 'hex', 'int', 'len', 'list', 'long', 'map', 'max', 'min',
         'oct', 'ord', 'pow', 'range', 'reduce', 'repr', 'round', 'slice', 'str',
         'sum', 'tuple', 'xrange', 'zip')
     for name in names:
-        benv.declare(name, getattr(__builtin__, name))
+        try:
+            benv.declare(name, getattr(builtins, name))
+        except AttributeError:
+            # ignore, likely Py3
+            pass
     denv = CompileTimeScope(benv)
     return denv
 
 #------------------------------------------------------------------
 
-class SourceDescriptor:
+class SourceDescriptor(object):
     """
     A SourceDescriptor should be considered immutable.
     """
     _escaped_description = None
+    _cmp_name = ''
     def __str__(self):
         assert False # To catch all places where a descriptor is used directly as a filename
     
@@ -219,6 +246,27 @@ class SourceDescriptor:
                 self.get_description().encode('ASCII', 'replace').decode("ASCII")
         return self._escaped_description
 
+    def __gt__(self, other):
+        # this is only used to provide some sort of order
+        try:
+            return self._cmp_name > other._cmp_name
+        except AttributeError:
+            return False
+
+    def __lt__(self, other):
+        # this is only used to provide some sort of order
+        try:
+            return self._cmp_name < other._cmp_name
+        except AttributeError:
+            return False
+
+    def __le__(self, other):
+        # this is only used to provide some sort of order
+        try:
+            return self._cmp_name <= other._cmp_name
+        except AttributeError:
+            return False
+
 class FileSourceDescriptor(SourceDescriptor):
     """
     Represents a code source. A code source is a more generic abstraction
@@ -229,6 +277,7 @@ class FileSourceDescriptor(SourceDescriptor):
     """
     def __init__(self, filename):
         self.filename = filename
+        self._cmp_name = filename
     
     def get_lines(self):
         return Utils.open_source_file(self.filename)
@@ -256,6 +305,7 @@ class StringSourceDescriptor(SourceDescriptor):
     def __init__(self, name, code):
         self.name = name
         self.codelines = [x + "\n" for x in code.split("\n")]
+        self._cmp_name = name
     
     def get_lines(self):
         return self.codelines
@@ -279,30 +329,27 @@ class StringSourceDescriptor(SourceDescriptor):
 
 class PyrexScanner(Scanner):
     #  context            Context  Compilation context
-    #  type_names         set      Identifiers to be treated as type names
     #  included_files     [string] Files included with 'include' statement
     #  compile_time_env   dict     Environment for conditional compilation
     #  compile_time_eval  boolean  In a true conditional compilation context
     #  compile_time_expr  boolean  In a compile-time expression context
-    resword_dict = build_resword_dict()
 
     def __init__(self, file, filename, parent_scanner = None, 
-                 scope = None, context = None, source_encoding=None):
-        Scanner.__init__(self, get_lexicon(), file, filename)
+                 scope = None, context = None, source_encoding=None, parse_comments=True, initial_pos=None):
+        Scanner.__init__(self, get_lexicon(), file, filename, initial_pos)
         if parent_scanner:
             self.context = parent_scanner.context
-            self.type_names = parent_scanner.type_names
             self.included_files = parent_scanner.included_files
             self.compile_time_env = parent_scanner.compile_time_env
             self.compile_time_eval = parent_scanner.compile_time_eval
             self.compile_time_expr = parent_scanner.compile_time_expr
         else:
             self.context = context
-            self.type_names = scope.type_names
             self.included_files = scope.included_files
             self.compile_time_env = initial_compile_time_env()
             self.compile_time_eval = 1
             self.compile_time_expr = 0
+        self.parse_comments = parse_comments
         self.source_encoding = source_encoding
         self.trace = trace_scanner
         self.indentation_stack = [0]
@@ -311,6 +358,10 @@ class PyrexScanner(Scanner):
         self.begin('INDENT')
         self.sy = ''
         self.next()
+
+    def commentline(self, text):
+        if self.parse_comments:
+            self.produce('commentline', text)    
     
     def current_level(self):
         return self.indentation_stack[-1]
@@ -355,7 +406,7 @@ class PyrexScanner(Scanner):
         self.begin('')
         # Indentation within brackets should be ignored.
         #if self.bracket_nesting_level > 0:
-        #      return
+        #    return
         # Check that tabs and spaces are being used consistently.
         if text:
             c = text[0]
@@ -398,15 +449,14 @@ class PyrexScanner(Scanner):
             sy, systring = self.read()
         except UnrecognizedInput:
             self.error("Unrecognized character")
-        if sy == 'IDENT':
-            if systring in self.resword_dict:
+        if sy == IDENT:
+            if systring in resword_dict:
                 sy = systring
             else:
-                systring = Utils.EncodedString(systring)
-                systring.encoding = self.source_encoding
+                systring = EncodedString(systring)
         self.sy = sy
         self.systring = systring
-        if debug_scanner:
+        if False: # debug_scanner:
             _, line, col = self.position()
             if not self.systring or self.sy == self.systring:
                 t = self.sy
@@ -423,12 +473,6 @@ class PyrexScanner(Scanner):
         # This method should be added to Plex
         self.queue.insert(0, (token, value))
     
-    def add_type_name(self, name):
-        self.type_names[name] = 1
-    
-    def looking_at_type_name(self):
-        return self.sy == 'IDENT' and self.systring in self.type_names
-    
     def error(self, message, pos = None, fatal = True):
         if pos is None:
             pos = self.position()
@@ -444,12 +488,12 @@ class PyrexScanner(Scanner):
             self.expected(what, message)
     
     def expect_keyword(self, what, message = None):
-        if self.sy == 'IDENT' and self.systring == what:
+        if self.sy == IDENT and self.systring == what:
             self.next()
         else:
             self.expected(what, message)
     
-    def expected(self, what, message):
+    def expected(self, what, message = None):
         if message:
             self.error(message)
         else: