enable __future__ imports
authorStefan Behnel <scoder@users.berlios.de>
Thu, 15 May 2008 16:49:13 +0000 (18:49 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Thu, 15 May 2008 16:49:13 +0000 (18:49 +0200)
Cython/Compiler/Future.py [new file with mode: 0644]
Cython/Compiler/Main.py
Cython/Compiler/Parsing.py
tests/run/future_unicode_literals.pyx [new file with mode: 0644]

diff --git a/Cython/Compiler/Future.py b/Cython/Compiler/Future.py
new file mode 100644 (file)
index 0000000..0028d13
--- /dev/null
@@ -0,0 +1,5 @@
+import __future__
+
+unicode_literals = __future__.unicode_literals
+
+del __future__
index f0e1beb6c9723e685650ddec083d8f1db6648459..0782f52f43c8a2e0a73615fb6e1e8d84d3add49d 100644 (file)
@@ -7,6 +7,12 @@ if sys.version_info[:2] < (2, 2):
     sys.stderr.write("Sorry, Cython requires Python 2.2 or later\n")
     sys.exit(1)
 
+try:
+    set
+except NameError:
+    # Python 2.3
+    from sets import Set as set
+
 from time import time
 import Version
 from Scanning import PyrexScanner
@@ -30,12 +36,14 @@ class Context:
     #
     #  modules               {string : ModuleScope}
     #  include_directories   [string]
+    #  future_directives     [object]
     
     def __init__(self, include_directories):
         #self.modules = {"__builtin__" : BuiltinScope()}
         import Builtin
         self.modules = {"__builtin__" : Builtin.builtin_scope}
         self.include_directories = include_directories
+        self.future_directives = set()
         
     def find_module(self, module_name, 
             relative_to = None, pos = None, need_pxd = 1):
index 49019fcb9ada764ae97f28f16f78839418c83cc2..68179b12d3988cab6765f96b047cd6480aac5408 100644 (file)
@@ -11,6 +11,7 @@ import ExprNodes
 from ModuleNode import ModuleNode
 from Errors import error, InternalError
 from Cython import Utils
+import Future
 
 def p_ident(s, message = "Expected an identifier"):
     if s.sy == 'IDENT':
@@ -543,6 +544,12 @@ def p_string_literal(s):
     kind = s.systring[:1].lower()
     if kind not in "cru":
         kind = ''
+    if Future.unicode_literals in s.context.future_directives:
+        if kind == '':
+            kind = 'u'
+        elif kind == 'u':
+            s.error("string literal must not start with 'u' when importing __future__.unicode_literals")
+            return ('u', '')
     chars = []
     while 1:
         s.next()
@@ -921,7 +928,7 @@ def p_import_statement(s):
         stats.append(stat)
     return Nodes.StatListNode(pos, stats = stats)
 
-def p_from_import_statement(s):
+def p_from_import_statement(s, first_statement = 0):
     # s.sy == 'from'
     pos = s.position()
     s.next()
@@ -938,7 +945,19 @@ def p_from_import_statement(s):
     while s.sy == ',':
         s.next()
         imported_names.append(p_imported_name(s))
-    if kind == 'cimport':
+    if dotted_name == '__future__':
+        if not first_statement:
+            s.error("from __future__ imports must occur at the beginning of the file")
+        else:
+            for (name_pos, name, as_name) in imported_names:
+                try:
+                    directive = getattr(Future, name)
+                except AttributeError:
+                    s.error("future feature %s is not defined" % name)
+                    break
+                s.context.future_directives.add(directive)
+        return Nodes.PassStatNode(pos)
+    elif kind == 'cimport':
         for (name_pos, name, as_name) in imported_names:
             local_name = as_name or name
             s.add_type_name(local_name)
@@ -1200,7 +1219,7 @@ def p_with_statement(s):
         s.error("Only 'with gil' and 'with nogil' implemented",
                 pos = pos)
     
-def p_simple_statement(s):
+def p_simple_statement(s, first_statement = 0):
     #print "p_simple_statement:", s.sy, s.systring ###
     if s.sy == 'global':
         node = p_global_statement(s)
@@ -1219,7 +1238,7 @@ def p_simple_statement(s):
     elif s.sy in ('import', 'cimport'):
         node = p_import_statement(s)
     elif s.sy == 'from':
-        node = p_from_import_statement(s)
+        node = p_from_import_statement(s, first_statement = first_statement)
     elif s.sy == 'assert':
         node = p_assert_statement(s)
     elif s.sy == 'pass':
@@ -1228,10 +1247,10 @@ def p_simple_statement(s):
         node = p_expression_or_assignment(s)
     return node
 
-def p_simple_statement_list(s):
+def p_simple_statement_list(s, first_statement = 0):
     # Parse a series of simple statements on one line
     # separated by semicolons.
-    stat = p_simple_statement(s)
+    stat = p_simple_statement(s, first_statement = first_statement)
     if s.sy == ';':
         stats = [stat]
         while s.sy == ';':
@@ -1291,7 +1310,8 @@ def p_IF_statement(s, level, cdef_flag, visibility, api):
     s.compile_time_eval = saved_eval
     return result
 
-def p_statement(s, level, cdef_flag = 0, visibility = 'private', api = 0):
+def p_statement(s, level, cdef_flag = 0, visibility = 'private', api = 0,
+                first_statement = 0):
     if s.sy == 'ctypedef':
         if level not in ('module', 'module_pxd'):
             s.error("ctypedef statement not allowed here")
@@ -1354,16 +1374,18 @@ def p_statement(s, level, cdef_flag = 0, visibility = 'private', api = 0):
                 elif s.sy == 'with':
                     return p_with_statement(s)
                 else:
-                    return p_simple_statement_list(s)
+                    return p_simple_statement_list(s, first_statement = first_statement)
 
 def p_statement_list(s, level,
-        cdef_flag = 0, visibility = 'private', api = 0):
+        cdef_flag = 0, visibility = 'private', api = 0, first_statement = 0):
     # Parse a series of statements separated by newlines.
     pos = s.position()
     stats = []
     while s.sy not in ('DEDENT', 'EOF'):
         stats.append(p_statement(s, level,
-            cdef_flag = cdef_flag, visibility = visibility, api = api))
+            cdef_flag = cdef_flag, visibility = visibility, api = api,
+            first_statement = first_statement))
+        first_statement = 0
     if len(stats) == 1:
         return stats[0]
     else:
@@ -2118,7 +2140,7 @@ def p_module(s, pxd, full_module_name):
         level = 'module_pxd'
     else:
         level = 'module'
-    body = p_statement_list(s, level)
+    body = p_statement_list(s, level, first_statement = 1)
     if s.sy != 'EOF':
         s.error("Syntax error in statement [%s,%s]" % (
             repr(s.sy), repr(s.systring)))
diff --git a/tests/run/future_unicode_literals.pyx b/tests/run/future_unicode_literals.pyx
new file mode 100644 (file)
index 0000000..4e932cb
--- /dev/null
@@ -0,0 +1,19 @@
+from __future__ import unicode_literals
+
+import sys
+if sys.version_info[0] >= 3:
+    __doc__ = """
+    >>> u == 'test'
+    True
+    >>> isinstance(u, str)
+    True
+"""
+else:
+    __doc__ = """
+    >>> u == u'test'
+    True
+    >>> isinstance(u, unicode)
+    True
+"""
+
+u = "test"