Py3 fix
authorStefan Behnel <scoder@users.berlios.de>
Sun, 5 Jul 2009 19:30:32 +0000 (21:30 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Sun, 5 Jul 2009 19:30:32 +0000 (21:30 +0200)
Cython/Compiler/StringEncoding.py

index 12ae3588788cae62c9122ff2545af4757b40f10b..b9de328ed0f168a3dd1b0d414b0033dc1e944383 100644 (file)
@@ -3,6 +3,17 @@
 #
 
 import re
+import sys
+
+if sys.version_info[0] >= 3:
+    _str, _bytes = str, bytes
+else:
+    _str, _bytes = unicode, str
+
+empty_bytes = _bytes()
+empty_str = _str()
+
+join_bytes = empty_bytes.join
 
 class UnicodeLiteralBuilder(object):
     """Assemble a unicode string.
@@ -11,10 +22,10 @@ class UnicodeLiteralBuilder(object):
         self.chars = []
 
     def append(self, characters):
-        if isinstance(characters, str):
+        if isinstance(characters, _bytes):
             # this came from a Py2 string literal in the parser code
             characters = characters.decode("ASCII")
-        assert isinstance(characters, unicode), str(type(characters))
+        assert isinstance(characters, _str), str(type(characters))
         self.chars.append(characters)
 
     def append_charval(self, char_number):
@@ -32,9 +43,9 @@ class BytesLiteralBuilder(object):
         self.target_encoding = target_encoding
 
     def append(self, characters):
-        if isinstance(characters, unicode):
+        if isinstance(characters, _str):
             characters = characters.encode(self.target_encoding)
-        assert isinstance(characters, str), str(type(characters))
+        assert isinstance(characters, _bytes), str(type(characters))
         self.chars.append(characters)
 
     def append_charval(self, char_number):
@@ -42,7 +53,7 @@ class BytesLiteralBuilder(object):
 
     def getstring(self):
         # this *must* return a byte string! => fix it in Py3k!!
-        s = BytesLiteral(''.join(self.chars))
+        s = BytesLiteral(join_bytes(self.chars))
         s.encoding = self.target_encoding
         return s
 
@@ -50,7 +61,7 @@ class BytesLiteralBuilder(object):
         # this *must* return a byte string! => fix it in Py3k!!
         return self.getstring()
 
-class EncodedString(unicode):
+class EncodedString(_str):
     # unicode string subclass to keep track of the original encoding.
     # 'encoding' is None for unicode strings and the source encoding
     # otherwise
@@ -68,7 +79,7 @@ class EncodedString(unicode):
         return self.encoding is None
     is_unicode = property(is_unicode)
 
-class BytesLiteral(str):
+class BytesLiteral(_bytes):
     # str subclass that is compatible with EncodedString
     encoding = None
 
@@ -95,19 +106,23 @@ def _to_escape_sequence(s):
         return repr(s)[1:-1]
     elif s == '"':
         return r'\"'
+    elif s == '\\':
+        return r'\\'
     else:
         # within a character sequence, oct passes much better than hex
         return ''.join(['\\%03o' % ord(c) for c in s])
 
-_c_special = ('\0', '\n', '\r', '\t', '??', '"')
-_c_special_replacements = zip(_c_special, map(_to_escape_sequence, _c_special))
+_c_special = ('\\', '\0', '\n', '\r', '\t', '??', '"')
+_c_special_replacements = [(orig.encode('ASCII'),
+                            _to_escape_sequence(orig).encode('ASCII'))
+                           for orig in _c_special ]
 
 def _build_specials_test():
     subexps = []
     for special in _c_special:
         regexp = ''.join(['[%s]' % c for c in special])
         subexps.append(regexp)
-    return re.compile('|'.join(subexps)).search
+    return re.compile('|'.join(subexps).encode('ASCII')).search
 
 _has_specials = _build_specials_test()
 
@@ -124,24 +139,33 @@ def escape_character(c):
         return c
 
 def escape_byte_string(s):
-    s = s.replace('\\', '\\\\')
     if _has_specials(s):
         for special, replacement in _c_special_replacements:
             s = s.replace(special, replacement)
     try:
-        s.decode("ASCII")
+        s.decode("ASCII") # trial decoding: plain ASCII => done
         return s
     except UnicodeDecodeError:
         pass
-    l = []
-    append = l.append
-    for c in s:
-        o = ord(c)
-        if o >= 128:
-            append('\\%3o' % o)
-        else:
-            append(c)
-    return ''.join(l)
+    if sys.version_info[0] >= 3:
+        s_new = bytearray()
+        append, extend = s_new.append, s_new.extend
+        for b in s:
+            if b >= 128:
+                extend(('\\%3o' % b).encode('ASCII'))
+            else:
+                append(b)
+        return bytes(s_new)
+    else:
+        l = []
+        append = l.append
+        for c in s:
+            o = ord(c)
+            if o >= 128:
+                append('\\%3o' % o)
+            else:
+                append(c)
+        return join_bytes(l)
 
 def split_docstring(s):
     if len(s) < 2047: