source code encoding support (PEP 263) and UTF-8 default source encoding (PEP 3120)
authorStefan Behnel <scoder@users.berlios.de>
Tue, 22 Apr 2008 14:37:33 +0000 (16:37 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Tue, 22 Apr 2008 14:37:33 +0000 (16:37 +0200)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Main.py
Cython/Compiler/ModuleNode.py
Cython/Compiler/Nodes.py
Cython/Compiler/Parsing.py
Cython/Compiler/PyrexTypes.py
Cython/Compiler/Scanning.py
Cython/Compiler/Symtab.py

index e79ae49ca6c6519696525782061dc634e9ea7948..c2f5a475ef905b67718b5eec63accc56377c746b 100644 (file)
@@ -18,6 +18,29 @@ from Cython.Debugging import print_call_chain
 from DebugFlags import debug_disposal_code, debug_temp_alloc, \
     debug_coercion
 
+class EncodedString(unicode):
+    # unicode string subclass to keep track of the original encoding.
+    # 'encoding' is None for unicode strings and the source encoding
+    # otherwise
+    encoding = None
+
+    def byteencode(self):
+        assert self.encoding is not None
+        return self.encode(self.encoding)
+
+    def utf8encode(self):
+        assert self.encoding is None
+        return self.encode("UTF-8")
+
+    def is_unicode(self):
+        return self.encoding is None
+    is_unicode = property(is_unicode)
+
+#    def __eq__(self, other):
+#        return unicode.__eq__(self, other) and \
+#            getattr(other, 'encoding', '') == self.encoding
+
+
 class ExprNode(Node):
     #  subexprs     [string]     Class var holding names of subexpr node attrs
     #  type         PyrexType    Type of the result
@@ -696,15 +719,16 @@ class StringNode(ConstNode):
     type = PyrexTypes.c_char_ptr_type
 
     def compile_time_value(self, denv):
-        return eval('"%s"' % self.value)
+        return self.value
     
     def analyse_types(self, env):
         self.entry = env.add_string_const(self.value)
     
     def coerce_to(self, dst_type, env):
         if dst_type.is_int:
-            if not self.type.is_pyobject and len(self.value) == 1:
-                return CharNode(self.pos, value=self.value)
+            if not self.type.is_pyobject and len(self.entry.init) == 1:
+                # we use the *encoded* value here
+                return CharNode(self.pos, value=self.entry.init)
             else:
                 error(self.pos, "Only coerce single-character ascii strings can be used as ints.")
                 return self
index ca8dc88d7387a6549a44a6cf803c37da111b0215..c49d170e4cdccaab0a1ba716a372e38c33773f88 100644 (file)
@@ -2,12 +2,11 @@
 #   Cython Top Level
 #
 
-import os, sys, re
+import os, sys, re, codecs
 if sys.version_info[:2] < (2, 2):
     print >>sys.stderr, "Sorry, Cython requires Python 2.2 or later"
     sys.exit(1)
 
-import os
 from time import time
 import Version
 from Scanning import PyrexScanner
@@ -138,10 +137,27 @@ class Context:
             self.modules[name] = scope
         return scope
 
+    match_file_encoding = re.compile("coding[:=]\s*([-\w.]+)").search
+
+    def detect_file_encoding(self, source_filename):
+        # PEPs 263 and 3120
+        f = codecs.open(source_filename, "rU", encoding="UTF-8")
+        try:
+            for line_no, line in enumerate(f):
+                encoding = self.match_file_encoding(line)
+                if encoding:
+                    return encoding.group(1)
+                if line_no == 1:
+                    break
+        finally:
+            f.close()
+        return "UTF-8"
+
     def parse(self, source_filename, type_names, pxd, full_module_name):
         # Parse the given source file and return a parse tree.
-        f = open(source_filename, "rU")
-        s = PyrexScanner(f, source_filename, 
+        encoding = self.detect_file_encoding(source_filename)
+        f = codecs.open(source_filename, "rU", encoding=encoding)
+        s = PyrexScanner(f, source_filename, source_encoding = encoding,
             type_names = type_names, context = self)
         try:
             tree = Parsing.p_module(s, pxd, full_module_name)
index 73ec489a58805e503478ef1764617d3d0c91fc02..4a58885a527fba9e99e28ce4703b727bdd55c6b6 100644 (file)
@@ -1270,7 +1270,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
                         entry.pystring_cname,
                         entry.cname,
                         entry.cname,
-                        isinstance(entry.init, unicode)
+                        entry.type.is_unicode
                         ))
             code.putln(
                 "{0, 0, 0, 0}")
