def release_temp(self, env):
pass
+
+class PersistentNode(ExprNode):
+ # A PersistentNode is like a CloneNode except it handles the temporary
+ # allocation itself by keeping track of the number of times it has been
+ # used.
+
+ subexprs = ["arg"]
+ temp_counter = 0
+ generate_counter = 0
+ result_code = None
+
+ def __init__(self, arg, uses):
+ self.pos = arg.pos
+ self.arg = arg
+ self.uses = uses
+
+ def analyse_types(self, env):
+ self.arg.analyse_types(env)
+ self.type = self.arg.type
+ self.result_ctype = self.arg.result_ctype
+ self.is_temp = 1
+
+ def generate_evaluation_code(self, code):
+ if self.generate_counter == 0:
+ self.arg.generate_evaluation_code(code)
+ code.putln("%s = %s;" % (
+ self.result_code, self.arg.result_as(self.ctype())))
+ if self.type.is_pyobject:
+ code.put_incref(self.result_code, self.ctype())
+ self.arg.generate_disposal_code(code)
+ self.generate_counter += 1
+
+ def generate_disposal_code(self, code):
+ if self.generate_counter == self.uses:
+ if self.type.is_pyobject:
+ code.put_decref_clear(self.result_code, self.ctype())
+
+ def allocate_temps(self, env, result=None):
+ if self.temp_counter == 0:
+ self.arg.allocate_temps(env)
+ if result is None:
+ self.result_code = env.allocate_temp(self.type)
+ else:
+ self.result_code = result
+ self.arg.release_temp(env)
+ self.temp_counter += 1
+
+ def release_temp(self, env):
+ if self.temp_counter == self.uses:
+ env.release_temp(self.result_code)
#------------------------------------------------------------------------------------
#
from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
+ from Optimize import FlattenInListTransform, SwitchTransform
from Buffer import BufferTransform
from ModuleNode import check_c_classes
create_parse(context),
NormalizeTree(context),
PostParse(context),
+ FlattenInListTransform(),
WithTransform(context),
DecoratorTransform(context),
AnalyseDeclarationsTransform(context),
check_c_classes,
AnalyseExpressionsTransform(context),
BufferTransform(context),
+ SwitchTransform(),
# CreateClosureClasses(context),
create_generate_code(context, options, result)
]
import Nodes
import ExprNodes
+import PyrexTypes
import Visitor
+def unwrap_node(node):
+ while isinstance(node, ExprNodes.PersistentNode):
+ node = node.arg
+ return node
def is_common_value(a, b):
+ a = unwrap_node(a)
+ b = unwrap_node(b)
if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
return a.name == b.name
if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
return False
-class SwitchTransformVisitor(Visitor.VisitorTransform):
-
+class SwitchTransform(Visitor.VisitorTransform):
+ """
+ This transformation tries to turn long if statements into C switch statements.
+ The requirement is that every clause be an (or of) var == value, where the var
+ is common among all clauses and both var and value are not Python objects.
+ """
def extract_conditions(self, cond):
if isinstance(cond, ExprNodes.CoerceToTempNode):
cond = cond.arg
-
+
+ if isinstance(cond, ExprNodes.TypecastNode):
+ cond = cond.operand
+
if (isinstance(cond, ExprNodes.PrimaryCmpNode)
and cond.cascade is None
and cond.operator == '=='
return t1, c1+c2
return None, None
- def is_common_value(self, a, b):
- if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
- return a.name == b.name
- if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
- return not a.is_py_attr and is_common_value(a.obj, b.obj)
- return False
-
def visit_IfStatNode(self, node):
+ self.visitchildren(node)
if len(node.if_clauses) < 3:
return node
common_var = None
var, conditions = self.extract_conditions(if_clause.condition)
if var is None:
return node
- elif common_var is not None and not self.is_common_value(var, common_var):
+ elif common_var is not None and not is_common_value(var, common_var):
return node
else:
common_var = var
test = common_var,
cases = cases,
else_clause = node.else_clause)
-
+
+
def visit_Node(self, node):
self.visitchildren(node)
return node
+
+class FlattenInListTransform(Visitor.VisitorTransform):
+ """
+ This transformation flattens "x in [val1, ..., valn]" into a sequential list
+ of comparisons.
+ """
+
+ def visit_PrimaryCmpNode(self, node):
+ self.visitchildren(node)
+ if node.cascade is not None:
+ return node
+ elif node.operator == 'in':
+ conjunction = 'or'
+ eq_or_neq = '=='
+ elif node.operator == 'not_in':
+ conjunction = 'and'
+ eq_or_neq = '!='
+ else:
+ return node
+
+ args = node.operand2.args
+ if isinstance(node.operand2, ExprNodes.TupleNode) or isinstance(node.operand2, ExprNodes.ListNode):
+ if len(args) == 0:
+ return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
+ else:
+ lhs = ExprNodes.PersistentNode(node.operand1, len(args))
+ conds = []
+ for arg in args:
+ cond = ExprNodes.PrimaryCmpNode(
+ pos = node.pos,
+ operand1 = lhs,
+ operator = eq_or_neq,
+ operand2 = arg,
+ cascade = None)
+ conds.append(ExprNodes.TypecastNode(
+ pos = node.pos,
+ operand = cond,
+ type = PyrexTypes.c_bint_type))
+ def concat(left, right):
+ return ExprNodes.BoolBinopNode(
+ pos = node.pos,
+ operator = conjunction,
+ operand1 = left,
+ operand2 = right)
+ return reduce(concat, conds)
+ else:
+ return node
+
+ def visit_Node(self, node):
+ self.visitchildren(node)
+ return node