reject non-ASCII literal characters in Python 3 byte strings
authorStefan Behnel <scoder@users.berlios.de>
Tue, 7 Sep 2010 18:25:41 +0000 (20:25 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Tue, 7 Sep 2010 18:25:41 +0000 (20:25 +0200)
Cython/Compiler/Parsing.pxd
Cython/Compiler/Parsing.py
tests/errors/cython3_bytes.pyx [new file with mode: 0644]
tests/run/cython2_bytes.pyx [new file with mode: 0644]

index 5a0a34734aa4b5cf65beef7434928f5dc79390d0..ac0ad879d3d1afb032ba18c8bc002d42292953bb 100644 (file)
@@ -47,6 +47,7 @@ cpdef p_atom(PyrexScanner s)
 cpdef p_name(PyrexScanner s, name)
 cpdef p_cat_string_literal(PyrexScanner s)
 cpdef p_opt_string_literal(PyrexScanner s, required_type=*)
+cpdef bint check_for_non_ascii_characters(unicode string)
 cpdef p_string_literal(PyrexScanner s, kind_override=*)
 cpdef p_list_maker(PyrexScanner s)
 cpdef p_comp_iter(PyrexScanner s, body)
index 93c7413b66cd45525f3c1460c3cc603567c3b76d..87a8ebf758641a6abe4d45d7f6c7642130f3e243 100644 (file)
@@ -661,9 +661,9 @@ def p_cat_string_literal(s):
             bstrings.append(next_bytes_value)
             ustrings.append(next_unicode_value)
     # join and rewrap the partial literals
-    if kind in ('b', 'c', '') or kind == 'u' and bstrings[0] is not None:
+    if kind in ('b', 'c', '') or kind == 'u' and None not in bstrings:
         # Py3 enforced unicode literals are parsed as bytes/unicode combination
-        bytes_value = BytesLiteral( StringEncoding.join_bytes([ b for b in bstrings if b is not None ]) )
+        bytes_value = BytesLiteral( StringEncoding.join_bytes(bstrings) )
         bytes_value.encoding = s.source_encoding
     if kind in ('u', ''):
         unicode_value = EncodedString( u''.join([ u for u in ustrings if u is not None ]) )
@@ -681,6 +681,12 @@ def p_opt_string_literal(s, required_type='u'):
     else:
         return None
 
+def check_for_non_ascii_characters(string):
+    for c in string:
+        if c >= u'\x80':
+            return True
+    return False
+
 def p_string_literal(s, kind_override=None):
     # A single string or char literal.  Returns (kind, bvalue, uvalue)
     # where kind in ('b', 'c', 'u', '').  The 'bvalue' is the source
@@ -692,6 +698,7 @@ def p_string_literal(s, kind_override=None):
     # s.sy == 'BEGIN_STRING'
     pos = s.position()
     is_raw = 0
+    has_non_ASCII_literal_characters = False
     kind = s.systring[:1].lower()
     if kind == 'r':
         kind = ''
@@ -715,12 +722,13 @@ def p_string_literal(s, kind_override=None):
     while 1:
         s.next()
         sy = s.sy
+        systr = s.systring
         #print "p_string_literal: sy =", sy, repr(s.systring) ###
         if sy == 'CHARS':
-            chars.append(s.systring)
+            chars.append(systr)
+            if not has_non_ASCII_literal_characters and check_for_non_ascii_characters(systr):
+                has_non_ASCII_literal_characters = True
         elif sy == 'ESCAPE':
-            has_escape = True
-            systr = s.systring
             if is_raw:
                 if systr == u'\\\n':
                     chars.append(u'\\\n')
@@ -730,6 +738,8 @@ def p_string_literal(s, kind_override=None):
                     chars.append(u"'")
                 else:
                     chars.append(systr)
+                    if not has_non_ASCII_literal_characters and check_for_non_ascii_characters(systr):
+                        has_non_ASCII_literal_characters = True
             else:
                 c = systr[1]
                 if c in u"01234567":
@@ -755,6 +765,8 @@ def p_string_literal(s, kind_override=None):
                     chars.append_uescape(chrval, systr)
                 else:
                     chars.append(u'\\' + systr[1:])
+                    if not has_non_ASCII_literal_characters and check_for_non_ascii_characters(systr):
+                        has_non_ASCII_literal_characters = True
         elif sy == 'NEWLINE':
             chars.append(u'\n')
         elif sy == 'END_STRING':
@@ -772,6 +784,11 @@ def p_string_literal(s, kind_override=None):
             error(pos, u"invalid character literal: %r" % bytes_value)
     else:
         bytes_value, unicode_value = chars.getstrings()
+        if has_non_ASCII_literal_characters and s.context.language_level >= 3:
+            # Python 3 forbids literal non-ASCII characters in byte strings
+            if kind != 'u':
+                s.error("bytes can only contain ASCII literal characters.", pos = pos)
+            bytes_value = None
     s.next()
     return (kind, bytes_value, unicode_value)
 
diff --git a/tests/errors/cython3_bytes.pyx b/tests/errors/cython3_bytes.pyx
new file mode 100644 (file)
index 0000000..a8ad4f6
--- /dev/null
@@ -0,0 +1,9 @@
+# -*- coding: utf-8 -*-
+# cython: language_level=3
+
+escaped = b'abc\xc3\xbc\xc3\xb6\xc3\xa4'
+invalid = b'abcüöä'
+
+_ERRORS = """
+5:10: bytes can only contain ASCII literal characters.
+"""
diff --git a/tests/run/cython2_bytes.pyx b/tests/run/cython2_bytes.pyx
new file mode 100644 (file)
index 0000000..84eec1c
--- /dev/null
@@ -0,0 +1,13 @@
+# -*- coding: utf-8 -*-
+# cython: language_level=2
+
+b = b'abcüöä \x12'
+
+cdef char* cs = 'abcüöä \x12'
+
+def compare_cs():
+    """
+    >>> b == compare_cs()
+    True
+    """
+    return cs