index eb593d2e0ab525c7155aeb5d96c2f54ef40ab7e7..c17f332a32bbd092326e9e4ff464ac990d1b1b0e 100644 (file)
@@ -1199,7 +1199,7 @@ class DefNode(FuncDefNode):
     # args          [CArgDeclNode]         formal arguments
     # star_arg      PyArgDeclNode or None  * argument
     # starstar_arg  PyArgDeclNode or None  ** argument
-    # doc           string or None
+    # doc           EncodedString or None
     # body          StatListNode
     #
     #  The following subnode is constructed internally
@@ -1358,12 +1358,15 @@ class DefNode(FuncDefNode):
         entry.pymethdef_cname = \
             Naming.pymethdef_prefix + prefix + name
         if not Options.docstrings:
-            self.entry.doc = None
+            entry.doc = None
         else:
             if Options.embed_pos_in_docstring:
-                entry.doc = 'File: %s (starting at line %s)'%relative_position(self.pos)
+                doc = u'File: %s (starting at line %s)'%relative_position(self.pos)
                 if not self.doc is None:
-                    entry.doc = entry.doc + '\\n' + self.doc
+                    doc = doc + u'\\n' + self.doc
+                doc = ExprNodes.EncodedString(doc)
+                doc.encoding = self.doc.encoding
+                entry.doc = doc
             else:
                 entry.doc = self.doc
             entry.doc_cname = \
@@ -1920,8 +1923,9 @@ class PyClassDefNode(StatNode, BlockNode):
         self.dict = ExprNodes.DictNode(pos, key_value_pairs = [])
         if self.doc and Options.docstrings:
             if Options.embed_pos_in_docstring:
-                doc = 'File: %s (starting at line %s)'%relative_position(self.pos)
-                doc = doc + '\\n' + self.doc
+                doc = u'File: %s (starting at line %s)'%relative_position(self.pos)
+                doc = ExprNodes.EncodedString(doc + 'u\\n' + self.doc)
+                doc.encoding = self.doc.encoding
             doc_node = ExprNodes.StringNode(pos, value = doc)
         else:
             doc_node = None
@@ -2073,7 +2077,7 @@ class PropertyNode(StatNode):
     #  Definition of a property in an extension type.
     #
     #  name   string
-    #  doc    string or None    Doc string
+    #  doc    EncodedString or None    Doc string
     #  body   StatListNode
     
     child_attrs = ["body"]
index 93491c9e0a382e00b02c2c22a17e36f85de87de6..c817ec3709ab2caac4c7805964f8d0d4f8cfba7d 100644 (file)
@@ -281,8 +281,10 @@ def p_call(s, function):
             if not arg.is_name:
                 s.error("Expected an identifier before '='",
                     pos = arg.pos)
+            encoded_name = ExprNodes.EncodedString(arg.name)
+            encoded_name.encoding = s.source_encoding
             keyword = ExprNodes.StringNode(arg.pos, 
-                value = arg.name)
+                value = encoded_name)
             arg = p_simple_expr(s)
             keyword_args.append((keyword, arg))
         else:
@@ -459,7 +461,7 @@ def p_atom(s):
         value = s.systring[:-1]
         s.next()
         return ExprNodes.ImagNode(pos, value = value)
-    elif sy == 'STRING' or sy == 'BEGIN_STRING':
+    elif sy == 'BEGIN_STRING':
         kind, value = p_cat_string_literal(s)
         if kind == 'c':
             return ExprNodes.CharNode(pos, value = value)
@@ -500,7 +502,12 @@ def p_name(s, name):
             elif isinstance(value, float):
                 return ExprNodes.FloatNode(pos, value = rep)
             elif isinstance(value, str):
-                return ExprNodes.StringNode(pos, value = rep[1:-1])
+                sval = ExprNodes.EncodedString(rep[1:-1])
+                sval.encoding = value.encoding
+                return ExprNodes.StringNode(pos, value = sval)
+            elif isinstance(value, unicode):
+                sval = ExprNodes.EncodedString(rep[2:-1])
+                return ExprNodes.StringNode(pos, value = sval)
             else:
                 error(pos, "Invalid type for compile-time constant: %s"
                     % value.__class__.__name__)
