From: Dag Sverre Seljebotn Date: Wed, 24 Sep 2008 21:54:56 +0000 (+0200) Subject: Introduce TempsBlockNode utility, improve TreeFragment-generated temps X-Git-Tag: 0.9.9.2.beta~86 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=a67aaf75bcb2108b09828aec9751994d63979554;p=cython.git Introduce TempsBlockNode utility, improve TreeFragment-generated temps --- diff --git a/Cython/CodeWriter.py b/Cython/CodeWriter.py index dca1f6a7..cbe91aab 100644 --- a/Cython/CodeWriter.py +++ b/Cython/CodeWriter.py @@ -1,4 +1,4 @@ -from Cython.Compiler.Visitor import TreeVisitor, get_temp_name_handle_desc +from Cython.Compiler.Visitor import TreeVisitor from Cython.Compiler.Nodes import * from Cython.Compiler.ExprNodes import * @@ -37,6 +37,7 @@ class CodeWriter(TreeVisitor): self.result = result self.numindents = 0 self.tempnames = {} + self.tempblockindex = 0 def write(self, tree): self.visit(tree) @@ -60,12 +61,6 @@ class CodeWriter(TreeVisitor): self.startline(s) self.endline() - def putname(self, name): - tmpdesc = get_temp_name_handle_desc(name) - if tmpdesc is not None: - name = self.tempnames.setdefault(tmpdesc, u"$" +tmpdesc) - self.put(name) - def comma_seperated_list(self, items, output_rhs=False): if len(items) > 0: for item in items[:-1]: @@ -132,7 +127,7 @@ class CodeWriter(TreeVisitor): self.endline() def visit_NameNode(self, node): - self.putname(node.name) + self.put(node.name) def visit_IntNode(self, node): self.put(node.value) @@ -312,3 +307,18 @@ class CodeWriter(TreeVisitor): self.visit(node.operand) self.put(u")") + def visit_TempsBlockNode(self, node): + """ + Temporaries are output like $1_1', where the first number is + an index of the TempsBlockNode and the second number is an index + of the temporary which that block allocates. + """ + idx = 0 + for handle in node.handles: + self.tempnames[handle] = "$%d_%d" % (self.tempblockindex, idx) + idx += 1 + self.tempblockindex += 1 + self.visit(node.body) + + def visit_TempRefNode(self, node): + self.put(self.tempnames[node.handle]) diff --git a/Cython/Compiler/Buffer.py b/Cython/Compiler/Buffer.py index aa6d09b0..e80e91f8 100644 --- a/Cython/Compiler/Buffer.py +++ b/Cython/Compiler/Buffer.py @@ -1,8 +1,7 @@ -from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform +from Cython.Compiler.Visitor import VisitorTransform, CythonTransform from Cython.Compiler.ModuleNode import ModuleNode from Cython.Compiler.Nodes import * from Cython.Compiler.ExprNodes import * -from Cython.Compiler.TreeFragment import TreeFragment from Cython.Compiler.StringEncoding import EncodedString from Cython.Compiler.Errors import CompileError import Interpreter diff --git a/Cython/Compiler/CodeGeneration.py b/Cython/Compiler/CodeGeneration.py index 9d9d555d..78a61fe0 100644 --- a/Cython/Compiler/CodeGeneration.py +++ b/Cython/Compiler/CodeGeneration.py @@ -1,4 +1,4 @@ -from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform +from Cython.Compiler.Visitor import VisitorTransform, CythonTransform from Cython.Compiler.ModuleNode import ModuleNode from Cython.Compiler.Nodes import * from Cython.Compiler.ExprNodes import * diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index cd9a5a46..30bce950 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -206,10 +206,10 @@ class ExprNode(Node): return self.saved_subexpr_nodes def result(self): - if self.is_temp: - return self.result_code - else: - return self.calculate_result_code() + if self.is_temp: + return self.result_code + else: + return self.calculate_result_code() def result_as(self, type = None): # Return the result code cast to the specified C type. diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index 6ecabcbd..baa34ae1 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -4188,6 +4188,7 @@ class FromImportStatNode(StatNode): self.module.generate_disposal_code(code) + #------------------------------------------------------------------------------------ # # Runtime support code diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index a8691788..26d0db2b 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -1,7 +1,8 @@ -from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform +from Cython.Compiler.Visitor import VisitorTransform, CythonTransform from Cython.Compiler.ModuleNode import ModuleNode from Cython.Compiler.Nodes import * from Cython.Compiler.ExprNodes import * +from Cython.Compiler.UtilNodes import * from Cython.Compiler.TreeFragment import TreeFragment from Cython.Compiler.StringEncoding import EncodedString from Cython.Compiler.Errors import CompileError @@ -409,7 +410,7 @@ class WithTransform(CythonTransform): finally: if EXC: EXIT(None, None, None) - """, temps=[u'MGR', u'EXC', u"EXIT", u"SYS)"], + """, temps=[u'MGR', u'EXC', u"EXIT"], pipeline=[NormalizeTree(None)]) template_with_target = TreeFragment(u""" @@ -428,32 +429,33 @@ class WithTransform(CythonTransform): finally: if EXC: EXIT(None, None, None) - """, temps=[u'MGR', u'EXC', u"EXIT", u"VALUE", u"SYS"], + """, temps=[u'MGR', u'EXC', u"EXIT", u"VALUE"], pipeline=[NormalizeTree(None)]) def visit_WithStatNode(self, node): - excinfo_name = temp_name_handle('EXCINFO') - excinfo_namenode = NameNode(pos=node.pos, name=excinfo_name) - excinfo_target = NameNode(pos=node.pos, name=excinfo_name) + excinfo_tempblock = TempsBlockNode(node.pos, [PyrexTypes.py_object_type], None) if node.target is not None: result = self.template_with_target.substitute({ u'EXPR' : node.manager, u'BODY' : node.body, u'TARGET' : node.target, - u'EXCINFO' : excinfo_namenode + u'EXCINFO' : excinfo_tempblock.get_ref_node(0, node.pos) }, pos=node.pos) # Set except excinfo target to EXCINFO - result.stats[4].body.stats[0].except_clauses[0].excinfo_target = excinfo_target + result.body.stats[4].body.stats[0].except_clauses[0].excinfo_target = ( + excinfo_tempblock.get_ref_node(0, node.pos)) else: result = self.template_without_target.substitute({ u'EXPR' : node.manager, u'BODY' : node.body, - u'EXCINFO' : excinfo_namenode + u'EXCINFO' : excinfo_tempblock.get_ref_node(0, node.pos) }, pos=node.pos) # Set except excinfo target to EXCINFO - result.stats[4].body.stats[0].except_clauses[0].excinfo_target = excinfo_target - - return result.stats + result.body.stats[4].body.stats[0].except_clauses[0].excinfo_target = ( + excinfo_tempblock.get_ref_node(0, node.pos)) + + excinfo_tempblock.body = result + return excinfo_tempblock class DecoratorTransform(CythonTransform): diff --git a/Cython/Compiler/Tests/TestParseTreeTransforms.py b/Cython/Compiler/Tests/TestParseTreeTransforms.py index 28accbbf..93e3d452 100644 --- a/Cython/Compiler/Tests/TestParseTreeTransforms.py +++ b/Cython/Compiler/Tests/TestParseTreeTransforms.py @@ -92,23 +92,23 @@ class TestWithTransform(TransformTest): with x: y = z ** 3 """) - + self.assertCode(u""" - $MGR = x - $EXIT = $MGR.__exit__ - $MGR.__enter__() - $EXC = True + $1_0 = x + $1_2 = $1_0.__exit__ + $1_0.__enter__() + $1_1 = True try: try: y = z ** 3 except: - $EXC = False - if (not $EXIT($EXCINFO)): + $1_1 = False + if (not $1_2($0_0)): raise finally: - if $EXC: - $EXIT(None, None, None) + if $1_1: + $1_2(None, None, None) """, t) @@ -119,21 +119,21 @@ class TestWithTransform(TransformTest): """) self.assertCode(u""" - $MGR = x - $EXIT = $MGR.__exit__ - $VALUE = $MGR.__enter__() - $EXC = True + $1_0 = x + $1_2 = $1_0.__exit__ + $1_3 = $1_0.__enter__() + $1_1 = True try: try: - y = $VALUE + y = $1_3 y = z ** 3 except: - $EXC = False - if (not $EXIT($EXCINFO)): + $1_1 = False + if (not $1_2($0_0)): raise finally: - if $EXC: - $EXIT(None, None, None) + if $1_1: + $1_2(None, None, None) """, t) diff --git a/Cython/Compiler/Tests/TestTreeFragment.py b/Cython/Compiler/Tests/TestTreeFragment.py index 2214cd14..32bcc37f 100644 --- a/Cython/Compiler/Tests/TestTreeFragment.py +++ b/Cython/Compiler/Tests/TestTreeFragment.py @@ -1,6 +1,7 @@ from Cython.TestUtils import CythonTest from Cython.Compiler.TreeFragment import * from Cython.Compiler.Nodes import * +from Cython.Compiler.UtilNodes import * import Cython.Compiler.Naming as Naming class TestTreeFragments(CythonTest): @@ -54,10 +55,10 @@ class TestTreeFragments(CythonTest): x = TMP """) T = F.substitute(temps=[u"TMP"]) - s = T.stats - self.assert_(s[0].expr.name == Naming.temp_prefix + u"1_TMP", s[0].expr.name) - self.assert_(s[1].rhs.name == Naming.temp_prefix + u"1_TMP") - self.assert_(s[0].expr.name != u"TMP") + s = T.body.stats + self.assert_(isinstance(s[0].expr, TempRefNode)) + self.assert_(isinstance(s[1].rhs, TempRefNode)) + self.assert_(s[0].expr.handle is s[1].rhs.handle) if __name__ == "__main__": import unittest diff --git a/Cython/Compiler/TreeFragment.py b/Cython/Compiler/TreeFragment.py index 81865f81..fa301de8 100644 --- a/Cython/Compiler/TreeFragment.py +++ b/Cython/Compiler/TreeFragment.py @@ -8,11 +8,12 @@ from Scanning import PyrexScanner, StringSourceDescriptor from Symtab import BuiltinScope, ModuleScope import Symtab import PyrexTypes -from Visitor import VisitorTransform, temp_name_handle +from Visitor import VisitorTransform from Nodes import Node, StatListNode from ExprNodes import NameNode import Parsing import Main +import UtilNodes """ Support for parsing strings into code trees. @@ -111,12 +112,18 @@ class TemplateTransform(VisitorTransform): def __call__(self, node, substitutions, temps, pos): self.substitutions = substitutions - tempdict = {} - for key in temps: - tempdict[key] = temp_name_handle(key) # pending result_code refactor: Symtab.new_temp(PyrexTypes.py_object_type, key) - self.temp_key_to_entries = tempdict self.pos = pos - return super(TemplateTransform, self).__call__(node) + + + self.temps = temps + if len(temps) > 0: + self.tempblock = UtilNodes.TempsBlockNode(self.get_pos(node), + [PyrexTypes.py_object_type for x in temps], + body=None) + self.tempblock.body = super(TemplateTransform, self).__call__(node) + return self.tempblock + else: + return super(TemplateTransform, self).__call__(node) def get_pos(self, node): if self.pos: @@ -145,13 +152,13 @@ class TemplateTransform(VisitorTransform): def visit_NameNode(self, node): - tempentry = self.temp_key_to_entries.get(node.name) - if tempentry is not None: - # Replace name with temporary - return NameNode(self.get_pos(node), name=tempentry) - # Pending result_code refactor: return NameNode(self.get_pos(node), entry=tempentry) - else: + try: + tmpidx = self.temps.index(node.name) + except: return self.try_substitution(node, node.name) + else: + # Replace name with temporary + return self.tempblock.get_ref_node(tmpidx, self.get_pos(node)) def visit_ExprStatNode(self, node): # If an expression-as-statement consists of only a replaceable diff --git a/Cython/Compiler/UtilNodes.py b/Cython/Compiler/UtilNodes.py new file mode 100644 index 00000000..9def39aa --- /dev/null +++ b/Cython/Compiler/UtilNodes.py @@ -0,0 +1,94 @@ +# +# Nodes used as utilities and support for transforms etc. +# These often make up sets including both Nodes and ExprNodes +# so it is convenient to have them in a seperate module. +# + +import Nodes +import ExprNodes +from Nodes import Node +from ExprNodes import ExprNode + +class TempHandle(object): + temp = None + def __init__(self, type): + self.type = type + +class TempRefNode(ExprNode): + # handle TempHandle + subexprs = [] + + def analyse_types(self, env): + assert self.type == self.handle.type + + def analyse_target_types(self, env): + assert self.type == self.handle.type + + def analyse_target_declaration(self, env): + pass + + def calculate_result_code(self): + result = self.handle.temp + if result is None: result = "" # might be called and overwritten + return result + + def generate_result_code(self, code): + pass + + def generate_assignment_code(self, rhs, code): + if self.type.is_pyobject: + rhs.make_owned_reference(code) + code.put_xdecref(self.result(), self.ctype()) + code.putln('%s = %s;' % (self.result(), rhs.result_as(self.ctype()))) + rhs.generate_post_assignment_code(code) + +class TempsBlockNode(Node): + """ + Creates a block which allocates temporary variables. + This is used by transforms to output constructs that need + to make use of a temporary variable. Simply pass the types + of the needed temporaries to the constructor. + + The variables can be referred to using a TempRefNode + (which can be constructed by calling get_ref_node). + """ + child_attrs = ["body"] + + def __init__(self, pos, types, body): + self.handles = [TempHandle(t) for t in types] + Node.__init__(self, pos, body=body) + + def get_ref_node(self, index, pos): + handle = self.handles[index] + return TempRefNode(pos, handle=handle, type=handle.type) + + def append_temp(self, type, pos): + """ + Appends a new temporary which this block manages, and returns + its index. + """ + self.handle.append(TempHandle(type)) + return len(self.handle) - 1 + + def generate_execution_code(self, code): + for handle in self.handles: + handle.temp = code.funcstate.allocate_temp(handle.type) + self.body.generate_execution_code(code) + for handle in self.handles: + code.funcstate.release_temp(handle.temp) + + def analyse_control_flow(self, env): + self.body.analyse_control_flow(env) + + def analyse_declarations(self, env): + self.body.analyse_declarations(env) + + def analyse_expressions(self, env): + self.body.analyse_expressions(env) + + def generate_function_definitions(self, env, code): + self.body.generate_function_definitions(env, code) + + def annotate(self, code): + self.body.annotate(code) + diff --git a/Cython/Compiler/Visitor.py b/Cython/Compiler/Visitor.py index 6c711ffd..f8aad6b4 100644 --- a/Cython/Compiler/Visitor.py +++ b/Cython/Compiler/Visitor.py @@ -199,23 +199,6 @@ def replace_node(ptr, value): else: getattr(parent, attrname)[listidx] = value -tmpnamectr = 0 -def temp_name_handle(description=None): - global tmpnamectr - tmpnamectr += 1 - if description is not None: - name = u"%d_%s" % (tmpnamectr, description) - else: - name = u"%d" % tmpnamectr - return EncodedString(Naming.temp_prefix + name) - -def get_temp_name_handle_desc(handle): - if not handle.startswith(u"__cyt_"): - return None - else: - idx = handle.find(u"_", 6) - return handle[idx+1:] - class PrintTree(TreeVisitor): """Prints a representation of the tree to standard output. Subclass and override repr_of to provide more information diff --git a/Cython/TestUtils.py b/Cython/TestUtils.py index f3ceda0d..9ed1b3b3 100644 --- a/Cython/TestUtils.py +++ b/Cython/TestUtils.py @@ -47,10 +47,16 @@ class CythonTest(unittest.TestCase): self.assertEqual(len(expected), len(result), "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(expected), u"\n".join(result))) - def assertCode(self, expected, result_tree): + def codeToLines(self, tree): writer = CodeWriter() - writer.write(result_tree) - result_lines = writer.result.lines + writer.write(tree) + return writer.result.lines + + def codeToString(self, tree): + return "\n".join(self.codeToLines(tree)) + + def assertCode(self, expected, result_tree): + result_lines = self.codeToLines(result_tree) expected_lines = strip_common_indent(expected.split("\n"))