From: Dag Sverre Seljebotn Date: Tue, 1 Jul 2008 20:22:10 +0000 (+0200) Subject: Buffer parsing complete; small transform factorizations and renaming of PostParse X-Git-Tag: 0.9.8.1~49^2~111^2~1 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=1788299ebce49f4c002b02e2ba13150d3262e41a;p=cython.git Buffer parsing complete; small transform factorizations and renaming of PostParse --- diff --git a/Cython/Compiler/Errors.py b/Cython/Compiler/Errors.py index 81414627..bc105faf 100644 --- a/Cython/Compiler/Errors.py +++ b/Cython/Compiler/Errors.py @@ -32,6 +32,7 @@ class CompileError(PyrexError): def __init__(self, position = None, message = ""): self.position = position + self.message_only = message # Deprecated and withdrawn in 2.6: # self.message = message if position: diff --git a/Cython/Compiler/Main.py b/Cython/Compiler/Main.py index 5c51e632..e3fbcb31 100644 --- a/Cython/Compiler/Main.py +++ b/Cython/Compiler/Main.py @@ -334,17 +334,18 @@ def create_generate_code(context, options, result): return generate_code def create_default_pipeline(context, options, result): - from ParseTreeTransforms import WithTransform, PostParse + from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ModuleNode import check_c_classes return [ create_parse(context), - PostParse(), - WithTransform(), - AnalyseDeclarationsTransform(), + NormalizeTree(context), + PostParse(context), + WithTransform(context), + AnalyseDeclarationsTransform(context), check_c_classes, - AnalyseExpressionsTransform(), + AnalyseExpressionsTransform(context), create_generate_code(context, options, result) ] diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index 75db65e5..698d57c6 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -565,7 +565,6 @@ class CBaseTypeNode(Node): pass - class CSimpleBaseTypeNode(CBaseTypeNode): # name string # module_path [string] Qualifying name components diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index dbb30d24..b31642dd 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -1,10 +1,13 @@ -from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle +from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform from Cython.Compiler.Nodes import * from Cython.Compiler.ExprNodes import * from Cython.Compiler.TreeFragment import TreeFragment +from Cython.Utils import EncodedString +from Cython.Compiler.Errors import CompileError +from sets import Set as set -class PostParse(VisitorTransform): +class NormalizeTree(CythonTransform): """ This transform fixes up a few things after parsing in order to make the parse tree more suitable for @@ -25,15 +28,11 @@ class PostParse(VisitorTransform): StatListNode has no children to see if the block is empty). """ - def __init__(self): - super(PostParse, self).__init__() + def __init__(self, context): + super(NormalizeTree, self).__init__(context) self.is_in_statlist = False self.is_in_expr = False - def visit_Node(self, node): - self.visitchildren(node) - return node - def visit_ExprNode(self, node): stacktmp = self.is_in_expr self.is_in_expr = True @@ -73,7 +72,80 @@ class PostParse(VisitorTransform): return self.visit_StatNode(node, True) -class WithTransform(VisitorTransform): +class PostParseError(CompileError): pass + +# error strings checked by unit tests, so define them +ERR_BUF_OPTION_UNKNOWN = '"%s" is not a buffer option' +ERR_BUF_TOO_MANY = 'Too many buffer options' +ERR_BUF_DUP = '"%s" buffer option already supplied' +ERR_BUF_MISSING = '"%s" missing' +ERR_BUF_INT = '"%s" must be an integer' +ERR_BUF_NONNEG = '"%s" must be non-negative' + +class PostParse(CythonTransform): + """ + Basic interpretation of the parse tree, as well as validity + checking that can be done on a very basic level on the parse + tree (while still not being a problem with the basic syntax, + as such). + + Specifically: + - CBufferAccessTypeNode has its options interpreted: + Any first positional argument goes into the "dtype" attribute, + any "ndim" keyword argument goes into the "ndim" attribute and + so on. Also it is checked that the option combination is valid. + + Note: Currently Parsing.py does a lot of interpretation and + reorganization that can be refactored into this transform + if a more pure Abstract Syntax Tree is wanted. + """ + + buffer_options = ("dtype", "ndim") # ordered! + def visit_CBufferAccessTypeNode(self, node): + options = {} + # Fetch positional arguments + if len(node.positional_args) > len(self.buffer_options): + self.context.error(ERR_BUF_TOO_MANY) + for arg, unicode_name in zip(node.positional_args, self.buffer_options): + name = str(unicode_name) + options[name] = arg + # Fetch named arguments + for item in node.keyword_args.key_value_pairs: + name = str(item.key.value) + if not name in self.buffer_options: + raise PostParseError(item.key.pos, + ERR_BUF_UNKNOWN % name) + if name in options.keys(): + raise PostParseError(item.key.pos, + ERR_BUF_DUP % key) + options[name] = item.value + + provided = options.keys() + # get dtype + dtype = options.get("dtype") + if dtype is None: raise PostParseError(node.pos, ERR_BUF_MISSING % 'dtype') + node.dtype = dtype + + # get ndim + if "ndim" in provided: + ndimnode = options["ndim"] + if not isinstance(ndimnode, IntNode): + # Compile-time values (DEF) are currently resolved by the parser, + # so nothing more to do here + raise PostParseError(ndimnode.pos, ERR_BUF_INT % 'ndim') + ndim_value = int(ndimnode.value) + if ndim_value < 0: + raise PostParseError(ndimnode.pos, ERR_BUF_NONNEG % 'ndim') + node.ndim = int(ndimnode.value) + + # We're done with the parse tree args + node.positional_args = None + node.keyword_args = None + return node + + + +class WithTransform(CythonTransform): # EXCINFO is manually set to a variable that contains # the exc_info() tuple that can be generated by the enclosing except @@ -94,7 +166,7 @@ class WithTransform(VisitorTransform): if EXC: EXIT(None, None, None) """, temps=[u'MGR', u'EXC', u"EXIT", u"SYS)"], - pipeline=[PostParse()]) + pipeline=[NormalizeTree(None)]) template_with_target = TreeFragment(u""" MGR = EXPR @@ -113,11 +185,7 @@ class WithTransform(VisitorTransform): if EXC: EXIT(None, None, None) """, temps=[u'MGR', u'EXC', u"EXIT", u"VALUE", u"SYS"], - pipeline=[PostParse()]) - - def visit_Node(self, node): - self.visitchildren(node) - return node + pipeline=[NormalizeTree(None)]) def visit_WithStatNode(self, node): excinfo_name = temp_name_handle('EXCINFO') @@ -143,7 +211,7 @@ class WithTransform(VisitorTransform): return result.stats -class AnalyseDeclarationsTransform(VisitorTransform): +class AnalyseDeclarationsTransform(CythonTransform): def __call__(self, root): self.env_stack = [root.scope] @@ -164,12 +232,7 @@ class AnalyseDeclarationsTransform(VisitorTransform): self.env_stack.pop() return node - def visit_Node(self, node): - self.visitchildren(node) - return node - - -class AnalyseExpressionsTransform(VisitorTransform): +class AnalyseExpressionsTransform(CythonTransform): def visit_ModuleNode(self, node): node.body.analyse_expressions(node.scope) @@ -181,7 +244,4 @@ class AnalyseExpressionsTransform(VisitorTransform): self.visitchildren(node) return node - def visit_Node(self, node): - self.visitchildren(node) - return node diff --git a/Cython/Compiler/Parsing.py b/Cython/Compiler/Parsing.py index 2308553a..de9c7810 100644 --- a/Cython/Compiler/Parsing.py +++ b/Cython/Compiler/Parsing.py @@ -281,20 +281,18 @@ def p_trailer(s, node1): return ExprNodes.AttributeNode(pos, obj = node1, attribute = name) -def p_positional_and_keyword_callargs(s, end_sy_set): - """ - Parses positional and keyword call arguments. end_sy_set - should contain any s.sy that terminate the argument chain - (this is ('*', '**', ')') for a normal function call, - and (']',) for buffers declarators). +# arglist: argument (',' argument)* [','] +# argument: [test '='] test # Really [keyword '='] test - Returns: (positional_args, keyword_args) - """ +def p_call(s, function): + # s.sy == '(' + pos = s.position() + s.next() positional_args = [] keyword_args = [] - while s.sy not in end_sy_set: - if s.sy == '*' or s.sy == '**': - s.error('Argument expansion not allowed here.') + star_arg = None + starstar_arg = None + while s.sy not in ('*', '**', ')'): arg = p_simple_expr(s) if s.sy == '=': s.next() @@ -314,20 +312,6 @@ def p_positional_and_keyword_callargs(s, end_sy_set): if s.sy != ',': break s.next() - return positional_args, keyword_args - - -# arglist: argument (',' argument)* [','] -# argument: [test '='] test # Really [keyword '='] test - -def p_call(s, function): - # s.sy == '(' - pos = s.position() - s.next() - star_arg = None - starstar_arg = None - positional_args, keyword_args = ( - p_positional_and_keyword_callargs(s,('*', '**', ')'))) if s.sy == '*': s.next() @@ -1473,6 +1457,71 @@ def p_suite(s, ctx = Ctx(), with_doc = 0, with_pseudo_doc = 0): else: return body +def p_positional_and_keyword_args(s, end_sy_set, type_positions=(), type_keywords=()): + """ + Parses positional and keyword arguments. end_sy_set + should contain any s.sy that terminate the argument list. + Argument expansion (* and **) are not allowed. + + type_positions and type_keywords specifies which argument + positions and/or names which should be interpreted as + types. Other arguments will be treated as expressions. + + Returns: (positional_args, keyword_args) + """ + positional_args = [] + keyword_args = [] + pos_idx = 0 + + while s.sy not in end_sy_set: + if s.sy == '*' or s.sy == '**': + s.error('Argument expansion not allowed here.') + + was_keyword = False + parsed_type = False + if s.sy == 'IDENT': + # Since we can have either types or expressions as positional args, + # we use a strategy of looking an extra step forward for a '=' and + # if it is a positional arg we backtrack. + ident = s.systring + s.next() + if s.sy == '=': + s.next() + # Is keyword arg + if ident in type_keywords: + arg = p_c_base_type(s) + parsed_type = True + else: + arg = p_simple_expr(s) + keyword_node = ExprNodes.IdentifierStringNode(arg.pos, + value = Utils.EncodedString(ident)) + keyword_args.append((keyword_node, arg)) + was_keyword = True + else: + s.put_back('IDENT', ident) + + if not was_keyword: + if pos_idx in type_positions: + arg = p_c_base_type(s) + parsed_type = True + else: + arg = p_simple_expr(s) + positional_args.append(arg) + pos_idx += 1 + if len(keyword_args) > 0: + s.error("Non-keyword arg following keyword arg", + pos = arg.pos) + + if s.sy != ',': + if s.sy not in end_sy_set: + if parsed_type: + s.error("Expected: type") + else: + s.error("Expected: expression") + break + s.next() + return positional_args, keyword_args + def p_c_base_type(s, self_flag = 0, nonempty = 0): # If self_flag is true, this is the base type for the # self argument of a C method of an extension type. @@ -1556,24 +1605,32 @@ def p_c_simple_base_type(s, self_flag, nonempty): if s.sy == '[': if is_basic: p.error("Basic C types do not support buffer access") - s.next() - positional_args, keyword_args = ( - p_positional_and_keyword_callargs(s, ('[]',))) - s.expect(']') - - keyword_dict = ExprNodes.DictNode(pos, - key_value_pairs = [ - ExprNodes.DictItemNode(pos=key.pos, key=key, value=value) - for key, value in keyword_args - ]) - - return Nodes.CBufferAccessTypeNode(pos, - positional_args = positional_args, - keyword_args = keyword_dict, - base_type_node = type_node) + return p_buffer_access(s, type_node) else: return type_node +def p_buffer_access(s, type_node): + # s.sy == '[' + pos = s.position() + s.next() + positional_args, keyword_args = ( + p_positional_and_keyword_args(s, (']',), (0,), ('dtype',)) + ) + s.expect(']') + + keyword_dict = ExprNodes.DictNode(pos, + key_value_pairs = [ + ExprNodes.DictItemNode(pos=key.pos, key=key, value=value) + for key, value in keyword_args + ]) + + result = Nodes.CBufferAccessTypeNode(pos, + positional_args = positional_args, + keyword_args = keyword_dict, + base_type_node = type_node) + return result + + def looking_at_type(s): return looking_at_base_type(s) or s.looking_at_type_name() diff --git a/Cython/Compiler/Tests/TestParseTreeTransforms.py b/Cython/Compiler/Tests/TestParseTreeTransforms.py index 3455dbc4..28accbbf 100644 --- a/Cython/Compiler/Tests/TestParseTreeTransforms.py +++ b/Cython/Compiler/Tests/TestParseTreeTransforms.py @@ -2,7 +2,7 @@ from Cython.TestUtils import TransformTest from Cython.Compiler.ParseTreeTransforms import * from Cython.Compiler.Nodes import * -class TestPostParse(TransformTest): +class TestNormalizeTree(TransformTest): def test_parserbehaviour_is_what_we_coded_for(self): t = self.fragment(u"if x: y").root self.assertLines(u""" @@ -15,7 +15,7 @@ class TestPostParse(TransformTest): """, self.treetypes(t)) def test_wrap_singlestat(self): - t = self.run_pipeline([PostParse()], u"if x: y") + t = self.run_pipeline([NormalizeTree(None)], u"if x: y") self.assertLines(u""" (root): StatListNode stats[0]: IfStatNode @@ -27,7 +27,7 @@ class TestPostParse(TransformTest): """, self.treetypes(t)) def test_wrap_multistat(self): - t = self.run_pipeline([PostParse()], u""" + t = self.run_pipeline([NormalizeTree(None)], u""" if z: x y @@ -45,7 +45,7 @@ class TestPostParse(TransformTest): """, self.treetypes(t)) def test_statinexpr(self): - t = self.run_pipeline([PostParse()], u""" + t = self.run_pipeline([NormalizeTree(None)], u""" a, b = x, y """) self.assertLines(u""" @@ -60,7 +60,7 @@ class TestPostParse(TransformTest): """, self.treetypes(t)) def test_wrap_offagain(self): - t = self.run_pipeline([PostParse()], u""" + t = self.run_pipeline([NormalizeTree(None)], u""" x y if z: @@ -82,13 +82,13 @@ class TestPostParse(TransformTest): def test_pass_eliminated(self): - t = self.run_pipeline([PostParse()], u"pass") + t = self.run_pipeline([NormalizeTree(None)], u"pass") self.assert_(len(t.stats) == 0) class TestWithTransform(TransformTest): def test_simplified(self): - t = self.run_pipeline([WithTransform()], u""" + t = self.run_pipeline([WithTransform(None)], u""" with x: y = z ** 3 """) @@ -113,7 +113,7 @@ class TestWithTransform(TransformTest): """, t) def test_basic(self): - t = self.run_pipeline([WithTransform()], u""" + t = self.run_pipeline([WithTransform(None)], u""" with x as y: y = z ** 3 """) diff --git a/Cython/Compiler/Visitor.py b/Cython/Compiler/Visitor.py index 0f9623d2..31b6386b 100644 --- a/Cython/Compiler/Visitor.py +++ b/Cython/Compiler/Visitor.py @@ -131,7 +131,6 @@ class VisitorTransform(TreeVisitor): was not, an exception will be raised. (Typically you want to ensure that you are within a StatListNode or similar before doing this.) """ - def visitchildren(self, parent, attrs=None): result = super(VisitorTransform, self).visitchildren(parent, attrs) for attr, newnode in result.iteritems(): @@ -152,6 +151,19 @@ class VisitorTransform(TreeVisitor): def __call__(self, root): return self.visit(root) +class CythonTransform(VisitorTransform): + """ + Certain common conventions and utilitues for Cython transforms. + """ + def __init__(self, context): + super(CythonTransform, self).__init__() + self.context = context + + def visit_Node(self, node): + self.visitchildren(node) + return node + + # Utils def ensure_statlist(node): if not isinstance(node, Nodes.StatListNode): diff --git a/Cython/TestUtils.py b/Cython/TestUtils.py index 87f072f8..0588ee4a 100644 --- a/Cython/TestUtils.py +++ b/Cython/TestUtils.py @@ -50,12 +50,12 @@ class CythonTest(unittest.TestCase): self.assertEqual(len(result_lines), len(expected_lines), "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected)) - def fragment(self, code, pxds={}): + def fragment(self, code, pxds={}, pipeline=[]): "Simply create a tree fragment using the name of the test-case in parse errors." name = self.id() if name.startswith("__main__."): name = name[len("__main__."):] name = name.replace(".", "_") - return TreeFragment(code, name, pxds) + return TreeFragment(code, name, pxds, pipeline=pipeline) def treetypes(self, root): """Returns a string representing the tree by class names. @@ -66,6 +66,26 @@ class CythonTest(unittest.TestCase): w.visit(root) return u"\n".join([u""] + w.result + [u""]) + def should_fail(self, func, exc_type=Exception): + """Calls "func" and fails if it doesn't raise the right exception + (any exception by default). Also returns the exception in question. + """ + try: + func() + self.fail("Expected an exception of type %r" % exc_type) + except exc_type, e: + self.assert_(isinstance(e, exc_type)) + return e + + def should_not_fail(self, func): + """Calls func and succeeds if and only if no exception is raised + (i.e. converts exception raising into a failed testcase). Returns + the return value of func.""" + try: + return func() + except: + self.fail() + class TransformTest(CythonTest): """ Utility base class for transform unit tests. It is based around constructing