@@ -508,21 +515,25 @@ def p_name(s, name):
 
 def p_cat_string_literal(s):
     # A sequence of one or more adjacent string literals.
-    # Returns (kind, value) where kind in ('', 'c', 'r')
+    # Returns (kind, value) where kind in ('', 'c', 'r', 'u')
     kind, value = p_string_literal(s)
     if kind != 'c':
         strings = [value]
-        while s.sy == 'STRING' or s.sy == 'BEGIN_STRING':
+        while s.sy == 'BEGIN_STRING':
             next_kind, next_value = p_string_literal(s)
             if next_kind == 'c':
                 self.error(
                     "Cannot concatenate char literal with another string or char literal")
+            elif next_kind == 'u':
+                kind = 'u'
             strings.append(next_value)
-        value = ''.join(strings)
+        value = ExprNodes.EncodedString( u''.join(strings) )
+        if kind != 'u':
+            value.encoding = s.source_encoding
     return kind, value
 
 def p_opt_string_literal(s):
-    if s.sy == 'STRING' or s.sy == 'BEGIN_STRING':
+    if s.sy == 'BEGIN_STRING':
         return p_string_literal(s)
     else:
         return None
@@ -530,10 +541,6 @@ def p_opt_string_literal(s):
 def p_string_literal(s):
     # A single string or char literal.
     # Returns (kind, value) where kind in ('', 'c', 'r', 'u')
-    if s.sy == 'STRING':
-        value = unquote(s.systring)
-        s.next()
-        return value
     # s.sy == 'BEGIN_STRING'
     pos = s.position()
     #is_raw = s.systring[:1].lower() == "r"
@@ -549,8 +556,6 @@ def p_string_literal(s):
             systr = s.systring
             if len(systr) == 1 and systr in "'\"\n":
                 chars.append('\\')
-            if kind == 'u' and not isinstance(systr, unicode):
-                systr = systr.decode("UTF-8")
             chars.append(systr)
         elif sy == 'ESCAPE':
             systr = s.systring
@@ -572,7 +577,8 @@ def p_string_literal(s):
                 elif c in 'ux':
                     if kind == 'u':
                         try:
-                            chars.append(systr.decode('unicode_escape'))
+                            chars.append(
+                                systr.encode("ASCII").decode('unicode_escape'))
                         except UnicodeDecodeError:
                             s.error("Invalid unicode escape '%s'" % systr,
                                     pos = pos)
@@ -593,50 +599,12 @@ def p_string_literal(s):
                 "Unexpected token %r:%r in string literal" %
                     (sy, s.systring))
     s.next()
-    value = ''.join(chars)
+    value = ExprNodes.EncodedString( u''.join(chars) )
+    if kind != 'u':
+        value.encoding = s.source_encoding
     #print "p_string_literal: value =", repr(value) ###
     return kind, value
 
-def unquote(s):
-    is_raw = 0
-    if s[:1].lower() == "r":
-        is_raw = 1
-        s = s[1:]
-    q = s[:3]
-    if q == '"""' or q == "'''":
-        s = s[3:-3]
-    else:
-        s = s[1:-1]
-    if is_raw:
-        s = s.replace('\\', '\\\\')
-        s = s.replace('\n', '\\\n')
-    else:
-        # Split into double quotes, newlines, escape sequences 
-        # and spans of regular chars
-        l1 = re.split(r'((?:\\[0-7]{1,3})|(?:\\x[0-9A-Fa-f]{2})|(?:\\.)|(?:\\\n)|(?:\n)|")', s)
-        #print "unquote: l1 =", l1 ###
-        l2 = []
-        for item in l1:
-            if item == '"' or item == '\n':
-                l2.append('\\' + item)
-            elif item == '\\\n':
-                pass
-            elif item[:1] == '\\':
-                if len(item) == 2:
-                    if item[1] in '"\\abfnrtv':
-                        l2.append(item)
-                    else:
-                        l2.append(item[1])
-                elif item[1:2] == 'x':
-                    l2.append('\\x0' + item[2:])
-                else:
-                    # octal escape
-                    l2.append(item)
-            else:
-                l2.append(item)
-        s = "".join(l2)
-    return s
-        
 # list_display         ::=     "[" [listmaker] "]"
 # listmaker    ::=     expression ( list_for | ( "," expression )* [","] )
 # list_iter    ::=     list_for | list_if
