Better unicode/str handling for user-supplied code.
authorRobert Bradshaw <robertwb@math.washington.edu>
Sat, 11 Dec 2010 22:58:57 +0000 (14:58 -0800)
committerRobert Bradshaw <robertwb@math.washington.edu>
Sat, 11 Dec 2010 22:58:57 +0000 (14:58 -0800)
Cython/Build/Inline.py
Cython/Build/Tests/TestInline.py

index d5494e582fbc01241b76f349d9d3d8f8d271c891..e72e50953e5990aace1c9dc055f8aa5645d9a4cb 100644 (file)
@@ -17,6 +17,16 @@ from Cython.Compiler.ParseTreeTransforms import CythonTransform, SkipDeclaration
 from Cython.Compiler.TreeFragment import parse_from_strings
 from Cython.Build.Dependencies import strip_string_literals, cythonize
 
+# A utility function to convert user-supplied ASCII strings to unicode.
+if sys.version_info[0] < 3:
+    def to_unicode(s):
+        if not isinstance(s, unicode):
+            return s.decode('ascii')
+        else:
+            return s
+else:
+    to_unicode = lambda x: x
+
 _code_cache = {}
 
 
@@ -28,11 +38,10 @@ class AllSymbols(CythonTransform, SkipDeclarations):
         self.names.add(node.name)
 
 def unbound_symbols(code, context=None):
+    code = to_unicode(code)
     if context is None:
         context = Context([], default_options)
     from Cython.Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform
-    if isinstance(code, str):
-        code = code.decode('ascii')
     tree = parse_from_strings('(tree fragment)', code)
     for phase in context.create_pipeline(pxd=False):
         if phase is None:
@@ -90,6 +99,7 @@ def cython_inline(code,
                   **kwds):
     if get_type is None:
         get_type = lambda x: 'object'
+    code = to_unicode(code)
     code, literals = strip_string_literals(code)
     code = strip_common_indent(code)
     ctx = Context(cython_include_dirs, default_options)
index e0ed3273d6ee3255f5b47ffa4d3e6ed19000535c..4a6376934bca3ba929ad316e008bd2ecd6ee9748 100644 (file)
@@ -12,7 +12,7 @@ test_kwds = dict(force=True, quiet=True)
 
 global_value = 100
 
-class TestStripLiterals(CythonTest):
+class TestInline(CythonTest):
 
     def test_simple(self):
         self.assertEquals(inline("return 1+2", **test_kwds), 3)