Don't store yield expressions list as yields can be copied
[cython.git] / Cython / Compiler / ParseTreeTransforms.py
1
2 import cython
3 cython.declare(PyrexTypes=object, Naming=object, ExprNodes=object, Nodes=object,
4                Options=object, UtilNodes=object, ModuleNode=object,
5                LetNode=object, LetRefNode=object, TreeFragment=object,
6                TemplateTransform=object, EncodedString=object,
7                error=object, warning=object, copy=object)
8
9 import PyrexTypes
10 import Naming
11 import ExprNodes
12 import Nodes
13 import Options
14
15 from Cython.Compiler.Visitor import VisitorTransform, TreeVisitor
16 from Cython.Compiler.Visitor import CythonTransform, EnvTransform, ScopeTrackingTransform
17 from Cython.Compiler.ModuleNode import ModuleNode
18 from Cython.Compiler.UtilNodes import LetNode, LetRefNode
19 from Cython.Compiler.TreeFragment import TreeFragment, TemplateTransform
20 from Cython.Compiler.StringEncoding import EncodedString
21 from Cython.Compiler.Errors import error, warning, CompileError, InternalError
22
23 import copy
24
25
26 class NameNodeCollector(TreeVisitor):
27     """Collect all NameNodes of a (sub-)tree in the ``name_nodes``
28     attribute.
29     """
30     def __init__(self):
31         super(NameNodeCollector, self).__init__()
32         self.name_nodes = []
33
34     def visit_NameNode(self, node):
35         self.name_nodes.append(node)
36
37     def visit_Node(self, node):
38         self._visitchildren(node, None)
39
40
41 class SkipDeclarations(object):
42     """
43     Variable and function declarations can often have a deep tree structure,
44     and yet most transformations don't need to descend to this depth.
45
46     Declaration nodes are removed after AnalyseDeclarationsTransform, so there
47     is no need to use this for transformations after that point.
48     """
49     def visit_CTypeDefNode(self, node):
50         return node
51
52     def visit_CVarDefNode(self, node):
53         return node
54
55     def visit_CDeclaratorNode(self, node):
56         return node
57
58     def visit_CBaseTypeNode(self, node):
59         return node
60
61     def visit_CEnumDefNode(self, node):
62         return node
63
64     def visit_CStructOrUnionDefNode(self, node):
65         return node
66
67
68 class NormalizeTree(CythonTransform):
69     """
70     This transform fixes up a few things after parsing
71     in order to make the parse tree more suitable for
72     transforms.
73
74     a) After parsing, blocks with only one statement will
75     be represented by that statement, not by a StatListNode.
76     When doing transforms this is annoying and inconsistent,
77     as one cannot in general remove a statement in a consistent
78     way and so on. This transform wraps any single statements
79     in a StatListNode containing a single statement.
80
81     b) The PassStatNode is a noop and serves no purpose beyond
82     plugging such one-statement blocks; i.e., once parsed a
83 `    "pass" can just as well be represented using an empty
84     StatListNode. This means less special cases to worry about
85     in subsequent transforms (one always checks to see if a
86     StatListNode has no children to see if the block is empty).
87     """
88
89     def __init__(self, context):
90         super(NormalizeTree, self).__init__(context)
91         self.is_in_statlist = False
92         self.is_in_expr = False
93
94     def visit_ExprNode(self, node):
95         stacktmp = self.is_in_expr
96         self.is_in_expr = True
97         self.visitchildren(node)
98         self.is_in_expr = stacktmp
99         return node
100
101     def visit_StatNode(self, node, is_listcontainer=False):
102         stacktmp = self.is_in_statlist
103         self.is_in_statlist = is_listcontainer
104         self.visitchildren(node)
105         self.is_in_statlist = stacktmp
106         if not self.is_in_statlist and not self.is_in_expr:
107             return Nodes.StatListNode(pos=node.pos, stats=[node])
108         else:
109             return node
110
111     def visit_StatListNode(self, node):
112         self.is_in_statlist = True
113         self.visitchildren(node)
114         self.is_in_statlist = False
115         return node
116
117     def visit_ParallelAssignmentNode(self, node):
118         return self.visit_StatNode(node, True)
119
120     def visit_CEnumDefNode(self, node):
121         return self.visit_StatNode(node, True)
122
123     def visit_CStructOrUnionDefNode(self, node):
124         return self.visit_StatNode(node, True)
125
126     # Eliminate PassStatNode
127     def visit_PassStatNode(self, node):
128         if not self.is_in_statlist:
129             return Nodes.StatListNode(pos=node.pos, stats=[])
130         else:
131             return []
132
133     def visit_CDeclaratorNode(self, node):
134         return node
135
136
137 class PostParseError(CompileError): pass
138
139 # error strings checked by unit tests, so define them
140 ERR_CDEF_INCLASS = 'Cannot assign default value to fields in cdef classes, structs or unions'
141 ERR_BUF_DEFAULTS = 'Invalid buffer defaults specification (see docs)'
142 ERR_INVALID_SPECIALATTR_TYPE = 'Special attributes must not have a type declared'
143 class PostParse(ScopeTrackingTransform):
144     """
145     Basic interpretation of the parse tree, as well as validity
146     checking that can be done on a very basic level on the parse
147     tree (while still not being a problem with the basic syntax,
148     as such).
149
150     Specifically:
151     - Default values to cdef assignments are turned into single
152     assignments following the declaration (everywhere but in class
153     bodies, where they raise a compile error)
154
155     - Interpret some node structures into Python runtime values.
156     Some nodes take compile-time arguments (currently:
157     TemplatedTypeNode[args] and __cythonbufferdefaults__ = {args}),
158     which should be interpreted. This happens in a general way
159     and other steps should be taken to ensure validity.
160
161     Type arguments cannot be interpreted in this way.
162
163     - For __cythonbufferdefaults__ the arguments are checked for
164     validity.
165
166     TemplatedTypeNode has its directives interpreted:
167     Any first positional argument goes into the "dtype" attribute,
168     any "ndim" keyword argument goes into the "ndim" attribute and
169     so on. Also it is checked that the directive combination is valid.
170     - __cythonbufferdefaults__ attributes are parsed and put into the
171     type information.
172
173     Note: Currently Parsing.py does a lot of interpretation and
174     reorganization that can be refactored into this transform
175     if a more pure Abstract Syntax Tree is wanted.
176     """
177
178     def __init__(self, context):
179         super(PostParse, self).__init__(context)
180         self.specialattribute_handlers = {
181             '__cythonbufferdefaults__' : self.handle_bufferdefaults
182         }
183
184     def visit_ModuleNode(self, node):
185         self.lambda_counter = 1
186         self.genexpr_counter = 1
187         return super(PostParse, self).visit_ModuleNode(node)
188
189     def visit_LambdaNode(self, node):
190         # unpack a lambda expression into the corresponding DefNode
191         lambda_id = self.lambda_counter
192         self.lambda_counter += 1
193         node.lambda_name = EncodedString(u'lambda%d' % lambda_id)
194
195         body = Nodes.ReturnStatNode(
196             node.result_expr.pos, value = node.result_expr)
197         node.def_node = Nodes.DefNode(
198             node.pos, name=node.name, lambda_name=node.lambda_name,
199             args=node.args, star_arg=node.star_arg,
200             starstar_arg=node.starstar_arg,
201             body=body)
202         self.visitchildren(node)
203         return node
204
205     def visit_GeneratorExpressionNode(self, node):
206         # unpack a generator expression into the corresponding DefNode
207         genexpr_id = self.genexpr_counter
208         self.genexpr_counter += 1
209         node.genexpr_name = EncodedString(u'genexpr%d' % genexpr_id)
210
211         node.def_node = Nodes.DefNode(node.pos, name=node.genexpr_name,
212                                       doc=None,
213                                       args=[], star_arg=None,
214                                       starstar_arg=None,
215                                       body=node.loop)
216         self.visitchildren(node)
217         return node
218
219     # cdef variables
220     def handle_bufferdefaults(self, decl):
221         if not isinstance(decl.default, ExprNodes.DictNode):
222             raise PostParseError(decl.pos, ERR_BUF_DEFAULTS)
223         self.scope_node.buffer_defaults_node = decl.default
224         self.scope_node.buffer_defaults_pos = decl.pos
225
226     def visit_CVarDefNode(self, node):
227         # This assumes only plain names and pointers are assignable on
228         # declaration. Also, it makes use of the fact that a cdef decl
229         # must appear before the first use, so we don't have to deal with
230         # "i = 3; cdef int i = i" and can simply move the nodes around.
231         try:
232             self.visitchildren(node)
233             stats = [node]
234             newdecls = []
235             for decl in node.declarators:
236                 declbase = decl
237                 while isinstance(declbase, Nodes.CPtrDeclaratorNode):
238                     declbase = declbase.base
239                 if isinstance(declbase, Nodes.CNameDeclaratorNode):
240                     if declbase.default is not None:
241                         if self.scope_type in ('cclass', 'pyclass', 'struct'):
242                             if isinstance(self.scope_node, Nodes.CClassDefNode):
243                                 handler = self.specialattribute_handlers.get(decl.name)
244                                 if handler:
245                                     if decl is not declbase:
246                                         raise PostParseError(decl.pos, ERR_INVALID_SPECIALATTR_TYPE)
247                                     handler(decl)
248                                     continue # Remove declaration
249                             raise PostParseError(decl.pos, ERR_CDEF_INCLASS)
250                         first_assignment = self.scope_type != 'module'
251                         stats.append(Nodes.SingleAssignmentNode(node.pos,
252                             lhs=ExprNodes.NameNode(node.pos, name=declbase.name),
253                             rhs=declbase.default, first=first_assignment))
254                         declbase.default = None
255                 newdecls.append(decl)
256             node.declarators = newdecls
257             return stats
258         except PostParseError, e:
259             # An error in a cdef clause is ok, simply remove the declaration
260             # and try to move on to report more errors
261             self.context.nonfatal_error(e)
262             return None
263
264     # Split parallel assignments (a,b = b,a) into separate partial
265     # assignments that are executed rhs-first using temps.  This
266     # restructuring must be applied before type analysis so that known
267     # types on rhs and lhs can be matched directly.  It is required in
268     # the case that the types cannot be coerced to a Python type in
269     # order to assign from a tuple.
270
271     def visit_SingleAssignmentNode(self, node):
272         self.visitchildren(node)
273         return self._visit_assignment_node(node, [node.lhs, node.rhs])
274
275     def visit_CascadedAssignmentNode(self, node):
276         self.visitchildren(node)
277         return self._visit_assignment_node(node, node.lhs_list + [node.rhs])
278
279     def _visit_assignment_node(self, node, expr_list):
280         """Flatten parallel assignments into separate single
281         assignments or cascaded assignments.
282         """
283         if sum([ 1 for expr in expr_list if expr.is_sequence_constructor ]) < 2:
284             # no parallel assignments => nothing to do
285             return node
286
287         expr_list_list = []
288         flatten_parallel_assignments(expr_list, expr_list_list)
289         temp_refs = []
290         eliminate_rhs_duplicates(expr_list_list, temp_refs)
291
292         nodes = []
293         for expr_list in expr_list_list:
294             lhs_list = expr_list[:-1]
295             rhs = expr_list[-1]
296             if len(lhs_list) == 1:
297                 node = Nodes.SingleAssignmentNode(rhs.pos,
298                     lhs = lhs_list[0], rhs = rhs)
299             else:
300                 node = Nodes.CascadedAssignmentNode(rhs.pos,
301                     lhs_list = lhs_list, rhs = rhs)
302             nodes.append(node)
303
304         if len(nodes) == 1:
305             assign_node = nodes[0]
306         else:
307             assign_node = Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes)
308
309         if temp_refs:
310             duplicates_and_temps = [ (temp.expression, temp)
311                                      for temp in temp_refs ]
312             sort_common_subsequences(duplicates_and_temps)
313             for _, temp_ref in duplicates_and_temps[::-1]:
314                 assign_node = LetNode(temp_ref, assign_node)
315
316         return assign_node
317
318 def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence):
319     """Replace rhs items by LetRefNodes if they appear more than once.
320     Creates a sequence of LetRefNodes that set up the required temps
321     and appends them to ref_node_sequence.  The input list is modified
322     in-place.
323     """
324     seen_nodes = cython.set()
325     ref_nodes = {}
326     def find_duplicates(node):
327         if node.is_literal or node.is_name:
328             # no need to replace those; can't include attributes here
329             # as their access is not necessarily side-effect free
330             return
331         if node in seen_nodes:
332             if node not in ref_nodes:
333                 ref_node = LetRefNode(node)
334                 ref_nodes[node] = ref_node
335                 ref_node_sequence.append(ref_node)
336         else:
337             seen_nodes.add(node)
338             if node.is_sequence_constructor:
339                 for item in node.args:
340                     find_duplicates(item)
341
342     for expr_list in expr_list_list:
343         rhs = expr_list[-1]
344         find_duplicates(rhs)
345     if not ref_nodes:
346         return
347
348     def substitute_nodes(node):
349         if node in ref_nodes:
350             return ref_nodes[node]
351         elif node.is_sequence_constructor:
352             node.args = list(map(substitute_nodes, node.args))
353         return node
354
355     # replace nodes inside of the common subexpressions
356     for node in ref_nodes:
357         if node.is_sequence_constructor:
358             node.args = list(map(substitute_nodes, node.args))
359
360     # replace common subexpressions on all rhs items
361     for expr_list in expr_list_list:
362         expr_list[-1] = substitute_nodes(expr_list[-1])
363
364 def sort_common_subsequences(items):
365     """Sort items/subsequences so that all items and subsequences that
366     an item contains appear before the item itself.  This is needed
367     because each rhs item must only be evaluated once, so its value
368     must be evaluated first and then reused when packing sequences
369     that contain it.
370
371     This implies a partial order, and the sort must be stable to
372     preserve the original order as much as possible, so we use a
373     simple insertion sort (which is very fast for short sequences, the
374     normal case in practice).
375     """
376     def contains(seq, x):
377         for item in seq:
378             if item is x:
379                 return True
380             elif item.is_sequence_constructor and contains(item.args, x):
381                 return True
382         return False
383     def lower_than(a,b):
384         return b.is_sequence_constructor and contains(b.args, a)
385
386     for pos, item in enumerate(items):
387         key = item[1] # the ResultRefNode which has already been injected into the sequences
388         new_pos = pos
389         for i in xrange(pos-1, -1, -1):
390             if lower_than(key, items[i][0]):
391                 new_pos = i
392         if new_pos != pos:
393             for i in xrange(pos, new_pos, -1):
394                 items[i] = items[i-1]
395             items[new_pos] = item
396
397 def flatten_parallel_assignments(input, output):
398     #  The input is a list of expression nodes, representing the LHSs
399     #  and RHS of one (possibly cascaded) assignment statement.  For
400     #  sequence constructors, rearranges the matching parts of both
401     #  sides into a list of equivalent assignments between the
402     #  individual elements.  This transformation is applied
403     #  recursively, so that nested structures get matched as well.
404     rhs = input[-1]
405     if not rhs.is_sequence_constructor or not sum([lhs.is_sequence_constructor for lhs in input[:-1]]):
406         output.append(input)
407         return
408
409     complete_assignments = []
410
411     rhs_size = len(rhs.args)
412     lhs_targets = [ [] for _ in xrange(rhs_size) ]
413     starred_assignments = []
414     for lhs in input[:-1]:
415         if not lhs.is_sequence_constructor:
416             if lhs.is_starred:
417                 error(lhs.pos, "starred assignment target must be in a list or tuple")
418             complete_assignments.append(lhs)
419             continue
420         lhs_size = len(lhs.args)
421         starred_targets = sum([1 for expr in lhs.args if expr.is_starred])
422         if starred_targets > 1:
423             error(lhs.pos, "more than 1 starred expression in assignment")
424             output.append([lhs,rhs])
425             continue
426         elif lhs_size - starred_targets > rhs_size:
427             error(lhs.pos, "need more than %d value%s to unpack"
428                   % (rhs_size, (rhs_size != 1) and 's' or ''))
429             output.append([lhs,rhs])
430             continue
431         elif starred_targets:
432             map_starred_assignment(lhs_targets, starred_assignments,
433                                    lhs.args, rhs.args)
434         elif lhs_size < rhs_size:
435             error(lhs.pos, "too many values to unpack (expected %d, got %d)"
436                   % (lhs_size, rhs_size))
437             output.append([lhs,rhs])
438             continue
439         else:
440             for targets, expr in zip(lhs_targets, lhs.args):
441                 targets.append(expr)
442
443     if complete_assignments:
444         complete_assignments.append(rhs)
445         output.append(complete_assignments)
446
447     # recursively flatten partial assignments
448     for cascade, rhs in zip(lhs_targets, rhs.args):
449         if cascade:
450             cascade.append(rhs)
451             flatten_parallel_assignments(cascade, output)
452
453     # recursively flatten starred assignments
454     for cascade in starred_assignments:
455         if cascade[0].is_sequence_constructor:
456             flatten_parallel_assignments(cascade, output)
457         else:
458             output.append(cascade)
459
460 def map_starred_assignment(lhs_targets, starred_assignments, lhs_args, rhs_args):
461     # Appends the fixed-position LHS targets to the target list that
462     # appear left and right of the starred argument.
463     #
464     # The starred_assignments list receives a new tuple
465     # (lhs_target, rhs_values_list) that maps the remaining arguments
466     # (those that match the starred target) to a list.
467
468     # left side of the starred target
469     for i, (targets, expr) in enumerate(zip(lhs_targets, lhs_args)):
470         if expr.is_starred:
471             starred = i
472             lhs_remaining = len(lhs_args) - i - 1
473             break
474         targets.append(expr)
475     else:
476         raise InternalError("no starred arg found when splitting starred assignment")
477
478     # right side of the starred target
479     for i, (targets, expr) in enumerate(zip(lhs_targets[-lhs_remaining:],
480                                             lhs_args[-lhs_remaining:])):
481         targets.append(expr)
482
483     # the starred target itself, must be assigned a (potentially empty) list
484     target = lhs_args[starred].target # unpack starred node
485     starred_rhs = rhs_args[starred:]
486     if lhs_remaining:
487         starred_rhs = starred_rhs[:-lhs_remaining]
488     if starred_rhs:
489         pos = starred_rhs[0].pos
490     else:
491         pos = target.pos
492     starred_assignments.append([
493         target, ExprNodes.ListNode(pos=pos, args=starred_rhs)])
494
495
496 class PxdPostParse(CythonTransform, SkipDeclarations):
497     """
498     Basic interpretation/validity checking that should only be
499     done on pxd trees.
500
501     A lot of this checking currently happens in the parser; but
502     what is listed below happens here.
503
504     - "def" functions are let through only if they fill the
505     getbuffer/releasebuffer slots
506
507     - cdef functions are let through only if they are on the
508     top level and are declared "inline"
509     """
510     ERR_INLINE_ONLY = "function definition in pxd file must be declared 'cdef inline'"
511     ERR_NOGO_WITH_INLINE = "inline function definition in pxd file cannot be '%s'"
512
513     def __call__(self, node):
514         self.scope_type = 'pxd'
515         return super(PxdPostParse, self).__call__(node)
516
517     def visit_CClassDefNode(self, node):
518         old = self.scope_type
519         self.scope_type = 'cclass'
520         self.visitchildren(node)
521         self.scope_type = old
522         return node
523
524     def visit_FuncDefNode(self, node):
525         # FuncDefNode always come with an implementation (without
526         # an imp they are CVarDefNodes..)
527         err = self.ERR_INLINE_ONLY
528
529         if (isinstance(node, Nodes.DefNode) and self.scope_type == 'cclass'
530             and node.name in ('__getbuffer__', '__releasebuffer__')):
531             err = None # allow these slots
532
533         if isinstance(node, Nodes.CFuncDefNode):
534             if u'inline' in node.modifiers and self.scope_type == 'pxd':
535                 node.inline_in_pxd = True
536                 if node.visibility != 'private':
537                     err = self.ERR_NOGO_WITH_INLINE % node.visibility
538                 elif node.api:
539                     err = self.ERR_NOGO_WITH_INLINE % 'api'
540                 else:
541                     err = None # allow inline function
542             else:
543                 err = self.ERR_INLINE_ONLY
544
545         if err:
546             self.context.nonfatal_error(PostParseError(node.pos, err))
547             return None
548         else:
549             return node
550
551 class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
552     """
553     After parsing, directives can be stored in a number of places:
554     - #cython-comments at the top of the file (stored in ModuleNode)
555     - Command-line arguments overriding these
556     - @cython.directivename decorators
557     - with cython.directivename: statements
558
559     This transform is responsible for interpreting these various sources
560     and store the directive in two ways:
561     - Set the directives attribute of the ModuleNode for global directives.
562     - Use a CompilerDirectivesNode to override directives for a subtree.
563
564     (The first one is primarily to not have to modify with the tree
565     structure, so that ModuleNode stay on top.)
566
567     The directives are stored in dictionaries from name to value in effect.
568     Each such dictionary is always filled in for all possible directives,
569     using default values where no value is given by the user.
570
571     The available directives are controlled in Options.py.
572
573     Note that we have to run this prior to analysis, and so some minor
574     duplication of functionality has to occur: We manually track cimports
575     and which names the "cython" module may have been imported to.
576     """
577     unop_method_nodes = {
578         'typeof': ExprNodes.TypeofNode,
579
580         'operator.address': ExprNodes.AmpersandNode,
581         'operator.dereference': ExprNodes.DereferenceNode,
582         'operator.preincrement' : ExprNodes.inc_dec_constructor(True, '++'),
583         'operator.predecrement' : ExprNodes.inc_dec_constructor(True, '--'),
584         'operator.postincrement': ExprNodes.inc_dec_constructor(False, '++'),
585         'operator.postdecrement': ExprNodes.inc_dec_constructor(False, '--'),
586
587         # For backwards compatability.
588         'address': ExprNodes.AmpersandNode,
589     }
590
591     binop_method_nodes = {
592         'operator.comma'        : ExprNodes.c_binop_constructor(','),
593     }
594
595     special_methods = cython.set(['declare', 'union', 'struct', 'typedef', 'sizeof',
596                                   'cast', 'pointer', 'compiled', 'NULL'])
597     special_methods.update(unop_method_nodes.keys())
598
599     def __init__(self, context, compilation_directive_defaults):
600         super(InterpretCompilerDirectives, self).__init__(context)
601         self.compilation_directive_defaults = {}
602         for key, value in compilation_directive_defaults.items():
603             self.compilation_directive_defaults[unicode(key)] = value
604         self.cython_module_names = cython.set()
605         self.directive_names = {}
606
607     def check_directive_scope(self, pos, directive, scope):
608         legal_scopes = Options.directive_scopes.get(directive, None)
609         if legal_scopes and scope not in legal_scopes:
610             self.context.nonfatal_error(PostParseError(pos, 'The %s compiler directive '
611                                         'is not allowed in %s scope' % (directive, scope)))
612             return False
613         else:
614             return True
615
616     # Set up processing and handle the cython: comments.
617     def visit_ModuleNode(self, node):
618         for key, value in node.directive_comments.items():
619             if not self.check_directive_scope(node.pos, key, 'module'):
620                 self.wrong_scope_error(node.pos, key, 'module')
621                 del node.directive_comments[key]
622
623         directives = copy.copy(Options.directive_defaults)
624         directives.update(self.compilation_directive_defaults)
625         directives.update(node.directive_comments)
626         self.directives = directives
627         node.directives = directives
628         self.visitchildren(node)
629         node.cython_module_names = self.cython_module_names
630         return node
631
632     # The following four functions track imports and cimports that
633     # begin with "cython"
634     def is_cython_directive(self, name):
635         return (name in Options.directive_types or
636                 name in self.special_methods or
637                 PyrexTypes.parse_basic_type(name))
638
639     def visit_CImportStatNode(self, node):
640         if node.module_name == u"cython":
641             self.cython_module_names.add(node.as_name or u"cython")
642         elif node.module_name.startswith(u"cython."):
643             if node.as_name:
644                 self.directive_names[node.as_name] = node.module_name[7:]
645             else:
646                 self.cython_module_names.add(u"cython")
647             # if this cimport was a compiler directive, we don't
648             # want to leave the cimport node sitting in the tree
649             return None
650         return node
651
652     def visit_FromCImportStatNode(self, node):
653         if (node.module_name == u"cython") or \
654                node.module_name.startswith(u"cython."):
655             submodule = (node.module_name + u".")[7:]
656             newimp = []
657             for pos, name, as_name, kind in node.imported_names:
658                 full_name = submodule + name
659                 if self.is_cython_directive(full_name):
660                     if as_name is None:
661                         as_name = full_name
662                     self.directive_names[as_name] = full_name
663                     if kind is not None:
664                         self.context.nonfatal_error(PostParseError(pos,
665                             "Compiler directive imports must be plain imports"))
666                 else:
667                     newimp.append((pos, name, as_name, kind))
668             if not newimp:
669                 return None
670             node.imported_names = newimp
671         return node
672
673     def visit_FromImportStatNode(self, node):
674         if (node.module.module_name.value == u"cython") or \
675                node.module.module_name.value.startswith(u"cython."):
676             submodule = (node.module.module_name.value + u".")[7:]
677             newimp = []
678             for name, name_node in node.items:
679                 full_name = submodule + name
680                 if self.is_cython_directive(full_name):
681                     self.directive_names[name_node.name] = full_name
682                 else:
683                     newimp.append((name, name_node))
684             if not newimp:
685                 return None
686             node.items = newimp
687         return node
688
689     def visit_SingleAssignmentNode(self, node):
690         if (isinstance(node.rhs, ExprNodes.ImportNode) and
691                 node.rhs.module_name.value == u'cython'):
692             node = Nodes.CImportStatNode(node.pos,
693                                          module_name = u'cython',
694                                          as_name = node.lhs.name)
695             self.visit_CImportStatNode(node)
696         else:
697             self.visitchildren(node)
698         return node
699
700     def visit_NameNode(self, node):
701         if node.name in self.cython_module_names:
702             node.is_cython_module = True
703         else:
704             node.cython_attribute = self.directive_names.get(node.name)
705         return node
706
707     def try_to_parse_directives(self, node):
708         # If node is the contents of an directive (in a with statement or
709         # decorator), returns a list of (directivename, value) pairs.
710         # Otherwise, returns None
711         if isinstance(node, ExprNodes.CallNode):
712             self.visit(node.function)
713             optname = node.function.as_cython_attribute()
714             if optname:
715                 directivetype = Options.directive_types.get(optname)
716                 if directivetype:
717                     args, kwds = node.explicit_args_kwds()
718                     directives = []
719                     key_value_pairs = []
720                     if kwds is not None and directivetype is not dict:
721                         for keyvalue in kwds.key_value_pairs:
722                             key, value = keyvalue
723                             sub_optname = "%s.%s" % (optname, key.value)
724                             if Options.directive_types.get(sub_optname):
725                                 directives.append(self.try_to_parse_directive(sub_optname, [value], None, keyvalue.pos))
726                             else:
727                                 key_value_pairs.append(keyvalue)
728                         if not key_value_pairs:
729                             kwds = None
730                         else:
731                             kwds.key_value_pairs = key_value_pairs
732                         if directives and not kwds and not args:
733                             return directives
734                     directives.append(self.try_to_parse_directive(optname, args, kwds, node.function.pos))
735                     return directives
736         elif isinstance(node, (ExprNodes.AttributeNode, ExprNodes.NameNode)):
737             self.visit(node)
738             optname = node.as_cython_attribute()
739             if optname:
740                 directivetype = Options.directive_types.get(optname)
741                 if directivetype is bool:
742                     return [(optname, True)]
743                 elif directivetype is None:
744                     return [(optname, None)]
745                 else:
746                     raise PostParseError(
747                         node.pos, "The '%s' directive should be used as a function call." % optname)
748         return None
749
750     def try_to_parse_directive(self, optname, args, kwds, pos):
751         directivetype = Options.directive_types.get(optname)
752         if len(args) == 1 and isinstance(args[0], ExprNodes.NoneNode):
753             return optname, Options.directive_defaults[optname]
754         elif directivetype is bool:
755             if kwds is not None or len(args) != 1 or not isinstance(args[0], ExprNodes.BoolNode):
756                 raise PostParseError(pos,
757                     'The %s directive takes one compile-time boolean argument' % optname)
758             return (optname, args[0].value)
759         elif directivetype is str:
760             if kwds is not None or len(args) != 1 or not isinstance(args[0], (ExprNodes.StringNode,
761                                                                               ExprNodes.UnicodeNode)):
762                 raise PostParseError(pos,
763                     'The %s directive takes one compile-time string argument' % optname)
764             return (optname, str(args[0].value))
765         elif directivetype is dict:
766             if len(args) != 0:
767                 raise PostParseError(pos,
768                     'The %s directive takes no prepositional arguments' % optname)
769             return optname, dict([(key.value, value) for key, value in kwds.key_value_pairs])
770         elif directivetype is list:
771             if kwds and len(kwds) != 0:
772                 raise PostParseError(pos,
773                     'The %s directive takes no keyword arguments' % optname)
774             return optname, [ str(arg.value) for arg in args ]
775         else:
776             assert False
777
778     def visit_with_directives(self, body, directives):
779         olddirectives = self.directives
780         newdirectives = copy.copy(olddirectives)
781         newdirectives.update(directives)
782         self.directives = newdirectives
783         assert isinstance(body, Nodes.StatListNode), body
784         retbody = self.visit_Node(body)
785         directive = Nodes.CompilerDirectivesNode(pos=retbody.pos, body=retbody,
786                                                  directives=newdirectives)
787         self.directives = olddirectives
788         return directive
789
790     # Handle decorators
791     def visit_FuncDefNode(self, node):
792         directives = self._extract_directives(node, 'function')
793         if not directives:
794             return self.visit_Node(node)
795         body = Nodes.StatListNode(node.pos, stats=[node])
796         return self.visit_with_directives(body, directives)
797
798     def visit_CVarDefNode(self, node):
799         if not node.decorators:
800             return node
801         for dec in node.decorators:
802             for directive in self.try_to_parse_directives(dec.decorator) or ():
803                 if directive is not None and directive[0] == u'locals':
804                     node.directive_locals = directive[1]
805                 else:
806                     self.context.nonfatal_error(PostParseError(dec.pos,
807                         "Cdef functions can only take cython.locals() decorator."))
808         return node
809
810     def visit_CClassDefNode(self, node):
811         directives = self._extract_directives(node, 'cclass')
812         if not directives:
813             return self.visit_Node(node)
814         body = Nodes.StatListNode(node.pos, stats=[node])
815         return self.visit_with_directives(body, directives)
816
817     def visit_PyClassDefNode(self, node):
818         directives = self._extract_directives(node, 'class')
819         if not directives:
820             return self.visit_Node(node)
821         body = Nodes.StatListNode(node.pos, stats=[node])
822         return self.visit_with_directives(body, directives)
823
824     def _extract_directives(self, node, scope_name):
825         if not node.decorators:
826             return {}
827         # Split the decorators into two lists -- real decorators and directives
828         directives = []
829         realdecs = []
830         for dec in node.decorators:
831             new_directives = self.try_to_parse_directives(dec.decorator)
832             if new_directives is not None:
833                 for directive in new_directives:
834                     if self.check_directive_scope(node.pos, directive[0], scope_name):
835                         directives.append(directive)
836             else:
837                 realdecs.append(dec)
838         if realdecs and isinstance(node, (Nodes.CFuncDefNode, Nodes.CClassDefNode)):
839             raise PostParseError(realdecs[0].pos, "Cdef functions/classes cannot take arbitrary decorators.")
840         else:
841             node.decorators = realdecs
842         # merge or override repeated directives
843         optdict = {}
844         directives.reverse() # Decorators coming first take precedence
845         for directive in directives:
846             name, value = directive
847             if name in optdict:
848                 old_value = optdict[name]
849                 # keywords and arg lists can be merged, everything
850                 # else overrides completely
851                 if isinstance(old_value, dict):
852                     old_value.update(value)
853                 elif isinstance(old_value, list):
854                     old_value.extend(value)
855                 else:
856                     optdict[name] = value
857             else:
858                 optdict[name] = value
859         return optdict
860
861     # Handle with statements
862     def visit_WithStatNode(self, node):
863         directive_dict = {}
864         for directive in self.try_to_parse_directives(node.manager) or []:
865             if directive is not None:
866                 if node.target is not None:
867                     self.context.nonfatal_error(
868                         PostParseError(node.pos, "Compiler directive with statements cannot contain 'as'"))
869                 else:
870                     name, value = directive
871                     if name == 'nogil':
872                         # special case: in pure mode, "with nogil" spells "with cython.nogil"
873                         node = Nodes.GILStatNode(node.pos, state = "nogil", body = node.body)
874                         return self.visit_Node(node)
875                     if self.check_directive_scope(node.pos, name, 'with statement'):
876                         directive_dict[name] = value
877         if directive_dict:
878             return self.visit_with_directives(node.body, directive_dict)
879         return self.visit_Node(node)
880
881 class WithTransform(CythonTransform, SkipDeclarations):
882
883     # EXCINFO is manually set to a variable that contains
884     # the exc_info() tuple that can be generated by the enclosing except
885     # statement.
886     template_without_target = TreeFragment(u"""
887         MGR = EXPR
888         EXIT = MGR.__exit__
889         MGR.__enter__()
890         EXC = True
891         try:
892             try:
893                 EXCINFO = None
894                 BODY
895             except:
896                 EXC = False
897                 if not EXIT(*EXCINFO):
898                     raise
899         finally:
900             if EXC:
901                 EXIT(None, None, None)
902     """, temps=[u'MGR', u'EXC', u"EXIT"],
903     pipeline=[NormalizeTree(None)])
904
905     template_with_target = TreeFragment(u"""
906         MGR = EXPR
907         EXIT = MGR.__exit__
908         VALUE = MGR.__enter__()
909         EXC = True
910         try:
911             try:
912                 EXCINFO = None
913                 TARGET = VALUE
914                 BODY
915             except:
916                 EXC = False
917                 if not EXIT(*EXCINFO):
918                     raise
919         finally:
920             if EXC:
921                 EXIT(None, None, None)
922             MGR = EXIT = VALUE = EXC = None
923
924     """, temps=[u'MGR', u'EXC', u"EXIT", u"VALUE"],
925     pipeline=[NormalizeTree(None)])
926
927     def visit_WithStatNode(self, node):
928         # TODO: Cleanup badly needed
929         TemplateTransform.temp_name_counter += 1
930         handle = "__tmpvar_%d" % TemplateTransform.temp_name_counter
931
932         self.visitchildren(node, ['body'])
933         excinfo_temp = ExprNodes.NameNode(node.pos, name=handle)#TempHandle(Builtin.tuple_type)
934         if node.target is not None:
935             result = self.template_with_target.substitute({
936                 u'EXPR' : node.manager,
937                 u'BODY' : node.body,
938                 u'TARGET' : node.target,
939                 u'EXCINFO' : excinfo_temp
940                 }, pos=node.pos)
941         else:
942             result = self.template_without_target.substitute({
943                 u'EXPR' : node.manager,
944                 u'BODY' : node.body,
945                 u'EXCINFO' : excinfo_temp
946                 }, pos=node.pos)
947
948         # Set except excinfo target to EXCINFO
949         try_except = result.stats[-1].body.stats[-1]
950         try_except.except_clauses[0].excinfo_target = ExprNodes.NameNode(node.pos, name=handle)
951 #            excinfo_temp.ref(node.pos))
952
953 #        result.stats[-1].body.stats[-1] = TempsBlockNode(
954 #            node.pos, temps=[excinfo_temp], body=try_except)
955
956         return result
957
958     def visit_ExprNode(self, node):
959         # With statements are never inside expressions.
960         return node
961
962
963 class DecoratorTransform(CythonTransform, SkipDeclarations):
964
965     def visit_DefNode(self, func_node):
966         self.visitchildren(func_node)
967         if not func_node.decorators:
968             return func_node
969         return self._handle_decorators(
970             func_node, func_node.name)
971
972     def visit_CClassDefNode(self, class_node):
973         # This doesn't currently work, so it's disabled.
974         #
975         # Problem: assignments to cdef class names do not work.  They
976         # would require an additional check anyway, as the extension
977         # type must not change its C type, so decorators cannot
978         # replace an extension type, just alter it and return it.
979
980         self.visitchildren(class_node)
981         if not class_node.decorators:
982             return class_node
983         error(class_node.pos,
984               "Decorators not allowed on cdef classes (used on type '%s')" % class_node.class_name)
985         return class_node
986         #return self._handle_decorators(
987         #    class_node, class_node.class_name)
988
989     def visit_ClassDefNode(self, class_node):
990         self.visitchildren(class_node)
991         if not class_node.decorators:
992             return class_node
993         return self._handle_decorators(
994             class_node, class_node.name)
995
996     def _handle_decorators(self, node, name):
997         decorator_result = ExprNodes.NameNode(node.pos, name = name)
998         for decorator in node.decorators[::-1]:
999             decorator_result = ExprNodes.SimpleCallNode(
1000                 decorator.pos,
1001                 function = decorator.decorator,
1002                 args = [decorator_result])
1003
1004         name_node = ExprNodes.NameNode(node.pos, name = name)
1005         reassignment = Nodes.SingleAssignmentNode(
1006             node.pos,
1007             lhs = name_node,
1008             rhs = decorator_result)
1009         return [node, reassignment]
1010
1011
1012 class AnalyseDeclarationsTransform(CythonTransform):
1013
1014     basic_property = TreeFragment(u"""
1015 property NAME:
1016     def __get__(self):
1017         return ATTR
1018     def __set__(self, value):
1019         ATTR = value
1020     """, level='c_class')
1021     basic_pyobject_property = TreeFragment(u"""
1022 property NAME:
1023     def __get__(self):
1024         return ATTR
1025     def __set__(self, value):
1026         ATTR = value
1027     def __del__(self):
1028         ATTR = None
1029     """, level='c_class')
1030     basic_property_ro = TreeFragment(u"""
1031 property NAME:
1032     def __get__(self):
1033         return ATTR
1034     """, level='c_class')
1035
1036     def __call__(self, root):
1037         self.env_stack = [root.scope]
1038         # needed to determine if a cdef var is declared after it's used.
1039         self.seen_vars_stack = []
1040         return super(AnalyseDeclarationsTransform, self).__call__(root)
1041
1042     def visit_NameNode(self, node):
1043         self.seen_vars_stack[-1].add(node.name)
1044         return node
1045
1046     def visit_ModuleNode(self, node):
1047         self.seen_vars_stack.append(cython.set())
1048         node.analyse_declarations(self.env_stack[-1])
1049         self.visitchildren(node)
1050         self.seen_vars_stack.pop()
1051         return node
1052
1053     def visit_LambdaNode(self, node):
1054         node.analyse_declarations(self.env_stack[-1])
1055         self.visitchildren(node)
1056         return node
1057
1058     def visit_ClassDefNode(self, node):
1059         self.env_stack.append(node.scope)
1060         self.visitchildren(node)
1061         self.env_stack.pop()
1062         return node
1063
1064     def visit_CClassDefNode(self, node):
1065         node = self.visit_ClassDefNode(node)
1066         if node.scope and node.scope.implemented:
1067             stats = []
1068             for entry in node.scope.var_entries:
1069                 if entry.needs_property:
1070                     property = self.create_Property(entry)
1071                     property.analyse_declarations(node.scope)
1072                     self.visit(property)
1073                     stats.append(property)
1074             if stats:
1075                 node.body.stats += stats
1076         return node
1077
1078     def visit_FuncDefNode(self, node):
1079         self.seen_vars_stack.append(cython.set())
1080         lenv = node.local_scope
1081         node.body.analyse_control_flow(lenv) # this will be totally refactored
1082         node.declare_arguments(lenv)
1083         for var, type_node in node.directive_locals.items():
1084             if not lenv.lookup_here(var):   # don't redeclare args
1085                 type = type_node.analyse_as_type(lenv)
1086                 if type:
1087                     lenv.declare_var(var, type, type_node.pos)
1088                 else:
1089                     error(type_node.pos, "Not a type")
1090         node.body.analyse_declarations(lenv)
1091         self.env_stack.append(lenv)
1092         self.visitchildren(node)
1093         self.env_stack.pop()
1094         self.seen_vars_stack.pop()
1095         return node
1096
1097     def visit_ScopedExprNode(self, node):
1098         env = self.env_stack[-1]
1099         node.analyse_declarations(env)
1100         # the node may or may not have a local scope
1101         if node.has_local_scope:
1102             self.seen_vars_stack.append(cython.set(self.seen_vars_stack[-1]))
1103             self.env_stack.append(node.expr_scope)
1104             node.analyse_scoped_declarations(node.expr_scope)
1105             self.visitchildren(node)
1106             self.env_stack.pop()
1107             self.seen_vars_stack.pop()
1108         else:
1109             node.analyse_scoped_declarations(env)
1110             self.visitchildren(node)
1111         return node
1112
1113     def visit_TempResultFromStatNode(self, node):
1114         self.visitchildren(node)
1115         node.analyse_declarations(self.env_stack[-1])
1116         return node
1117
1118     # Some nodes are no longer needed after declaration
1119     # analysis and can be dropped. The analysis was performed
1120     # on these nodes in a seperate recursive process from the
1121     # enclosing function or module, so we can simply drop them.
1122     def visit_CDeclaratorNode(self, node):
1123         # necessary to ensure that all CNameDeclaratorNodes are visited.
1124         self.visitchildren(node)
1125         return node
1126
1127     def visit_CTypeDefNode(self, node):
1128         return node
1129
1130     def visit_CBaseTypeNode(self, node):
1131         return None
1132
1133     def visit_CEnumDefNode(self, node):
1134         if node.visibility == 'public':
1135             return node
1136         else:
1137             return None
1138
1139     def visit_CStructOrUnionDefNode(self, node):
1140         return None
1141
1142     def visit_CNameDeclaratorNode(self, node):
1143         if node.name in self.seen_vars_stack[-1]:
1144             entry = self.env_stack[-1].lookup(node.name)
1145             if entry is None or entry.visibility != 'extern':
1146                 warning(node.pos, "cdef variable '%s' declared after it is used" % node.name, 2)
1147         self.visitchildren(node)
1148         return node
1149
1150     def visit_CVarDefNode(self, node):
1151         # to ensure all CNameDeclaratorNodes are visited.
1152         self.visitchildren(node)
1153         return None
1154
1155     def create_Property(self, entry):
1156         if entry.visibility == 'public':
1157             if entry.type.is_pyobject:
1158                 template = self.basic_pyobject_property
1159             else:
1160                 template = self.basic_property
1161         elif entry.visibility == 'readonly':
1162             template = self.basic_property_ro
1163         property = template.substitute({
1164                 u"ATTR": ExprNodes.AttributeNode(pos=entry.pos,
1165                                                  obj=ExprNodes.NameNode(pos=entry.pos, name="self"),
1166                                                  attribute=entry.name),
1167             }, pos=entry.pos).stats[0]
1168         property.name = entry.name
1169         # ---------------------------------------
1170         # XXX This should go to AutoDocTransforms
1171         # ---------------------------------------
1172         if (Options.docstrings and
1173             self.current_directives['embedsignature']):
1174             attr_name = entry.name
1175             type_name = entry.type.declaration_code("", for_display=1)
1176             default_value = ''
1177             if not entry.type.is_pyobject:
1178                 type_name = "'%s'" % type_name
1179             elif entry.type.is_extension_type:
1180                 type_name = entry.type.module_name + '.' + type_name
1181             if entry.init is not None:
1182                 default_value = ' = ' + entry.init
1183             elif entry.init_to_none:
1184                 default_value = ' = ' + repr(None)
1185             docstring = attr_name + ': ' + type_name + default_value
1186             property.doc = EncodedString(docstring)
1187         # ---------------------------------------
1188         return property
1189
1190 class AnalyseExpressionsTransform(CythonTransform):
1191
1192     def visit_ModuleNode(self, node):
1193         node.scope.infer_types()
1194         node.body.analyse_expressions(node.scope)
1195         self.visitchildren(node)
1196         return node
1197
1198     def visit_FuncDefNode(self, node):
1199         node.local_scope.infer_types()
1200         node.body.analyse_expressions(node.local_scope)
1201         self.visitchildren(node)
1202         return node
1203
1204     def visit_ScopedExprNode(self, node):
1205         if node.has_local_scope:
1206             node.expr_scope.infer_types()
1207             node.analyse_scoped_expressions(node.expr_scope)
1208         self.visitchildren(node)
1209         return node
1210
1211 class ExpandInplaceOperators(EnvTransform):
1212
1213     def visit_InPlaceAssignmentNode(self, node):
1214         lhs = node.lhs
1215         rhs = node.rhs
1216         if lhs.type.is_cpp_class:
1217             # No getting around this exact operator here.
1218             return node
1219         if isinstance(lhs, ExprNodes.IndexNode) and lhs.is_buffer_access:
1220             # There is code to handle this case.
1221             return node
1222
1223         env = self.current_env()
1224         def side_effect_free_reference(node, setting=False):
1225             if isinstance(node, ExprNodes.NameNode):
1226                 return node, []
1227             elif node.type.is_pyobject and not setting:
1228                 node = LetRefNode(node)
1229                 return node, [node]
1230             elif isinstance(node, ExprNodes.IndexNode):
1231                 if node.is_buffer_access:
1232                     raise ValueError, "Buffer access"
1233                 base, temps = side_effect_free_reference(node.base)
1234                 index = LetRefNode(node.index)
1235                 return ExprNodes.IndexNode(node.pos, base=base, index=index), temps + [index]
1236             elif isinstance(node, ExprNodes.AttributeNode):
1237                 obj, temps = side_effect_free_reference(node.obj)
1238                 return ExprNodes.AttributeNode(node.pos, obj=obj, attribute=node.attribute), temps
1239             else:
1240                 node = LetRefNode(node)
1241                 return node, [node]
1242         try:
1243             lhs, let_ref_nodes = side_effect_free_reference(lhs, setting=True)
1244         except ValueError:
1245             return node
1246         dup = lhs.__class__(**lhs.__dict__)
1247         binop = ExprNodes.binop_node(node.pos,
1248                                      operator = node.operator,
1249                                      operand1 = dup,
1250                                      operand2 = rhs,
1251                                      inplace=True)
1252         # Manually analyse types for new node.
1253         lhs.analyse_target_types(env)
1254         dup.analyse_types(env)
1255         binop.analyse_operation(env)
1256         node = Nodes.SingleAssignmentNode(
1257             node.pos,
1258             lhs = lhs,
1259             rhs=binop.coerce_to(lhs.type, env))
1260         # Use LetRefNode to avoid side effects.
1261         let_ref_nodes.reverse()
1262         for t in let_ref_nodes:
1263             node = LetNode(t, node)
1264         return node
1265
1266     def visit_ExprNode(self, node):
1267         # In-place assignments can't happen within an expression.
1268         return node
1269
1270
1271 class AlignFunctionDefinitions(CythonTransform):
1272     """
1273     This class takes the signatures from a .pxd file and applies them to
1274     the def methods in a .py file.
1275     """
1276
1277     def visit_ModuleNode(self, node):
1278         self.scope = node.scope
1279         self.directives = node.directives
1280         self.visitchildren(node)
1281         return node
1282
1283     def visit_PyClassDefNode(self, node):
1284         pxd_def = self.scope.lookup(node.name)
1285         if pxd_def:
1286             if pxd_def.is_cclass:
1287                 return self.visit_CClassDefNode(node.as_cclass(), pxd_def)
1288             else:
1289                 error(node.pos, "'%s' redeclared" % node.name)
1290                 error(pxd_def.pos, "previous declaration here")
1291                 return None
1292         else:
1293             return node
1294
1295     def visit_CClassDefNode(self, node, pxd_def=None):
1296         if pxd_def is None:
1297             pxd_def = self.scope.lookup(node.class_name)
1298         if pxd_def:
1299             outer_scope = self.scope
1300             self.scope = pxd_def.type.scope
1301         self.visitchildren(node)
1302         if pxd_def:
1303             self.scope = outer_scope
1304         return node
1305
1306     def visit_DefNode(self, node):
1307         pxd_def = self.scope.lookup(node.name)
1308         if pxd_def:
1309             if not pxd_def.is_cfunction:
1310                 error(node.pos, "'%s' redeclared" % node.name)
1311                 error(pxd_def.pos, "previous declaration here")
1312                 return None
1313             node = node.as_cfunction(pxd_def)
1314         elif self.scope.is_module_scope and self.directives['auto_cpdef']:
1315             node = node.as_cfunction(scope=self.scope)
1316         # Enable this when internal def functions are allowed.
1317         # self.visitchildren(node)
1318         return node
1319
1320
1321 class YieldNodeCollector(TreeVisitor):
1322
1323     def __init__(self):
1324         super(YieldNodeCollector, self).__init__()
1325         self.yields = []
1326         self.returns = []
1327         self.has_return_value = False
1328
1329     visit_Node = TreeVisitor.visitchildren
1330
1331     def visit_YieldExprNode(self, node):
1332         if self.has_return_value:
1333             error(node.pos, "'yield' outside function")
1334         self.yields.append(node)
1335
1336     def visit_ReturnStatNode(self, node):
1337         if node.value:
1338             self.has_return_value = True
1339             if self.yields:
1340                 error(node.pos, "'return' with argument inside generator")
1341         self.returns.append(node)
1342
1343     def visit_ClassDefNode(self, node):
1344         pass
1345
1346     def visit_DefNode(self, node):
1347         pass
1348
1349     def visit_LambdaNode(self, node):
1350         pass
1351
1352     def visit_GeneratorExpressionNode(self, node):
1353         pass
1354
1355 class MarkClosureVisitor(CythonTransform):
1356
1357     def visit_ModuleNode(self, node):
1358         self.needs_closure = False
1359         self.visitchildren(node)
1360         return node
1361
1362     def visit_FuncDefNode(self, node):
1363         self.needs_closure = False
1364         self.visitchildren(node)
1365         node.needs_closure = self.needs_closure
1366         self.needs_closure = True
1367
1368         collector = YieldNodeCollector()
1369         collector.visitchildren(node)
1370
1371         if collector.yields:
1372             for i, yield_expr in enumerate(collector.yields):
1373                 yield_expr.label_num = i + 1
1374
1375             gbody = Nodes.GeneratorBodyDefNode(pos=node.pos,
1376                                                name=node.name,
1377                                                body=node.body)
1378             generator = Nodes.GeneratorDefNode(pos=node.pos,
1379                                                name=node.name,
1380                                                args=node.args,
1381                                                star_arg=node.star_arg,
1382                                                starstar_arg=node.starstar_arg,
1383                                                doc=node.doc,
1384                                                decorators=node.decorators,
1385                                                gbody=gbody)
1386             return generator
1387         return node
1388
1389     def visit_CFuncDefNode(self, node):
1390         self.visit_FuncDefNode(node)
1391         if node.needs_closure:
1392             error(node.pos, "closures inside cdef functions not yet supported")
1393         return node
1394
1395     def visit_LambdaNode(self, node):
1396         self.needs_closure = False
1397         self.visitchildren(node)
1398         node.needs_closure = self.needs_closure
1399         self.needs_closure = True
1400         return node
1401
1402     def visit_ClassDefNode(self, node):
1403         self.visitchildren(node)
1404         self.needs_closure = True
1405         return node
1406
1407 class CreateClosureClasses(CythonTransform):
1408     # Output closure classes in module scope for all functions
1409     # that really need it.
1410
1411     def __init__(self, context):
1412         super(CreateClosureClasses, self).__init__(context)
1413         self.path = []
1414         self.in_lambda = False
1415         self.generator_class = None
1416
1417     def visit_ModuleNode(self, node):
1418         self.module_scope = node.scope
1419         self.visitchildren(node)
1420         return node
1421
1422     def create_generator_class(self, target_module_scope, pos):
1423         if self.generator_class:
1424             return self.generator_class
1425         # XXX: make generator class creation cleaner
1426         entry = target_module_scope.declare_c_class(name='__pyx_Generator',
1427                     objstruct_cname='__pyx_Generator_object',
1428                     typeobj_cname='__pyx_Generator_type',
1429                     pos=pos, defining=True, implementing=True)
1430         klass = entry.type.scope
1431         klass.is_internal = True
1432         klass.directives = {'final': True}
1433
1434         body_type = PyrexTypes.create_typedef_type('generator_body',
1435                                                    PyrexTypes.c_void_ptr_type,
1436                                                    '__pyx_generator_body_t')
1437         klass.declare_var(pos=pos, name='body', cname='body',
1438                           type=body_type, is_cdef=True)
1439         klass.declare_var(pos=pos, name='is_running', cname='is_running', type=PyrexTypes.c_int_type,
1440                           is_cdef=True)
1441         klass.declare_var(pos=pos, name='resume_label', cname='resume_label', type=PyrexTypes.c_int_type,
1442                           is_cdef=True)
1443
1444         import TypeSlots
1445         e = klass.declare_pyfunction('send', pos)
1446         e.func_cname = '__Pyx_Generator_Send'
1447         e.signature = TypeSlots.binaryfunc
1448
1449         e = klass.declare_pyfunction('close', pos)
1450         e.func_cname = '__Pyx_Generator_Close'
1451         e.signature = TypeSlots.unaryfunc
1452
1453         e = klass.declare_pyfunction('throw', pos)
1454         e.func_cname = '__Pyx_Generator_Throw'
1455         e.signature = TypeSlots.pyfunction_signature
1456
1457         e = klass.declare_var('__iter__', PyrexTypes.py_object_type, pos, visibility='public')
1458         e.func_cname = 'PyObject_SelfIter'
1459
1460         e = klass.declare_var('__next__', PyrexTypes.py_object_type, pos, visibility='public')
1461         e.func_cname = '__Pyx_Generator_Next'
1462
1463         self.generator_class = entry.type
1464         return self.generator_class
1465
1466     def get_scope_use(self, node):
1467         from_closure = []
1468         in_closure = []
1469         for name, entry in node.local_scope.entries.items():
1470             if entry.from_closure:
1471                 from_closure.append((name, entry))
1472             elif entry.in_closure and not entry.from_closure:
1473                 in_closure.append((name, entry))
1474         return from_closure, in_closure
1475
1476     def create_class_from_scope(self, node, target_module_scope, inner_node=None):
1477         # skip generator body
1478         if node.is_generator_body:
1479             return
1480         # move local variables into closure
1481         if node.is_generator:
1482             for entry in node.local_scope.entries.values():
1483                 if not entry.from_closure:
1484                     entry.in_closure = True
1485
1486         from_closure, in_closure = self.get_scope_use(node)
1487         in_closure.sort()
1488
1489         # Now from the begining
1490         node.needs_closure = False
1491         node.needs_outer_scope = False
1492
1493         func_scope = node.local_scope
1494         cscope = node.entry.scope
1495         while cscope.is_py_class_scope or cscope.is_c_class_scope:
1496             cscope = cscope.outer_scope
1497
1498         if not from_closure and (self.path or inner_node):
1499             if not inner_node:
1500                 if not node.assmt:
1501                     raise InternalError, "DefNode does not have assignment node"
1502                 inner_node = node.assmt.rhs
1503             inner_node.needs_self_code = False
1504             node.needs_outer_scope = False
1505
1506         if node.is_generator:
1507             generator_class = self.create_generator_class(target_module_scope, node.pos)
1508         elif not in_closure and not from_closure:
1509             return
1510         elif not in_closure:
1511             func_scope.is_passthrough = True
1512             func_scope.scope_class = cscope.scope_class
1513             node.needs_outer_scope = True
1514             return
1515
1516         as_name = '%s_%s' % (target_module_scope.next_id(Naming.closure_class_prefix), node.entry.cname)
1517
1518         if node.is_generator:
1519             entry = target_module_scope.declare_c_class(name = as_name,
1520                         pos = node.pos, defining = True, implementing = True, base_type=generator_class)
1521         else:
1522             entry = target_module_scope.declare_c_class(name = as_name,
1523                         pos = node.pos, defining = True, implementing = True)
1524         func_scope.scope_class = entry
1525         class_scope = entry.type.scope
1526         class_scope.is_internal = True
1527         class_scope.directives = {'final': True}
1528
1529         if from_closure:
1530             assert cscope.is_closure_scope
1531             class_scope.declare_var(pos=node.pos,
1532                                     name=Naming.outer_scope_cname,
1533                                     cname=Naming.outer_scope_cname,
1534                                     type=cscope.scope_class.type,
1535                                     is_cdef=True)
1536             node.needs_outer_scope = True
1537         for name, entry in in_closure:
1538             class_scope.declare_var(pos=entry.pos,
1539                                     name=entry.name,
1540                                     cname=entry.cname,
1541                                     type=entry.type,
1542                                     is_cdef=True)
1543         node.needs_closure = True
1544         # Do it here because other classes are already checked
1545         target_module_scope.check_c_class(func_scope.scope_class)
1546
1547     def visit_LambdaNode(self, node):
1548         was_in_lambda = self.in_lambda
1549         self.in_lambda = True
1550         self.create_class_from_scope(node.def_node, self.module_scope, node)
1551         self.visitchildren(node)
1552         self.in_lambda = was_in_lambda
1553         return node
1554
1555     def visit_FuncDefNode(self, node):
1556         if self.in_lambda:
1557             self.visitchildren(node)
1558             return node
1559         if node.needs_closure or self.path:
1560             self.create_class_from_scope(node, self.module_scope)
1561             self.path.append(node)
1562             self.visitchildren(node)
1563             self.path.pop()
1564         return node
1565
1566
1567 class GilCheck(VisitorTransform):
1568     """
1569     Call `node.gil_check(env)` on each node to make sure we hold the
1570     GIL when we need it.  Raise an error when on Python operations
1571     inside a `nogil` environment.
1572     """
1573     def __call__(self, root):
1574         self.env_stack = [root.scope]
1575         self.nogil = False
1576         return super(GilCheck, self).__call__(root)
1577
1578     def visit_FuncDefNode(self, node):
1579         self.env_stack.append(node.local_scope)
1580         was_nogil = self.nogil
1581         self.nogil = node.local_scope.nogil
1582         if self.nogil and node.nogil_check:
1583             node.nogil_check(node.local_scope)
1584         self.visitchildren(node)
1585         self.env_stack.pop()
1586         self.nogil = was_nogil
1587         return node
1588
1589     def visit_GILStatNode(self, node):
1590         env = self.env_stack[-1]
1591         if self.nogil and node.nogil_check: node.nogil_check()
1592         was_nogil = self.nogil
1593         self.nogil = (node.state == 'nogil')
1594         self.visitchildren(node)
1595         self.nogil = was_nogil
1596         return node
1597
1598     def visit_Node(self, node):
1599         if self.env_stack and self.nogil and node.nogil_check:
1600             node.nogil_check(self.env_stack[-1])
1601         self.visitchildren(node)
1602         return node
1603
1604
1605 class TransformBuiltinMethods(EnvTransform):
1606
1607     def visit_SingleAssignmentNode(self, node):
1608         if node.declaration_only:
1609             return None
1610         else:
1611             self.visitchildren(node)
1612             return node
1613
1614     def visit_AttributeNode(self, node):
1615         self.visitchildren(node)
1616         return self.visit_cython_attribute(node)
1617
1618     def visit_NameNode(self, node):
1619         return self.visit_cython_attribute(node)
1620
1621     def visit_cython_attribute(self, node):
1622         attribute = node.as_cython_attribute()
1623         if attribute:
1624             if attribute == u'compiled':
1625                 node = ExprNodes.BoolNode(node.pos, value=True)
1626             elif attribute == u'NULL':
1627                 node = ExprNodes.NullNode(node.pos)
1628             elif attribute in (u'set', u'frozenset'):
1629                 node = ExprNodes.NameNode(node.pos, name=EncodedString(attribute),
1630                                           entry=self.current_env().builtin_scope().lookup_here(attribute))
1631             elif not PyrexTypes.parse_basic_type(attribute):
1632                 error(node.pos, u"'%s' not a valid cython attribute or is being used incorrectly" % attribute)
1633         return node
1634
1635     def visit_SimpleCallNode(self, node):
1636
1637         # locals builtin
1638         if isinstance(node.function, ExprNodes.NameNode):
1639             if node.function.name == 'locals':
1640                 lenv = self.current_env()
1641                 entry = lenv.lookup_here('locals')
1642                 if entry:
1643                     # not the builtin 'locals'
1644                     return node
1645                 if len(node.args) > 0:
1646                     error(self.pos, "Builtin 'locals()' called with wrong number of args, expected 0, got %d" % len(node.args))
1647                     return node
1648                 pos = node.pos
1649                 items = [ ExprNodes.DictItemNode(pos,
1650                                                  key=ExprNodes.StringNode(pos, value=var),
1651                                                  value=ExprNodes.NameNode(pos, name=var))
1652                           for var in lenv.entries ]
1653                 return ExprNodes.DictNode(pos, key_value_pairs=items)
1654
1655         # cython.foo
1656         function = node.function.as_cython_attribute()
1657         if function:
1658             if function in InterpretCompilerDirectives.unop_method_nodes:
1659                 if len(node.args) != 1:
1660                     error(node.function.pos, u"%s() takes exactly one argument" % function)
1661                 else:
1662                     node = InterpretCompilerDirectives.unop_method_nodes[function](node.function.pos, operand=node.args[0])
1663             elif function in InterpretCompilerDirectives.binop_method_nodes:
1664                 if len(node.args) != 2:
1665                     error(node.function.pos, u"%s() takes exactly two arguments" % function)
1666                 else:
1667                     node = InterpretCompilerDirectives.binop_method_nodes[function](node.function.pos, operand1=node.args[0], operand2=node.args[1])
1668             elif function == u'cast':
1669                 if len(node.args) != 2:
1670                     error(node.function.pos, u"cast() takes exactly two arguments")
1671                 else:
1672                     type = node.args[0].analyse_as_type(self.current_env())
1673                     if type:
1674                         node = ExprNodes.TypecastNode(node.function.pos, type=type, operand=node.args[1])
1675                     else:
1676                         error(node.args[0].pos, "Not a type")
1677             elif function == u'sizeof':
1678                 if len(node.args) != 1:
1679                     error(node.function.pos, u"sizeof() takes exactly one argument")
1680                 else:
1681                     type = node.args[0].analyse_as_type(self.current_env())
1682                     if type:
1683                         node = ExprNodes.SizeofTypeNode(node.function.pos, arg_type=type)
1684                     else:
1685                         node = ExprNodes.SizeofVarNode(node.function.pos, operand=node.args[0])
1686             elif function == 'cmod':
1687                 if len(node.args) != 2:
1688                     error(node.function.pos, u"cmod() takes exactly two arguments")
1689                 else:
1690                     node = ExprNodes.binop_node(node.function.pos, '%', node.args[0], node.args[1])
1691                     node.cdivision = True
1692             elif function == 'cdiv':
1693                 if len(node.args) != 2:
1694                     error(node.function.pos, u"cdiv() takes exactly two arguments")
1695                 else:
1696                     node = ExprNodes.binop_node(node.function.pos, '/', node.args[0], node.args[1])
1697                     node.cdivision = True
1698             elif function == u'set':
1699                 node.function = ExprNodes.NameNode(node.pos, name=EncodedString('set'))
1700             else:
1701                 error(node.function.pos, u"'%s' not a valid cython language construct" % function)
1702
1703         self.visitchildren(node)
1704         return node
1705
1706
1707 class DebugTransform(CythonTransform):
1708     """
1709     Create debug information and all functions' visibility to extern in order
1710     to enable debugging.
1711     """
1712
1713     def __init__(self, context, options, result):
1714         super(DebugTransform, self).__init__(context)
1715         self.visited = cython.set()
1716         # our treebuilder and debug output writer
1717         # (see Cython.Debugger.debug_output.CythonDebugWriter)
1718         self.tb = self.context.gdb_debug_outputwriter
1719         #self.c_output_file = options.output_file
1720         self.c_output_file = result.c_file
1721
1722         # tells visit_NameNode whether it should register step-into functions
1723         self.register_stepinto = False
1724
1725     def visit_ModuleNode(self, node):
1726         self.tb.module_name = node.full_module_name
1727         attrs = dict(
1728             module_name=node.full_module_name,
1729             filename=node.pos[0].filename,
1730             c_filename=self.c_output_file)
1731
1732         self.tb.start('Module', attrs)
1733
1734         # serialize functions
1735         self.tb.start('Functions')
1736         self.visitchildren(node)
1737         self.tb.end('Functions')
1738
1739         # 2.3 compatibility. Serialize global variables
1740         self.tb.start('Globals')
1741         entries = {}
1742
1743         for k, v in node.scope.entries.iteritems():
1744             if (v.qualified_name not in self.visited and not
1745                 v.name.startswith('__pyx_') and not
1746                 v.type.is_cfunction and not
1747                 v.type.is_extension_type):
1748                 entries[k]= v
1749
1750         self.serialize_local_variables(entries)
1751         self.tb.end('Globals')
1752         # self.tb.end('Module') # end Module after the line number mapping in
1753         # Cython.Compiler.ModuleNode.ModuleNode._serialize_lineno_map
1754         return node
1755
1756     def visit_FuncDefNode(self, node):
1757         self.visited.add(node.local_scope.qualified_name)
1758         # node.entry.visibility = 'extern'
1759         if node.py_func is None:
1760             pf_cname = ''
1761         else:
1762             pf_cname = node.py_func.entry.func_cname
1763
1764         attrs = dict(
1765             name=node.entry.name,
1766             cname=node.entry.func_cname,
1767             pf_cname=pf_cname,
1768             qualified_name=node.local_scope.qualified_name,
1769             lineno=str(node.pos[1]))
1770
1771         self.tb.start('Function', attrs=attrs)
1772
1773         self.tb.start('Locals')
1774         self.serialize_local_variables(node.local_scope.entries)
1775         self.tb.end('Locals')
1776
1777         self.tb.start('Arguments')
1778         for arg in node.local_scope.arg_entries:
1779             self.tb.start(arg.name)
1780             self.tb.end(arg.name)
1781         self.tb.end('Arguments')
1782
1783         self.tb.start('StepIntoFunctions')
1784         self.register_stepinto = True
1785         self.visitchildren(node)
1786         self.register_stepinto = False
1787         self.tb.end('StepIntoFunctions')
1788         self.tb.end('Function')
1789
1790         return node
1791
1792     def visit_NameNode(self, node):
1793         if (self.register_stepinto and
1794             node.type.is_cfunction and
1795             getattr(node, 'is_called', False) and
1796             node.entry.func_cname is not None):
1797             # don't check node.entry.in_cinclude, as 'cdef extern: ...'
1798             # declared functions are not 'in_cinclude'.
1799             # This means we will list called 'cdef' functions as
1800             # "step into functions", but this is not an issue as they will be
1801             # recognized as Cython functions anyway.
1802             attrs = dict(name=node.entry.func_cname)
1803             self.tb.start('StepIntoFunction', attrs=attrs)
1804             self.tb.end('StepIntoFunction')
1805
1806         self.visitchildren(node)
1807         return node
1808
1809     def serialize_local_variables(self, entries):
1810         for entry in entries.values():
1811             if entry.type.is_pyobject:
1812                 vartype = 'PythonObject'
1813             else:
1814                 vartype = 'CObject'
1815
1816             cname = entry.cname
1817             # if entry.type.is_extension_type:
1818                 # cname = entry.type.typeptr_cname
1819
1820             if not entry.pos:
1821                 # this happens for variables that are not in the user's code,
1822                 # e.g. for the global __builtins__, __doc__, etc. We can just
1823                 # set the lineno to 0 for those.
1824                 lineno = '0'
1825             else:
1826                 lineno = str(entry.pos[1])
1827
1828             attrs = dict(
1829                 name=entry.name,
1830                 cname=cname,
1831                 qualified_name=entry.qualified_name,
1832                 type=vartype,
1833                 lineno=lineno)
1834
1835             self.tb.start('LocalVar', attrs)
1836             self.tb.end('LocalVar')
1837