@@ -946,6 +914,8 @@ def p_import_statement(s):
                     ExprNodes.StringNode(pos, value = "*")])
             else:
                 name_list = None
+            dotted_name = ExprNodes.EncodedString(dotted_name)
+            dotted_name.encoding = s.source_encoding
             stat = Nodes.SingleAssignmentNode(pos,
                 lhs = ExprNodes.NameNode(pos, 
                     name = as_name or target_name),
@@ -984,14 +954,18 @@ def p_from_import_statement(s):
         imported_name_strings = []
         items = []
         for (name_pos, name, as_name) in imported_names:
+            encoded_name = ExprNodes.EncodedString(name)
+            encoded_name.encoding = s.source_encoding
             imported_name_strings.append(
-                ExprNodes.StringNode(name_pos, value = name))
+                ExprNodes.StringNode(name_pos, value = encoded_name))
             items.append(
                 (name,
                  ExprNodes.NameNode(name_pos, 
                        name = as_name or name)))
         import_list = ExprNodes.ListNode(
             imported_names[0][0], args = imported_name_strings)
+        dotted_name = ExprNodes.EncodedString(dotted_name)
+        dotted_name.encoding = s.source_encoding
         return Nodes.FromImportStatNode(pos,
             module = ExprNodes.ImportNode(dotted_name_pos,
                 module_name = ExprNodes.StringNode(dotted_name_pos,
@@ -1996,7 +1970,8 @@ def p_class_statement(s):
     # s.sy == 'class'
     pos = s.position()
     s.next()
-    class_name = p_ident(s)
+    class_name = ExprNodes.EncodedString( p_ident(s) )
+    class_name.encoding = s.source_encoding
     if s.sy == '(':
         s.next()
         base_list = p_simple_expr_list(s)
@@ -2113,7 +2088,7 @@ def p_property_decl(s):
     return Nodes.PropertyNode(pos, name = name, doc = doc, body = body)
 
 def p_doc_string(s):
-    if s.sy == 'STRING' or s.sy == 'BEGIN_STRING':
+    if s.sy == 'BEGIN_STRING':
         _, result = p_cat_string_literal(s)
         if s.sy != 'EOF':
             s.expect_newline("Syntax error in doc string")
index bf3e6f9f8a3a9d4d26faf65a90687f980112a95f..d7427d4cbfc603b2be780f809ee46993cef2c5c8 100644 (file)
@@ -37,6 +37,7 @@ class PyrexType(BaseType):
     #  is_enum               boolean     Is a C enum type
     #  is_typedef            boolean     Is a typedef type
     #  is_string             boolean     Is a C char * type
+    #  is_unicode            boolean     Is a UTF-8 encoded C char * type
     #  is_returncode         boolean     Is used only to signal exceptions
     #  is_error              boolean     Is the dummy error type
     #  has_attributes        boolean     Has C dot-selectable attributes
@@ -83,6 +84,7 @@ class PyrexType(BaseType):
     is_enum = 0
     is_typedef = 0
     is_string = 0
+    is_unicode = 0
     is_returncode = 0
     is_error = 0
     has_attributes = 0
@@ -875,19 +877,49 @@ class CEnumType(CType):
             return self.base_declaration_code(public_decl(base, dll_linkage), entity_code)
 
 
+def _escape_byte_string(s):
+    try:
+        s.decode("ASCII")
+        return s
+    except UnicodeDecodeError:
+        pass
+    l = []
+    append = l.append
+    for c in s:
+        o = ord(c)
+        if o >= 128:
+            append('\\x%X' % o)
+        else:
+            append(c)
+    return ''.join(l)
+
 class CStringType:
     #  Mixin class for C string types.
 
     is_string = 1
+    is_unicode = 0
     
     to_py_function = "PyString_FromString"
     from_py_function = "PyString_AsString"
     exception_value = "NULL"
 
     def literal_code(self, value):
-        if isinstance(value, unicode):
-            value = value.encode("UTF-8")
-        return '"%s"' % value
+        assert isinstance(value, str)
+        return '"%s"' % _escape_byte_string(value)
+
+
+class CUTF8StringType:
+    #  Mixin class for C unicode types.
+
+    is_string = 1
+    is_unicode = 1
+    
+    to_py_function = "PyUnicode_DecodeUTF8"
+    exception_value = "NULL"
+
+    def literal_code(self, value):
+        assert isinstance(value, str)
+        return '"%s"' % _escape_byte_string(value)
 
 
 class CCharArrayType(CStringType, CArrayType):
@@ -898,6 +930,16 @@ class CCharArrayType(CStringType, CArrayType):
     
     def __init__(self, size):
         CArrayType.__init__(self, c_char_type, size)
+
+
+class CUTF8CharArrayType(CUTF8StringType, CArrayType):
+    #  C 'char []' type.
+    
+    parsetuple_format = "s"
+    pymemberdef_typecode = "T_STRING_INPLACE"
+    
+    def __init__(self, size):
+        CArrayType.__init__(self, c_char_type, size)
     
 
 class CCharPtrType(CStringType, CPtrType):
@@ -910,6 +952,16 @@ class CCharPtrType(CStringType, CPtrType):
         CPtrType.__init__(self, c_char_type)
 
 
+class CUTF8CharPtrType(CUTF8StringType, CPtrType):
+    # C 'char *' type, encoded in UTF-8.
+    
+    parsetuple_format = "s"
+    pymemberdef_typecode = "T_STRING"
+    
+    def __init__(self):
+        CPtrType.__init__(self, c_char_type)
+
+
 class ErrorType(PyrexType):
     # Used to prevent propagation of error messages.
     
@@ -974,7 +1026,9 @@ c_longdouble_type =  CFloatType(8)
 
 c_null_ptr_type =     CNullPtrType(c_void_type)
 c_char_array_type =   CCharArrayType(None)
+c_utf8_char_array_type =   CUTF8CharArrayType(None)
 c_char_ptr_type =     CCharPtrType()
+c_utf8_char_ptr_type =     CUTF8CharPtrType()
 c_char_ptr_ptr_type = CPtrType(c_char_ptr_type)
 c_int_ptr_type =      CPtrType(c_int_type)
 
index e48c8dcea590480bf5c5778ed99f473860579903..e91e343aaadd6c5e1e7a9773e58c900ef603083d 100644 (file)
@@ -212,7 +212,7 @@ class PyrexScanner(Scanner):
     resword_dict = build_resword_dict()
 
     def __init__(self, file, filename, parent_scanner = None, 
-            type_names = None, context = None):
+            type_names = None, context = None, source_encoding=None):
         Scanner.__init__(self, get_lexicon(), file, filename)
         if parent_scanner:
             self.context = parent_scanner.context
@@ -226,6 +226,7 @@ class PyrexScanner(Scanner):
             self.compile_time_env = initial_compile_time_env()
             self.compile_time_eval = 1
             self.compile_time_expr = 0
+        self.source_encoding = source_encoding
         self.trace = trace_scanner
         self.indentation_stack = [0]
         self.indentation_char = None
index 0a6e429364a46203fe604e6d0fd357c4e571f098..4140d80f50dadcd7607cd8be5ea1d0172f342416 100644 (file)
@@ -434,15 +434,21 @@ class Scope:
         if not entry:
             entry = self.declare_var(name, py_object_type, None)
         return entry
-    
+
     def add_string_const(self, value):
         # Add an entry for a string constant.
         cname = self.new_const_cname()
-        entry = Entry("", cname, c_char_array_type, init = value)
+        if value.is_unicode:
+            c_type = c_utf8_char_array_type
+            value = value.utf8encode()
+        else:
+            c_type = c_char_array_type
+            value = value.byteencode()
+        entry = Entry("", cname, c_type, init = value)
         entry.used = 1
         self.const_entries.append(entry)
         return entry
-    
+
     def get_string_const(self, value):
         # Get entry for string constant. Returns an existing
         # one if possible, otherwise creates a new one.
@@ -452,7 +458,7 @@ class Scope:
             entry = self.add_string_const(value)
             genv.string_to_entry[value] = entry
         return entry
-    
+
     def add_py_string(self, entry):
         # If not already done, allocate a C name for a Python version of
         # a string literal, and add it to the list of Python strings to
@@ -460,7 +466,7 @@ class Scope:
         # Python identifier, it will be interned.
         if not entry.pystring_cname:
             value = entry.init
-            if identifier_pattern.match(value) and isinstance(value, str):
+            if not entry.type.is_unicode and identifier_pattern.match(value):
                 entry.pystring_cname = self.intern(value)
                 entry.is_interned = 1
             else: