Rewrite of the string literal handling code
[cython.git] / Cython / Compiler / StringEncoding.py
1 #
2 #   Cython -- encoding related tools
3 #
4
5 import re
6
7 class UnicodeLiteralBuilder(object):
8     """Assemble a unicode string.
9     """
10     def __init__(self):
11         self.chars = []
12
13     def append(self, characters):
14         if isinstance(characters, str):
15             # this came from a Py2 string literal in the parser code
16             characters = characters.decode("ASCII")
17         assert isinstance(characters, unicode), str(type(characters))
18         self.chars.append(characters)
19
20     def append_charval(self, char_number):
21         self.chars.append( unichr(char_number) )
22
23     def getstring(self):
24         return EncodedString(u''.join(self.chars))
25
26
27 class BytesLiteralBuilder(object):
28     """Assemble a byte string or char value.
29     """
30     def __init__(self, target_encoding):
31         self.chars = []
32         self.target_encoding = target_encoding
33
34     def append(self, characters):
35         if isinstance(characters, unicode):
36             characters = characters.encode(self.target_encoding)
37         assert isinstance(characters, str), str(type(characters))
38         self.chars.append(characters)
39
40     def append_charval(self, char_number):
41         self.chars.append( chr(char_number) )
42
43     def getstring(self):
44         # this *must* return a byte string! => fix it in Py3k!!
45         s = BytesLiteral(''.join(self.chars))
46         s.encoding = self.target_encoding
47         return s
48
49     def getchar(self):
50         # this *must* return a byte string! => fix it in Py3k!!
51         return self.getstring()
52
53 class EncodedString(unicode):
54     # unicode string subclass to keep track of the original encoding.
55     # 'encoding' is None for unicode strings and the source encoding
56     # otherwise
57     encoding = None
58
59     def byteencode(self):
60         assert self.encoding is not None
61         return self.encode(self.encoding)
62
63     def utf8encode(self):
64         assert self.encoding is None
65         return self.encode("UTF-8")
66
67     def is_unicode(self):
68         return self.encoding is None
69     is_unicode = property(is_unicode)
70
71 class BytesLiteral(str):
72     # str subclass that is compatible with EncodedString
73     encoding = None
74
75     def byteencode(self):
76         return str(self)
77
78     def utf8encode(self):
79         assert False, "this is not a unicode string: %r" % self
80
81     is_unicode = False
82
83 char_from_escape_sequence = {
84     r'\a' : u'\a',
85     r'\b' : u'\b',
86     r'\f' : u'\f',
87     r'\n' : u'\n',
88     r'\r' : u'\r',
89     r'\t' : u'\t',
90     r'\v' : u'\v',
91     }.get
92
93 def _to_escape_sequence(s):
94     if s in '\n\r\t':
95         return repr(s)[1:-1]
96     elif s == '"':
97         return r'\"'
98     else:
99         # within a character sequence, oct passes much better than hex
100         return ''.join(['\\%03o' % ord(c) for c in s])
101
102 _c_special = ('\0', '\n', '\r', '\t', '??', '"')
103 _c_special_replacements = zip(_c_special, map(_to_escape_sequence, _c_special))
104
105 def _build_specials_test():
106     subexps = []
107     for special in _c_special:
108         regexp = ''.join(['[%s]' % c for c in special])
109         subexps.append(regexp)
110     return re.compile('|'.join(subexps)).search
111
112 _has_specials = _build_specials_test()
113
114 def escape_character(c):
115     if c in '\n\r\t\\':
116         return repr(c)[1:-1]
117     elif c == "'":
118         return "\\'"
119     n = ord(c)
120     if n < 32 or n > 127:
121         # hex works well for characters
122         return "\\x%02X" % n
123     else:
124         return c
125
126 def escape_byte_string(s):
127     s = s.replace('\\', '\\\\')
128     if _has_specials(s):
129         for special, replacement in _c_special_replacements:
130             s = s.replace(special, replacement)
131     try:
132         s.decode("ASCII")
133         return s
134     except UnicodeDecodeError:
135         pass
136     l = []
137     append = l.append
138     for c in s:
139         o = ord(c)
140         if o >= 128:
141             append('\\%3o' % o)
142         else:
143             append(c)
144     return ''.join(l)