fix generators after raise-from merge
[cython.git] / Cython / Compiler / StringEncoding.py
1 #
2 #   Cython -- encoding related tools
3 #
4
5 import re
6 import sys
7
8 if sys.version_info[0] >= 3:
9     _unicode, _str, _bytes = str, str, bytes
10     IS_PYTHON3 = True
11 else:
12     _unicode, _str, _bytes = unicode, str, str
13     IS_PYTHON3 = False
14
15 empty_bytes = _bytes()
16 empty_unicode = _unicode()
17
18 join_bytes = empty_bytes.join
19
20 class UnicodeLiteralBuilder(object):
21     """Assemble a unicode string.
22     """
23     def __init__(self):
24         self.chars = []
25
26     def append(self, characters):
27         if isinstance(characters, _bytes):
28             # this came from a Py2 string literal in the parser code
29             characters = characters.decode("ASCII")
30         assert isinstance(characters, _unicode), str(type(characters))
31         self.chars.append(characters)
32
33     if sys.maxunicode == 65535:
34         def append_charval(self, char_number):
35             if char_number > 65535:
36                 # wide Unicode character on narrow platform => replace
37                 # by surrogate pair
38                 char_number -= 0x10000
39                 self.chars.append( unichr((char_number // 1024) + 0xD800) )
40                 self.chars.append( unichr((char_number  % 1024) + 0xDC00) )
41             else:
42                 self.chars.append( unichr(char_number) )
43     else:
44         def append_charval(self, char_number):
45             self.chars.append( unichr(char_number) )
46
47     def append_uescape(self, char_number, escape_string):
48         self.append_charval(char_number)
49
50     def getstring(self):
51         return EncodedString(u''.join(self.chars))
52
53     def getstrings(self):
54         return (None, self.getstring())
55
56
57 class BytesLiteralBuilder(object):
58     """Assemble a byte string or char value.
59     """
60     def __init__(self, target_encoding):
61         self.chars = []
62         self.target_encoding = target_encoding
63
64     def append(self, characters):
65         if isinstance(characters, _unicode):
66             characters = characters.encode(self.target_encoding)
67         assert isinstance(characters, _bytes), str(type(characters))
68         self.chars.append(characters)
69
70     def append_charval(self, char_number):
71         self.chars.append( unichr(char_number).encode('ISO-8859-1') )
72
73     def append_uescape(self, char_number, escape_string):
74         self.append(escape_string)
75
76     def getstring(self):
77         # this *must* return a byte string!
78         s = BytesLiteral(join_bytes(self.chars))
79         s.encoding = self.target_encoding
80         return s
81
82     def getchar(self):
83         # this *must* return a byte string!
84         return self.getstring()
85
86     def getstrings(self):
87         return (self.getstring(), None)
88
89 class StrLiteralBuilder(object):
90     """Assemble both a bytes and a unicode representation of a string.
91     """
92     def __init__(self, target_encoding):
93         self._bytes   = BytesLiteralBuilder(target_encoding)
94         self._unicode = UnicodeLiteralBuilder()
95
96     def append(self, characters):
97         self._bytes.append(characters)
98         self._unicode.append(characters)
99
100     def append_charval(self, char_number):
101         self._bytes.append_charval(char_number)
102         self._unicode.append_charval(char_number)
103
104     def append_uescape(self, char_number, escape_string):
105         self._bytes.append(escape_string)
106         self._unicode.append_charval(char_number)
107
108     def getstrings(self):
109         return (self._bytes.getstring(), self._unicode.getstring())
110
111
112 class EncodedString(_unicode):
113     # unicode string subclass to keep track of the original encoding.
114     # 'encoding' is None for unicode strings and the source encoding
115     # otherwise
116     encoding = None
117
118     def byteencode(self):
119         assert self.encoding is not None
120         return self.encode(self.encoding)
121
122     def utf8encode(self):
123         assert self.encoding is None
124         return self.encode("UTF-8")
125
126     def is_unicode(self):
127         return self.encoding is None
128     is_unicode = property(is_unicode)
129
130 class BytesLiteral(_bytes):
131     # bytes subclass that is compatible with EncodedString
132     encoding = None
133
134     def byteencode(self):
135         if IS_PYTHON3:
136             return _bytes(self)
137         else:
138             # fake-recode the string to make it a plain bytes object
139             return self.decode('ISO-8859-1').encode('ISO-8859-1')
140
141     def utf8encode(self):
142         assert False, "this is not a unicode string: %r" % self
143
144     def __str__(self):
145         """Fake-decode the byte string to unicode to support %
146         formatting of unicode strings.
147         """
148         return self.decode('ISO-8859-1')
149
150     is_unicode = False
151
152 char_from_escape_sequence = {
153     r'\a' : u'\a',
154     r'\b' : u'\b',
155     r'\f' : u'\f',
156     r'\n' : u'\n',
157     r'\r' : u'\r',
158     r'\t' : u'\t',
159     r'\v' : u'\v',
160     }.get
161
162 def _to_escape_sequence(s):
163     if s in '\n\r\t':
164         return repr(s)[1:-1]
165     elif s == '"':
166         return r'\"'
167     elif s == '\\':
168         return r'\\'
169     else:
170         # within a character sequence, oct passes much better than hex
171         return ''.join(['\\%03o' % ord(c) for c in s])
172
173 _c_special = ('\\', '??', '"') + tuple(map(chr, range(32)))
174 _c_special_replacements = [(orig.encode('ASCII'),
175                             _to_escape_sequence(orig).encode('ASCII'))
176                            for orig in _c_special ]
177
178 def _build_specials_test():
179     subexps = []
180     for special in _c_special:
181         regexp = ''.join(['[%s]' % c.replace('\\', '\\\\') for c in special])
182         subexps.append(regexp)
183     return re.compile('|'.join(subexps).encode('ASCII')).search
184
185 _has_specials = _build_specials_test()
186
187 def escape_char(c):
188     if IS_PYTHON3:
189         c = c.decode('ISO-8859-1')
190     if c in '\n\r\t\\':
191         return repr(c)[1:-1]
192     elif c == "'":
193         return "\\'"
194     n = ord(c)
195     if n < 32 or n > 127:
196         # hex works well for characters
197         return "\\x%02X" % n
198     else:
199         return c
200
201 def escape_byte_string(s):
202     """Escape a byte string so that it can be written into C code.
203     Note that this returns a Unicode string instead which, when
204     encoded as ISO-8859-1, will result in the correct byte sequence
205     being written.
206     """
207     if _has_specials(s):
208         for special, replacement in _c_special_replacements:
209             if special in s:
210                 s = s.replace(special, replacement)
211     try:
212         return s.decode("ASCII") # trial decoding: plain ASCII => done
213     except UnicodeDecodeError:
214         pass
215     if IS_PYTHON3:
216         s_new = bytearray()
217         append, extend = s_new.append, s_new.extend
218         for b in s:
219             if b >= 128:
220                 extend(('\\%3o' % b).encode('ASCII'))
221             else:
222                 append(b)
223         return s_new.decode('ISO-8859-1')
224     else:
225         l = []
226         append = l.append
227         for c in s:
228             o = ord(c)
229             if o >= 128:
230                 append('\\%3o' % o)
231             else:
232                 append(c)
233         return join_bytes(l).decode('ISO-8859-1')
234
235 def split_string_literal(s, limit=2000):
236     # MSVC can't handle long string literals.
237     if len(s) < limit:
238         return s
239     else:
240         start = 0
241         chunks = []
242         while start < len(s):
243             end = start + limit
244             if len(s) > end-4 and '\\' in s[end-4:end]:
245                 end -= 4 - s[end-4:end].find('\\') # just before the backslash
246                 while s[end-1] == '\\':
247                     end -= 1
248                     if end == start:
249                         # must have been a long line of backslashes
250                         end = start + limit - (limit % 2) - 4
251                         break
252             chunks.append(s[start:end])
253             start = end
254         return '""'.join(chunks)