From: Stefan Behnel Date: Sun, 5 Jul 2009 19:30:32 +0000 (+0200) Subject: Py3 fix X-Git-Tag: 0.12.alpha0~263 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=bd0195333c501d48387d8e23a35bf1ab5d79b6a9;p=cython.git Py3 fix --- diff --git a/Cython/Compiler/StringEncoding.py b/Cython/Compiler/StringEncoding.py index 12ae3588..b9de328e 100644 --- a/Cython/Compiler/StringEncoding.py +++ b/Cython/Compiler/StringEncoding.py @@ -3,6 +3,17 @@ # import re +import sys + +if sys.version_info[0] >= 3: + _str, _bytes = str, bytes +else: + _str, _bytes = unicode, str + +empty_bytes = _bytes() +empty_str = _str() + +join_bytes = empty_bytes.join class UnicodeLiteralBuilder(object): """Assemble a unicode string. @@ -11,10 +22,10 @@ class UnicodeLiteralBuilder(object): self.chars = [] def append(self, characters): - if isinstance(characters, str): + if isinstance(characters, _bytes): # this came from a Py2 string literal in the parser code characters = characters.decode("ASCII") - assert isinstance(characters, unicode), str(type(characters)) + assert isinstance(characters, _str), str(type(characters)) self.chars.append(characters) def append_charval(self, char_number): @@ -32,9 +43,9 @@ class BytesLiteralBuilder(object): self.target_encoding = target_encoding def append(self, characters): - if isinstance(characters, unicode): + if isinstance(characters, _str): characters = characters.encode(self.target_encoding) - assert isinstance(characters, str), str(type(characters)) + assert isinstance(characters, _bytes), str(type(characters)) self.chars.append(characters) def append_charval(self, char_number): @@ -42,7 +53,7 @@ class BytesLiteralBuilder(object): def getstring(self): # this *must* return a byte string! => fix it in Py3k!! - s = BytesLiteral(''.join(self.chars)) + s = BytesLiteral(join_bytes(self.chars)) s.encoding = self.target_encoding return s @@ -50,7 +61,7 @@ class BytesLiteralBuilder(object): # this *must* return a byte string! => fix it in Py3k!! return self.getstring() -class EncodedString(unicode): +class EncodedString(_str): # unicode string subclass to keep track of the original encoding. # 'encoding' is None for unicode strings and the source encoding # otherwise @@ -68,7 +79,7 @@ class EncodedString(unicode): return self.encoding is None is_unicode = property(is_unicode) -class BytesLiteral(str): +class BytesLiteral(_bytes): # str subclass that is compatible with EncodedString encoding = None @@ -95,19 +106,23 @@ def _to_escape_sequence(s): return repr(s)[1:-1] elif s == '"': return r'\"' + elif s == '\\': + return r'\\' else: # within a character sequence, oct passes much better than hex return ''.join(['\\%03o' % ord(c) for c in s]) -_c_special = ('\0', '\n', '\r', '\t', '??', '"') -_c_special_replacements = zip(_c_special, map(_to_escape_sequence, _c_special)) +_c_special = ('\\', '\0', '\n', '\r', '\t', '??', '"') +_c_special_replacements = [(orig.encode('ASCII'), + _to_escape_sequence(orig).encode('ASCII')) + for orig in _c_special ] def _build_specials_test(): subexps = [] for special in _c_special: regexp = ''.join(['[%s]' % c for c in special]) subexps.append(regexp) - return re.compile('|'.join(subexps)).search + return re.compile('|'.join(subexps).encode('ASCII')).search _has_specials = _build_specials_test() @@ -124,24 +139,33 @@ def escape_character(c): return c def escape_byte_string(s): - s = s.replace('\\', '\\\\') if _has_specials(s): for special, replacement in _c_special_replacements: s = s.replace(special, replacement) try: - s.decode("ASCII") + s.decode("ASCII") # trial decoding: plain ASCII => done return s except UnicodeDecodeError: pass - l = [] - append = l.append - for c in s: - o = ord(c) - if o >= 128: - append('\\%3o' % o) - else: - append(c) - return ''.join(l) + if sys.version_info[0] >= 3: + s_new = bytearray() + append, extend = s_new.append, s_new.extend + for b in s: + if b >= 128: + extend(('\\%3o' % b).encode('ASCII')) + else: + append(b) + return bytes(s_new) + else: + l = [] + append = l.append + for c in s: + o = ord(c) + if o >= 128: + append('\\%3o' % o) + else: + append(c) + return join_bytes(l) def split_docstring(s): if len(s) < 2047: