-from Cython.Compiler.Visitor import TreeVisitor, get_temp_name_handle_desc
+from Cython.Compiler.Visitor import TreeVisitor
from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import *
self.result = result
self.numindents = 0
self.tempnames = {}
+ self.tempblockindex = 0
def write(self, tree):
self.visit(tree)
self.startline(s)
self.endline()
- def putname(self, name):
- tmpdesc = get_temp_name_handle_desc(name)
- if tmpdesc is not None:
- name = self.tempnames.setdefault(tmpdesc, u"$" +tmpdesc)
- self.put(name)
-
def comma_seperated_list(self, items, output_rhs=False):
if len(items) > 0:
for item in items[:-1]:
self.endline()
def visit_NameNode(self, node):
- self.putname(node.name)
+ self.put(node.name)
def visit_IntNode(self, node):
self.put(node.value)
self.visit(node.operand)
self.put(u")")
+ def visit_TempsBlockNode(self, node):
+ """
+ Temporaries are output like $1_1', where the first number is
+ an index of the TempsBlockNode and the second number is an index
+ of the temporary which that block allocates.
+ """
+ idx = 0
+ for handle in node.handles:
+ self.tempnames[handle] = "$%d_%d" % (self.tempblockindex, idx)
+ idx += 1
+ self.tempblockindex += 1
+ self.visit(node.body)
+
+ def visit_TempRefNode(self, node):
+ self.put(self.tempnames[node.handle])
-from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform
+from Cython.Compiler.Visitor import VisitorTransform, CythonTransform
from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import *
-from Cython.Compiler.TreeFragment import TreeFragment
from Cython.Compiler.StringEncoding import EncodedString
from Cython.Compiler.Errors import CompileError
import Interpreter
-from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform
+from Cython.Compiler.Visitor import VisitorTransform, CythonTransform
from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import *
return self.saved_subexpr_nodes
def result(self):
- if self.is_temp:
- return self.result_code
- else:
- return self.calculate_result_code()
+ if self.is_temp:
+ return self.result_code
+ else:
+ return self.calculate_result_code()
def result_as(self, type = None):
# Return the result code cast to the specified C type.
self.module.generate_disposal_code(code)
+
#------------------------------------------------------------------------------------
#
# Runtime support code
-from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform
+from Cython.Compiler.Visitor import VisitorTransform, CythonTransform
from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import *
+from Cython.Compiler.UtilNodes import *
from Cython.Compiler.TreeFragment import TreeFragment
from Cython.Compiler.StringEncoding import EncodedString
from Cython.Compiler.Errors import CompileError
finally:
if EXC:
EXIT(None, None, None)
- """, temps=[u'MGR', u'EXC', u"EXIT", u"SYS)"],
+ """, temps=[u'MGR', u'EXC', u"EXIT"],
pipeline=[NormalizeTree(None)])
template_with_target = TreeFragment(u"""
finally:
if EXC:
EXIT(None, None, None)
- """, temps=[u'MGR', u'EXC', u"EXIT", u"VALUE", u"SYS"],
+ """, temps=[u'MGR', u'EXC', u"EXIT", u"VALUE"],
pipeline=[NormalizeTree(None)])
def visit_WithStatNode(self, node):
- excinfo_name = temp_name_handle('EXCINFO')
- excinfo_namenode = NameNode(pos=node.pos, name=excinfo_name)
- excinfo_target = NameNode(pos=node.pos, name=excinfo_name)
+ excinfo_tempblock = TempsBlockNode(node.pos, [PyrexTypes.py_object_type], None)
if node.target is not None:
result = self.template_with_target.substitute({
u'EXPR' : node.manager,
u'BODY' : node.body,
u'TARGET' : node.target,
- u'EXCINFO' : excinfo_namenode
+ u'EXCINFO' : excinfo_tempblock.get_ref_node(0, node.pos)
}, pos=node.pos)
# Set except excinfo target to EXCINFO
- result.stats[4].body.stats[0].except_clauses[0].excinfo_target = excinfo_target
+ result.body.stats[4].body.stats[0].except_clauses[0].excinfo_target = (
+ excinfo_tempblock.get_ref_node(0, node.pos))
else:
result = self.template_without_target.substitute({
u'EXPR' : node.manager,
u'BODY' : node.body,
- u'EXCINFO' : excinfo_namenode
+ u'EXCINFO' : excinfo_tempblock.get_ref_node(0, node.pos)
}, pos=node.pos)
# Set except excinfo target to EXCINFO
- result.stats[4].body.stats[0].except_clauses[0].excinfo_target = excinfo_target
-
- return result.stats
+ result.body.stats[4].body.stats[0].except_clauses[0].excinfo_target = (
+ excinfo_tempblock.get_ref_node(0, node.pos))
+
+ excinfo_tempblock.body = result
+ return excinfo_tempblock
class DecoratorTransform(CythonTransform):
with x:
y = z ** 3
""")
-
+
self.assertCode(u"""
- $MGR = x
- $EXIT = $MGR.__exit__
- $MGR.__enter__()
- $EXC = True
+ $1_0 = x
+ $1_2 = $1_0.__exit__
+ $1_0.__enter__()
+ $1_1 = True
try:
try:
y = z ** 3
except:
- $EXC = False
- if (not $EXIT($EXCINFO)):
+ $1_1 = False
+ if (not $1_2($0_0)):
raise
finally:
- if $EXC:
- $EXIT(None, None, None)
+ if $1_1:
+ $1_2(None, None, None)
""", t)
""")
self.assertCode(u"""
- $MGR = x
- $EXIT = $MGR.__exit__
- $VALUE = $MGR.__enter__()
- $EXC = True
+ $1_0 = x
+ $1_2 = $1_0.__exit__
+ $1_3 = $1_0.__enter__()
+ $1_1 = True
try:
try:
- y = $VALUE
+ y = $1_3
y = z ** 3
except:
- $EXC = False
- if (not $EXIT($EXCINFO)):
+ $1_1 = False
+ if (not $1_2($0_0)):
raise
finally:
- if $EXC:
- $EXIT(None, None, None)
+ if $1_1:
+ $1_2(None, None, None)
""", t)
from Cython.TestUtils import CythonTest
from Cython.Compiler.TreeFragment import *
from Cython.Compiler.Nodes import *
+from Cython.Compiler.UtilNodes import *
import Cython.Compiler.Naming as Naming
class TestTreeFragments(CythonTest):
x = TMP
""")
T = F.substitute(temps=[u"TMP"])
- s = T.stats
- self.assert_(s[0].expr.name == Naming.temp_prefix + u"1_TMP", s[0].expr.name)
- self.assert_(s[1].rhs.name == Naming.temp_prefix + u"1_TMP")
- self.assert_(s[0].expr.name != u"TMP")
+ s = T.body.stats
+ self.assert_(isinstance(s[0].expr, TempRefNode))
+ self.assert_(isinstance(s[1].rhs, TempRefNode))
+ self.assert_(s[0].expr.handle is s[1].rhs.handle)
if __name__ == "__main__":
import unittest
from Symtab import BuiltinScope, ModuleScope
import Symtab
import PyrexTypes
-from Visitor import VisitorTransform, temp_name_handle
+from Visitor import VisitorTransform
from Nodes import Node, StatListNode
from ExprNodes import NameNode
import Parsing
import Main
+import UtilNodes
"""
Support for parsing strings into code trees.
def __call__(self, node, substitutions, temps, pos):
self.substitutions = substitutions
- tempdict = {}
- for key in temps:
- tempdict[key] = temp_name_handle(key) # pending result_code refactor: Symtab.new_temp(PyrexTypes.py_object_type, key)
- self.temp_key_to_entries = tempdict
self.pos = pos
- return super(TemplateTransform, self).__call__(node)
+
+
+ self.temps = temps
+ if len(temps) > 0:
+ self.tempblock = UtilNodes.TempsBlockNode(self.get_pos(node),
+ [PyrexTypes.py_object_type for x in temps],
+ body=None)
+ self.tempblock.body = super(TemplateTransform, self).__call__(node)
+ return self.tempblock
+ else:
+ return super(TemplateTransform, self).__call__(node)
def get_pos(self, node):
if self.pos:
def visit_NameNode(self, node):
- tempentry = self.temp_key_to_entries.get(node.name)
- if tempentry is not None:
- # Replace name with temporary
- return NameNode(self.get_pos(node), name=tempentry)
- # Pending result_code refactor: return NameNode(self.get_pos(node), entry=tempentry)
- else:
+ try:
+ tmpidx = self.temps.index(node.name)
+ except:
return self.try_substitution(node, node.name)
+ else:
+ # Replace name with temporary
+ return self.tempblock.get_ref_node(tmpidx, self.get_pos(node))
def visit_ExprStatNode(self, node):
# If an expression-as-statement consists of only a replaceable
--- /dev/null
+#
+# Nodes used as utilities and support for transforms etc.
+# These often make up sets including both Nodes and ExprNodes
+# so it is convenient to have them in a seperate module.
+#
+
+import Nodes
+import ExprNodes
+from Nodes import Node
+from ExprNodes import ExprNode
+
+class TempHandle(object):
+ temp = None
+ def __init__(self, type):
+ self.type = type
+
+class TempRefNode(ExprNode):
+ # handle TempHandle
+ subexprs = []
+
+ def analyse_types(self, env):
+ assert self.type == self.handle.type
+
+ def analyse_target_types(self, env):
+ assert self.type == self.handle.type
+
+ def analyse_target_declaration(self, env):
+ pass
+
+ def calculate_result_code(self):
+ result = self.handle.temp
+ if result is None: result = "<error>" # might be called and overwritten
+ return result
+
+ def generate_result_code(self, code):
+ pass
+
+ def generate_assignment_code(self, rhs, code):
+ if self.type.is_pyobject:
+ rhs.make_owned_reference(code)
+ code.put_xdecref(self.result(), self.ctype())
+ code.putln('%s = %s;' % (self.result(), rhs.result_as(self.ctype())))
+ rhs.generate_post_assignment_code(code)
+
+class TempsBlockNode(Node):
+ """
+ Creates a block which allocates temporary variables.
+ This is used by transforms to output constructs that need
+ to make use of a temporary variable. Simply pass the types
+ of the needed temporaries to the constructor.
+
+ The variables can be referred to using a TempRefNode
+ (which can be constructed by calling get_ref_node).
+ """
+ child_attrs = ["body"]
+
+ def __init__(self, pos, types, body):
+ self.handles = [TempHandle(t) for t in types]
+ Node.__init__(self, pos, body=body)
+
+ def get_ref_node(self, index, pos):
+ handle = self.handles[index]
+ return TempRefNode(pos, handle=handle, type=handle.type)
+
+ def append_temp(self, type, pos):
+ """
+ Appends a new temporary which this block manages, and returns
+ its index.
+ """
+ self.handle.append(TempHandle(type))
+ return len(self.handle) - 1
+
+ def generate_execution_code(self, code):
+ for handle in self.handles:
+ handle.temp = code.funcstate.allocate_temp(handle.type)
+ self.body.generate_execution_code(code)
+ for handle in self.handles:
+ code.funcstate.release_temp(handle.temp)
+
+ def analyse_control_flow(self, env):
+ self.body.analyse_control_flow(env)
+
+ def analyse_declarations(self, env):
+ self.body.analyse_declarations(env)
+
+ def analyse_expressions(self, env):
+ self.body.analyse_expressions(env)
+
+ def generate_function_definitions(self, env, code):
+ self.body.generate_function_definitions(env, code)
+
+ def annotate(self, code):
+ self.body.annotate(code)
+
else:
getattr(parent, attrname)[listidx] = value
-tmpnamectr = 0
-def temp_name_handle(description=None):
- global tmpnamectr
- tmpnamectr += 1
- if description is not None:
- name = u"%d_%s" % (tmpnamectr, description)
- else:
- name = u"%d" % tmpnamectr
- return EncodedString(Naming.temp_prefix + name)
-
-def get_temp_name_handle_desc(handle):
- if not handle.startswith(u"__cyt_"):
- return None
- else:
- idx = handle.find(u"_", 6)
- return handle[idx+1:]
-
class PrintTree(TreeVisitor):
"""Prints a representation of the tree to standard output.
Subclass and override repr_of to provide more information
self.assertEqual(len(expected), len(result),
"Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(expected), u"\n".join(result)))
- def assertCode(self, expected, result_tree):
+ def codeToLines(self, tree):
writer = CodeWriter()
- writer.write(result_tree)
- result_lines = writer.result.lines
+ writer.write(tree)
+ return writer.result.lines
+
+ def codeToString(self, tree):
+ return "\n".join(self.codeToLines(tree))
+
+ def assertCode(self, expected, result_tree):
+ result_lines = self.codeToLines(result_tree)
expected_lines = strip_common_indent(expected.split("\n"))