From 671e6a53a4621f5a684f45ed9e4436e74af9975e Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Sat, 11 Dec 2010 14:58:57 -0800 Subject: [PATCH] Better unicode/str handling for user-supplied code. --- Cython/Build/Inline.py | 14 ++++++++++++-- Cython/Build/Tests/TestInline.py | 2 +- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/Cython/Build/Inline.py b/Cython/Build/Inline.py index d5494e58..e72e5095 100644 --- a/Cython/Build/Inline.py +++ b/Cython/Build/Inline.py @@ -17,6 +17,16 @@ from Cython.Compiler.ParseTreeTransforms import CythonTransform, SkipDeclaration from Cython.Compiler.TreeFragment import parse_from_strings from Cython.Build.Dependencies import strip_string_literals, cythonize +# A utility function to convert user-supplied ASCII strings to unicode. +if sys.version_info[0] < 3: + def to_unicode(s): + if not isinstance(s, unicode): + return s.decode('ascii') + else: + return s +else: + to_unicode = lambda x: x + _code_cache = {} @@ -28,11 +38,10 @@ class AllSymbols(CythonTransform, SkipDeclarations): self.names.add(node.name) def unbound_symbols(code, context=None): + code = to_unicode(code) if context is None: context = Context([], default_options) from Cython.Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform - if isinstance(code, str): - code = code.decode('ascii') tree = parse_from_strings('(tree fragment)', code) for phase in context.create_pipeline(pxd=False): if phase is None: @@ -90,6 +99,7 @@ def cython_inline(code, **kwds): if get_type is None: get_type = lambda x: 'object' + code = to_unicode(code) code, literals = strip_string_literals(code) code = strip_common_indent(code) ctx = Context(cython_include_dirs, default_options) diff --git a/Cython/Build/Tests/TestInline.py b/Cython/Build/Tests/TestInline.py index e0ed3273..4a637693 100644 --- a/Cython/Build/Tests/TestInline.py +++ b/Cython/Build/Tests/TestInline.py @@ -12,7 +12,7 @@ test_kwds = dict(force=True, quiet=True) global_value = 100 -class TestStripLiterals(CythonTest): +class TestInline(CythonTest): def test_simple(self): self.assertEquals(inline("return 1+2", **test_kwds), 3) -- 2.26.2