remove optimisations for set([...]) and dict([...]) as they do not take side-effects into account: unhashable items lead to pre-mature exit from the loop
instead, transform set(genexp), list(genexp) and dict(genexp) into inlined comprehensions that do not leak loop variables
# this is called with the expr_scope as env
pass
+ def init_scope(self, outer_scope, expr_scope=None):
+ self.expr_scope = expr_scope
-class ComprehensionNode(ExprNode): # (ScopedExprNode)
+
+class ComprehensionNode(ScopedExprNode):
subexprs = ["target"]
child_attrs = ["loop", "append"]
+ # different behaviour in Py2 and Py3: leak loop variables or not?
+ has_local_scope = False # Py2 behaviour as default
+
def infer_type(self, env):
return self.target.infer_type(env)
def analyse_declarations(self, env):
self.append.target = self # this is used in the PyList_Append of the inner loop
- self.loop.analyse_declarations(env)
-# self.expr_scope = Symtab.GeneratorExpressionScope(env)
-# self.loop.analyse_declarations(self.expr_scope)
+ self.init_scope(env)
+ if self.expr_scope is not None:
+ self.loop.analyse_declarations(self.expr_scope)
+ else:
+ self.loop.analyse_declarations(env)
+
+ def init_scope(self, outer_scope, expr_scope=None):
+ if expr_scope is not None:
+ self.expr_scope = expr_scope
+ elif self.has_local_scope:
+ self.expr_scope = Symtab.GeneratorExpressionScope(outer_scope)
+ else:
+ self.expr_scope = None
def analyse_types(self, env):
self.target.analyse_expressions(env)
self.type = self.target.type
- self.loop.analyse_expressions(env)
+ if not self.has_local_scope:
+ self.loop.analyse_expressions(env)
-# def analyse_scoped_expressions(self, env):
-# self.loop.analyse_expressions(env)
+ def analyse_scoped_expressions(self, env):
+ if self.has_local_scope:
+ self.loop.analyse_expressions(env)
def may_be_none(self):
return False
self.loop.annotate(code)
-class ComprehensionAppendNode(ExprNode):
+class ComprehensionAppendNode(Node):
# Need to be careful to avoid infinite recursion:
# target must not be in child_attrs/subexprs
- subexprs = ['expr']
+
+ child_attrs = ['expr']
type = PyrexTypes.c_int_type
- def analyse_types(self, env):
- self.expr.analyse_types(env)
+ def analyse_expressions(self, env):
+ self.expr.analyse_expressions(env)
if not self.expr.type.is_pyobject:
self.expr = self.expr.coerce_to_pyobject(env)
- self.is_temp = 1
- def generate_result_code(self, code):
+ def generate_execution_code(self, code):
if self.target.type is list_type:
function = "PyList_Append"
elif self.target.type is set_type:
else:
raise InternalError(
"Invalid type for comprehension node: %s" % self.target.type)
-
- code.putln("%s = %s(%s, (PyObject*)%s); %s" %
- (self.result(),
- function,
- self.target.result(),
- self.expr.result(),
- code.error_goto_if(self.result(), self.pos)))
+
+ self.expr.generate_evaluation_code(code)
+ code.putln(code.error_goto_if("%s(%s, (PyObject*)%s)" % (
+ function,
+ self.target.result(),
+ self.expr.result()
+ ), self.pos))
+ self.expr.generate_disposal_code(code)
+ self.expr.free_temps(code)
+
+ def generate_function_definitions(self, env, code):
+ self.expr.generate_function_definitions(env, code)
+
+ def annotate(self, code):
+ self.expr.annotate(code)
class DictComprehensionAppendNode(ComprehensionAppendNode):
- subexprs = ['key_expr', 'value_expr']
+ child_attrs = ['key_expr', 'value_expr']
- def analyse_types(self, env):
- self.key_expr.analyse_types(env)
+ def analyse_expressions(self, env):
+ self.key_expr.analyse_expressions(env)
if not self.key_expr.type.is_pyobject:
self.key_expr = self.key_expr.coerce_to_pyobject(env)
- self.value_expr.analyse_types(env)
+ self.value_expr.analyse_expressions(env)
if not self.value_expr.type.is_pyobject:
self.value_expr = self.value_expr.coerce_to_pyobject(env)
- self.is_temp = 1
- def generate_result_code(self, code):
- code.putln("%s = PyDict_SetItem(%s, (PyObject*)%s, (PyObject*)%s); %s" %
- (self.result(),
- self.target.result(),
- self.key_expr.result(),
- self.value_expr.result(),
- code.error_goto_if(self.result(), self.pos)))
+ def generate_execution_code(self, code):
+ self.key_expr.generate_evaluation_code(code)
+ self.value_expr.generate_evaluation_code(code)
+ code.putln(code.error_goto_if("PyDict_SetItem(%s, (PyObject*)%s, (PyObject*)%s)" % (
+ self.target.result(),
+ self.key_expr.result(),
+ self.value_expr.result()
+ ), self.pos))
+ self.key_expr.generate_disposal_code(code)
+ self.key_expr.free_temps(code)
+ self.value_expr.generate_disposal_code(code)
+ self.value_expr.free_temps(code)
+
+ def generate_function_definitions(self, env, code):
+ self.key_expr.generate_function_definitions(env, code)
+ self.value_expr.generate_function_definitions(env, code)
+
+ def annotate(self, code):
+ self.key_expr.annotate(code)
+ self.value_expr.annotate(code)
class GeneratorExpressionNode(ScopedExprNode):
type = py_object_type
def analyse_declarations(self, env):
- self.expr_scope = Symtab.GeneratorExpressionScope(env)
+ self.init_scope(env)
self.loop.analyse_declarations(self.expr_scope)
+ def init_scope(self, outer_scope, expr_scope=None):
+ if expr_scope is not None:
+ self.expr_scope = expr_scope
+ else:
+ self.expr_scope = Symtab.GeneratorExpressionScope(outer_scope)
+
def analyse_types(self, env):
self.is_temp = True
# specific handlers for simple call nodes
- def _handle_simple_function_set(self, node, pos_args):
- """Replace set([a,b,...]) by a literal set {a,b,...} and
- set([ x for ... ]) by a literal { x for ... }.
- """
- arg_count = len(pos_args)
- if arg_count == 0:
- return ExprNodes.SetNode(node.pos, args=[],
- type=Builtin.set_type)
- if arg_count > 1:
- return node
- iterable = pos_args[0]
- if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)):
- return ExprNodes.SetNode(node.pos, args=iterable.args)
- elif isinstance(iterable, ExprNodes.ComprehensionNode) and \
- isinstance(iterable.target, (ExprNodes.ListNode,
- ExprNodes.SetNode)):
- iterable.target = ExprNodes.SetNode(node.pos, args=[])
- iterable.pos = node.pos
- return iterable
- else:
- return node
-
- def _handle_simple_function_dict(self, node, pos_args):
- """Replace dict([ (a,b) for ... ]) by a literal { a:b for ... }.
- """
- if len(pos_args) != 1:
- return node
- arg = pos_args[0]
- if isinstance(arg, ExprNodes.ComprehensionNode) and \
- isinstance(arg.target, (ExprNodes.ListNode,
- ExprNodes.SetNode)):
- append_node = arg.append
- if isinstance(append_node.expr, (ExprNodes.TupleNode, ExprNodes.ListNode)) and \
- len(append_node.expr.args) == 2:
- key_node, value_node = append_node.expr.args
- target_node = ExprNodes.DictNode(
- pos=arg.target.pos, key_value_pairs=[])
- new_append_node = ExprNodes.DictComprehensionAppendNode(
- append_node.pos, target=target_node,
- key_expr=key_node, value_expr=value_node)
- arg.target = target_node
- arg.type = target_node.type
- replace_in = Visitor.RecursiveNodeReplacer(append_node, new_append_node)
- return replace_in(arg)
- return node
-
def _handle_simple_function_float(self, node, pos_args):
if len(pos_args) == 0:
return ExprNodes.FloatNode(node.pos, value='0.0')
rhs = ExprNodes.BoolNode(yield_node.pos, value = not is_any,
constant_result = not is_any))
- Visitor.RecursiveNodeReplacer(yield_node, test_node).visitchildren(loop_node)
+ Visitor.recursively_replace_node(loop_node, yield_node, test_node)
return ExprNodes.InlinedGeneratorExpressionNode(
gen_expr_node.pos, loop = loop_node, result_node = result_ref,
rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
)
- Visitor.RecursiveNodeReplacer(yield_node, add_node).visitchildren(loop_node)
+ Visitor.recursively_replace_node(loop_node, yield_node, add_node)
exec_code = Nodes.StatListNode(
node.pos,
gen_expr_node.pos, loop = exec_code, result_node = result_ref,
expr_scope = gen_expr_node.expr_scope, orig_func = 'sum')
+ def _handle_simple_function_list(self, node, pos_args):
+ if len(pos_args) == 0:
+ return ExprNodes.ListNode(node.pos, args=[], constant_result=[])
+ return self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
+
+ def _handle_simple_function_set(self, node, pos_args):
+ if len(pos_args) == 0:
+ return ExprNodes.SetNode(node.pos, args=[], constant_result=set())
+ return self._transform_list_set_genexpr(node, pos_args, ExprNodes.SetNode)
+
+ def _transform_list_set_genexpr(self, node, pos_args, container_node_class):
+ """Replace set(genexpr) and list(genexpr) by a literal comprehension.
+ """
+ if len(pos_args) > 1:
+ return node
+ if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
+ return node
+ gen_expr_node = pos_args[0]
+ loop_node = gen_expr_node.loop
+
+ yield_node = self._find_single_yield_node(loop_node)
+ if yield_node is None:
+ return node
+ yield_expression = yield_node.arg
+
+ target_node = container_node_class(node.pos, args=[])
+ append_node = ExprNodes.ComprehensionAppendNode(
+ yield_node.pos,
+ expr = yield_expression,
+ target = ExprNodes.CloneNode(target_node),
+ is_temp = 1) # FIXME: why is this an ExprNode?
+
+ Visitor.recursively_replace_node(loop_node, yield_node, append_node)
+
+ setcomp = ExprNodes.ComprehensionNode(
+ node.pos,
+ has_local_scope = True,
+ expr_scope = gen_expr_node.expr_scope,
+ loop = loop_node,
+ append = append_node,
+ target = target_node)
+ append_node.target = setcomp
+ return setcomp
+
+ def _handle_simple_function_dict(self, node, pos_args):
+ """Replace dict( (a,b) for ... ) by a literal { a:b for ... }.
+ """
+ if len(pos_args) == 0:
+ return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={})
+ if len(pos_args) > 1:
+ return node
+ if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
+ return node
+ gen_expr_node = pos_args[0]
+ loop_node = gen_expr_node.loop
+
+ yield_node = self._find_single_yield_node(loop_node)
+ if yield_node is None:
+ return node
+ yield_expression = yield_node.arg
+
+ if not isinstance(yield_expression, ExprNodes.TupleNode):
+ return node
+ if len(yield_expression.args) != 2:
+ return node
+
+ target_node = ExprNodes.DictNode(node.pos, key_value_pairs=[])
+ append_node = ExprNodes.DictComprehensionAppendNode(
+ yield_node.pos,
+ key_expr = yield_expression.args[0],
+ value_expr = yield_expression.args[1],
+ target = ExprNodes.CloneNode(target_node),
+ is_temp = 1) # FIXME: why is this an ExprNode?
+
+ Visitor.recursively_replace_node(loop_node, yield_node, append_node)
+
+ dictcomp = ExprNodes.ComprehensionNode(
+ node.pos,
+ has_local_scope = True,
+ expr_scope = gen_expr_node.expr_scope,
+ loop = loop_node,
+ append = append_node,
+ target = target_node)
+ append_node.target = dictcomp
+ return dictcomp
+
+
+
+ arg = pos_args[0]
+ if isinstance(arg, ExprNodes.ComprehensionNode) and \
+ isinstance(arg.target, (ExprNodes.ListNode,
+ ExprNodes.SetNode)):
+ append_node = arg.append
+ if isinstance(append_node.expr, (ExprNodes.TupleNode, ExprNodes.ListNode)) and \
+ len(append_node.expr.args) == 2:
+ key_node, value_node = append_node.expr.args
+ target_node = ExprNodes.DictNode(
+ pos=arg.target.pos, key_value_pairs=[])
+ new_append_node = ExprNodes.DictComprehensionAppendNode(
+ append_node.pos, target=target_node,
+ key_expr=key_node, value_expr=value_node)
+ arg.target = target_node
+ arg.type = target_node.type
+ replace_in = Visitor.RecursiveNodeReplacer(append_node, new_append_node)
+ return replace_in(arg)
+ return node
+
# specific handlers for general call nodes
def _handle_general_function_dict(self, node, pos_args, kwargs):
return node
def visit_ScopedExprNode(self, node):
- node.expr_scope.infer_types()
- node.analyse_scoped_expressions(node.expr_scope)
+ if node.expr_scope is not None:
+ node.expr_scope.infer_types()
+ node.analyse_scoped_expressions(node.expr_scope)
self.visitchildren(node)
return node
target = ExprNodes.ListNode(pos, args = [])
append = ExprNodes.ComprehensionAppendNode(
pos, expr=expr, target=ExprNodes.CloneNode(target))
- loop = p_comp_for(s, Nodes.ExprStatNode(append.pos, expr=append))
+ loop = p_comp_for(s, append)
s.expect(']')
return ExprNodes.ComprehensionNode(
pos, loop=loop, append=append, target=target)
target = ExprNodes.SetNode(pos, args=[])
append = ExprNodes.ComprehensionAppendNode(
item.pos, expr=item, target=ExprNodes.CloneNode(target))
- loop = p_comp_for(s, Nodes.ExprStatNode(append.pos, expr=append))
+ loop = p_comp_for(s, append)
s.expect('}')
return ExprNodes.ComprehensionNode(
pos, loop=loop, append=append, target=target)
append = ExprNodes.DictComprehensionAppendNode(
item.pos, key_expr=key, value_expr=value,
target=ExprNodes.CloneNode(target))
- loop = p_comp_for(s, Nodes.ExprStatNode(append.pos, expr=append))
+ loop = p_comp_for(s, append)
s.expect('}')
return ExprNodes.ComprehensionNode(
pos, loop=loop, append=append, target=target)
else:
return node
-
+def recursively_replace_node(tree, old_node, new_node):
+ replace_in = RecursiveNodeReplacer(old_node, new_node)
+ replace_in(tree)
# Utils
-__doc__ = u"""
->>> type(smoketest_dict()) is dict
-True
->>> type(smoketest_list()) is dict
-True
-
->>> sorted(smoketest_dict().items())
-[(2, 0), (4, 4), (6, 8)]
->>> sorted(smoketest_list().items())
-[(2, 0), (4, 4), (6, 8)]
-
->>> list(typed().items())
-[(A, 1), (A, 1), (A, 1)]
->>> sorted(iterdict().items())
-[(1, 'a'), (2, 'b'), (3, 'c')]
-"""
cimport cython
-def smoketest_dict():
- return { x+2:x*2
- for x in range(5)
- if x % 2 == 0 }
+def dictcomp():
+ """
+ >>> sorted(dictcomp().items())
+ [(2, 0), (4, 4), (6, 8)]
+ >>> sorted(dictcomp().items())
+ [(2, 0), (4, 4), (6, 8)]
+ """
+ x = 'abc'
+ result = { x+2:x*2
+ for x in range(5)
+ if x % 2 == 0 }
+ assert x != 'abc'
+ return result
@cython.test_fail_if_path_exists(
- "//ComprehensionNode//ComprehensionAppendNode",
- "//SimpleCallNode//ComprehensionNode")
+ "//GeneratorExpressionNode",
+ "//SimpleCallNode")
@cython.test_assert_path_exists(
"//ComprehensionNode",
"//ComprehensionNode//DictComprehensionAppendNode")
-def smoketest_list():
- return dict([ (x+2,x*2)
- for x in range(5)
- if x % 2 == 0 ])
+def genexpr():
+ """
+ >>> type(genexpr()) is dict
+ True
+ >>> type(genexpr()) is dict
+ True
+ """
+ x = 'abc'
+ result = dict( (x+2,x*2)
+ for x in range(5)
+ if x % 2 == 0 )
+ assert x == 'abc'
+ return result
cdef class A:
def __repr__(self): return u"A"
def __richcmp__(one, other, op): return one is other
def __hash__(self): return id(self) % 65536
-def typed():
+def typed_dictcomp():
+ """
+ >>> list(typed_dictcomp().items())
+ [(A, 1), (A, 1), (A, 1)]
+ """
cdef A obj
return {obj:1 for obj in [A(), A(), A()]}
-def iterdict():
+def iterdict_dictcomp():
+ """
+ >>> sorted(iterdict_dictcomp().items())
+ [(1, 'a'), (2, 'b'), (3, 'c')]
+ """
cdef dict d = dict(a=1,b=2,c=3)
return {d[key]:key for key in d}
>>> smoketest()
[0, 4, 8]
"""
- print [x*2 for x in range(5) if x % 2 == 0]
+ x = 'abc'
+ result = [x*2 for x in range(5) if x % 2 == 0]
+ assert x != 'abc'
+ return result
+
+def list_genexp():
+ """
+ >>> list_genexp()
+ [0, 4, 8]
+ """
+ x = 'abc'
+ result = list(x*2 for x in range(5) if x % 2 == 0)
+ assert x == 'abc'
+ return result
def int_runvar():
"""
-__doc__ = u"""
->>> type(smoketest_set()) is not list
-True
->>> type(smoketest_set()) is _set
-True
->>> type(smoketest_list()) is _set
-True
-
->>> sorted(smoketest_set())
-[0, 4, 8]
->>> sorted(smoketest_list())
-[0, 4, 8]
-
->>> list(typed())
-[A, A, A]
->>> sorted(iterdict())
-[1, 2, 3]
-"""
cimport cython
# Py2.3 doesn't have the set type, but Cython does :)
_set = set
-def smoketest_set():
+def setcomp():
+ """
+ >>> type(setcomp()) is not list
+ True
+ >>> type(setcomp()) is _set
+ True
+ >>> sorted(setcomp())
+ [0, 4, 8]
+ """
return { x*2
for x in range(5)
if x % 2 == 0 }
-@cython.test_fail_if_path_exists("//SimpleCallNode//ComprehensionNode")
-@cython.test_assert_path_exists("//ComprehensionNode",
- "//ComprehensionNode//ComprehensionAppendNode")
-def smoketest_list():
- return set([ x*2
+@cython.test_fail_if_path_exists(
+ "//GeneratorExpressionNode",
+ "//SimpleCallNode")
+@cython.test_assert_path_exists(
+ "//ComprehensionNode",
+ "//ComprehensionNode//ComprehensionAppendNode")
+def genexp_set():
+ """
+ >>> type(genexp_set()) is _set
+ True
+ >>> sorted(genexp_set())
+ [0, 4, 8]
+ """
+ return set( x*2
for x in range(5)
- if x % 2 == 0 ])
+ if x % 2 == 0 )
cdef class A:
def __repr__(self): return u"A"
def __hash__(self): return id(self) % 65536
def typed():
+ """
+ >>> list(typed())
+ [A, A, A]
+ """
cdef A obj
return {obj for obj in {A(), A(), A()}}
def iterdict():
+ """
+ >>> sorted(iterdict())
+ [1, 2, 3]
+ """
cdef dict d = dict(a=1,b=2,c=3)
return {d[key] for key in d}