Merge https://github.com/RafeKettler/cython
[cython.git] / Cython / Compiler / ParseTreeTransforms.py
1
2 import cython
3 cython.declare(PyrexTypes=object, Naming=object, ExprNodes=object, Nodes=object,
4                Options=object, UtilNodes=object, ModuleNode=object,
5                LetNode=object, LetRefNode=object, TreeFragment=object,
6                TemplateTransform=object, EncodedString=object,
7                error=object, warning=object, copy=object)
8
9 import PyrexTypes
10 import Naming
11 import ExprNodes
12 import Nodes
13 import Options
14
15 from Cython.Compiler.Visitor import VisitorTransform, TreeVisitor
16 from Cython.Compiler.Visitor import CythonTransform, EnvTransform, ScopeTrackingTransform
17 from Cython.Compiler.ModuleNode import ModuleNode
18 from Cython.Compiler.UtilNodes import LetNode, LetRefNode
19 from Cython.Compiler.TreeFragment import TreeFragment, TemplateTransform
20 from Cython.Compiler.StringEncoding import EncodedString
21 from Cython.Compiler.Errors import error, warning, CompileError, InternalError
22
23 import copy
24
25
26 class NameNodeCollector(TreeVisitor):
27     """Collect all NameNodes of a (sub-)tree in the ``name_nodes``
28     attribute.
29     """
30     def __init__(self):
31         super(NameNodeCollector, self).__init__()
32         self.name_nodes = []
33
34     def visit_NameNode(self, node):
35         self.name_nodes.append(node)
36
37     def visit_Node(self, node):
38         self._visitchildren(node, None)
39
40
41 class SkipDeclarations(object):
42     """
43     Variable and function declarations can often have a deep tree structure,
44     and yet most transformations don't need to descend to this depth.
45
46     Declaration nodes are removed after AnalyseDeclarationsTransform, so there
47     is no need to use this for transformations after that point.
48     """
49     def visit_CTypeDefNode(self, node):
50         return node
51
52     def visit_CVarDefNode(self, node):
53         return node
54
55     def visit_CDeclaratorNode(self, node):
56         return node
57
58     def visit_CBaseTypeNode(self, node):
59         return node
60
61     def visit_CEnumDefNode(self, node):
62         return node
63
64     def visit_CStructOrUnionDefNode(self, node):
65         return node
66
67 class NormalizeTree(CythonTransform):
68     """
69     This transform fixes up a few things after parsing
70     in order to make the parse tree more suitable for
71     transforms.
72
73     a) After parsing, blocks with only one statement will
74     be represented by that statement, not by a StatListNode.
75     When doing transforms this is annoying and inconsistent,
76     as one cannot in general remove a statement in a consistent
77     way and so on. This transform wraps any single statements
78     in a StatListNode containing a single statement.
79
80     b) The PassStatNode is a noop and serves no purpose beyond
81     plugging such one-statement blocks; i.e., once parsed a
82 `    "pass" can just as well be represented using an empty
83     StatListNode. This means less special cases to worry about
84     in subsequent transforms (one always checks to see if a
85     StatListNode has no children to see if the block is empty).
86     """
87
88     def __init__(self, context):
89         super(NormalizeTree, self).__init__(context)
90         self.is_in_statlist = False
91         self.is_in_expr = False
92
93     def visit_ExprNode(self, node):
94         stacktmp = self.is_in_expr
95         self.is_in_expr = True
96         self.visitchildren(node)
97         self.is_in_expr = stacktmp
98         return node
99
100     def visit_StatNode(self, node, is_listcontainer=False):
101         stacktmp = self.is_in_statlist
102         self.is_in_statlist = is_listcontainer
103         self.visitchildren(node)
104         self.is_in_statlist = stacktmp
105         if not self.is_in_statlist and not self.is_in_expr:
106             return Nodes.StatListNode(pos=node.pos, stats=[node])
107         else:
108             return node
109
110     def visit_StatListNode(self, node):
111         self.is_in_statlist = True
112         self.visitchildren(node)
113         self.is_in_statlist = False
114         return node
115
116     def visit_ParallelAssignmentNode(self, node):
117         return self.visit_StatNode(node, True)
118
119     def visit_CEnumDefNode(self, node):
120         return self.visit_StatNode(node, True)
121
122     def visit_CStructOrUnionDefNode(self, node):
123         return self.visit_StatNode(node, True)
124
125     # Eliminate PassStatNode
126     def visit_PassStatNode(self, node):
127         if not self.is_in_statlist:
128             return Nodes.StatListNode(pos=node.pos, stats=[])
129         else:
130             return []
131
132     def visit_CDeclaratorNode(self, node):
133         return node
134
135
136 class PostParseError(CompileError): pass
137
138 # error strings checked by unit tests, so define them
139 ERR_CDEF_INCLASS = 'Cannot assign default value to fields in cdef classes, structs or unions'
140 ERR_BUF_DEFAULTS = 'Invalid buffer defaults specification (see docs)'
141 ERR_INVALID_SPECIALATTR_TYPE = 'Special attributes must not have a type declared'
142 class PostParse(ScopeTrackingTransform):
143     """
144     Basic interpretation of the parse tree, as well as validity
145     checking that can be done on a very basic level on the parse
146     tree (while still not being a problem with the basic syntax,
147     as such).
148
149     Specifically:
150     - Default values to cdef assignments are turned into single
151     assignments following the declaration (everywhere but in class
152     bodies, where they raise a compile error)
153
154     - Interpret some node structures into Python runtime values.
155     Some nodes take compile-time arguments (currently:
156     TemplatedTypeNode[args] and __cythonbufferdefaults__ = {args}),
157     which should be interpreted. This happens in a general way
158     and other steps should be taken to ensure validity.
159
160     Type arguments cannot be interpreted in this way.
161
162     - For __cythonbufferdefaults__ the arguments are checked for
163     validity.
164
165     TemplatedTypeNode has its directives interpreted:
166     Any first positional argument goes into the "dtype" attribute,
167     any "ndim" keyword argument goes into the "ndim" attribute and
168     so on. Also it is checked that the directive combination is valid.
169     - __cythonbufferdefaults__ attributes are parsed and put into the
170     type information.
171
172     Note: Currently Parsing.py does a lot of interpretation and
173     reorganization that can be refactored into this transform
174     if a more pure Abstract Syntax Tree is wanted.
175     """
176
177     def __init__(self, context):
178         super(PostParse, self).__init__(context)
179         self.specialattribute_handlers = {
180             '__cythonbufferdefaults__' : self.handle_bufferdefaults
181         }
182
183     def visit_ModuleNode(self, node):
184         self.lambda_counter = 1
185         return super(PostParse, self).visit_ModuleNode(node)
186
187     def visit_LambdaNode(self, node):
188         # unpack a lambda expression into the corresponding DefNode
189         lambda_id = self.lambda_counter
190         self.lambda_counter += 1
191         node.lambda_name = EncodedString(u'lambda%d' % lambda_id)
192
193         body = Nodes.ReturnStatNode(
194             node.result_expr.pos, value = node.result_expr)
195         node.def_node = Nodes.DefNode(
196             node.pos, name=node.name, lambda_name=node.lambda_name,
197             args=node.args, star_arg=node.star_arg,
198             starstar_arg=node.starstar_arg,
199             body=body)
200         self.visitchildren(node)
201         return node
202
203     # cdef variables
204     def handle_bufferdefaults(self, decl):
205         if not isinstance(decl.default, ExprNodes.DictNode):
206             raise PostParseError(decl.pos, ERR_BUF_DEFAULTS)
207         self.scope_node.buffer_defaults_node = decl.default
208         self.scope_node.buffer_defaults_pos = decl.pos
209
210     def visit_CVarDefNode(self, node):
211         # This assumes only plain names and pointers are assignable on
212         # declaration. Also, it makes use of the fact that a cdef decl
213         # must appear before the first use, so we don't have to deal with
214         # "i = 3; cdef int i = i" and can simply move the nodes around.
215         try:
216             self.visitchildren(node)
217             stats = [node]
218             newdecls = []
219             for decl in node.declarators:
220                 declbase = decl
221                 while isinstance(declbase, Nodes.CPtrDeclaratorNode):
222                     declbase = declbase.base
223                 if isinstance(declbase, Nodes.CNameDeclaratorNode):
224                     if declbase.default is not None:
225                         if self.scope_type in ('cclass', 'pyclass', 'struct'):
226                             if isinstance(self.scope_node, Nodes.CClassDefNode):
227                                 handler = self.specialattribute_handlers.get(decl.name)
228                                 if handler:
229                                     if decl is not declbase:
230                                         raise PostParseError(decl.pos, ERR_INVALID_SPECIALATTR_TYPE)
231                                     handler(decl)
232                                     continue # Remove declaration
233                             raise PostParseError(decl.pos, ERR_CDEF_INCLASS)
234                         first_assignment = self.scope_type != 'module'
235                         stats.append(Nodes.SingleAssignmentNode(node.pos,
236                             lhs=ExprNodes.NameNode(node.pos, name=declbase.name),
237                             rhs=declbase.default, first=first_assignment))
238                         declbase.default = None
239                 newdecls.append(decl)
240             node.declarators = newdecls
241             return stats
242         except PostParseError, e:
243             # An error in a cdef clause is ok, simply remove the declaration
244             # and try to move on to report more errors
245             self.context.nonfatal_error(e)
246             return None
247
248     # Split parallel assignments (a,b = b,a) into separate partial
249     # assignments that are executed rhs-first using temps.  This
250     # restructuring must be applied before type analysis so that known
251     # types on rhs and lhs can be matched directly.  It is required in
252     # the case that the types cannot be coerced to a Python type in
253     # order to assign from a tuple.
254
255     def visit_SingleAssignmentNode(self, node):
256         self.visitchildren(node)
257         return self._visit_assignment_node(node, [node.lhs, node.rhs])
258
259     def visit_CascadedAssignmentNode(self, node):
260         self.visitchildren(node)
261         return self._visit_assignment_node(node, node.lhs_list + [node.rhs])
262
263     def _visit_assignment_node(self, node, expr_list):
264         """Flatten parallel assignments into separate single
265         assignments or cascaded assignments.
266         """
267         if sum([ 1 for expr in expr_list if expr.is_sequence_constructor ]) < 2:
268             # no parallel assignments => nothing to do
269             return node
270
271         expr_list_list = []
272         flatten_parallel_assignments(expr_list, expr_list_list)
273         temp_refs = []
274         eliminate_rhs_duplicates(expr_list_list, temp_refs)
275
276         nodes = []
277         for expr_list in expr_list_list:
278             lhs_list = expr_list[:-1]
279             rhs = expr_list[-1]
280             if len(lhs_list) == 1:
281                 node = Nodes.SingleAssignmentNode(rhs.pos,
282                     lhs = lhs_list[0], rhs = rhs)
283             else:
284                 node = Nodes.CascadedAssignmentNode(rhs.pos,
285                     lhs_list = lhs_list, rhs = rhs)
286             nodes.append(node)
287
288         if len(nodes) == 1:
289             assign_node = nodes[0]
290         else:
291             assign_node = Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes)
292
293         if temp_refs:
294             duplicates_and_temps = [ (temp.expression, temp)
295                                      for temp in temp_refs ]
296             sort_common_subsequences(duplicates_and_temps)
297             for _, temp_ref in duplicates_and_temps[::-1]:
298                 assign_node = LetNode(temp_ref, assign_node)
299
300         return assign_node
301
302     def _flatten_sequence(self, seq, result):
303         for arg in seq.args:
304             if arg.is_sequence_constructor:
305                 self._flatten_sequence(arg, result)
306             else:
307                 result.append(arg)
308         return result
309
310     def visit_DelStatNode(self, node):
311         self.visitchildren(node)
312         node.args = self._flatten_sequence(node, [])
313         return node
314
315
316 def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence):
317     """Replace rhs items by LetRefNodes if they appear more than once.
318     Creates a sequence of LetRefNodes that set up the required temps
319     and appends them to ref_node_sequence.  The input list is modified
320     in-place.
321     """
322     seen_nodes = cython.set()
323     ref_nodes = {}
324     def find_duplicates(node):
325         if node.is_literal or node.is_name:
326             # no need to replace those; can't include attributes here
327             # as their access is not necessarily side-effect free
328             return
329         if node in seen_nodes:
330             if node not in ref_nodes:
331                 ref_node = LetRefNode(node)
332                 ref_nodes[node] = ref_node
333                 ref_node_sequence.append(ref_node)
334         else:
335             seen_nodes.add(node)
336             if node.is_sequence_constructor:
337                 for item in node.args:
338                     find_duplicates(item)
339
340     for expr_list in expr_list_list:
341         rhs = expr_list[-1]
342         find_duplicates(rhs)
343     if not ref_nodes:
344         return
345
346     def substitute_nodes(node):
347         if node in ref_nodes:
348             return ref_nodes[node]
349         elif node.is_sequence_constructor:
350             node.args = list(map(substitute_nodes, node.args))
351         return node
352
353     # replace nodes inside of the common subexpressions
354     for node in ref_nodes:
355         if node.is_sequence_constructor:
356             node.args = list(map(substitute_nodes, node.args))
357
358     # replace common subexpressions on all rhs items
359     for expr_list in expr_list_list:
360         expr_list[-1] = substitute_nodes(expr_list[-1])
361
362 def sort_common_subsequences(items):
363     """Sort items/subsequences so that all items and subsequences that
364     an item contains appear before the item itself.  This is needed
365     because each rhs item must only be evaluated once, so its value
366     must be evaluated first and then reused when packing sequences
367     that contain it.
368
369     This implies a partial order, and the sort must be stable to
370     preserve the original order as much as possible, so we use a
371     simple insertion sort (which is very fast for short sequences, the
372     normal case in practice).
373     """
374     def contains(seq, x):
375         for item in seq:
376             if item is x:
377                 return True
378             elif item.is_sequence_constructor and contains(item.args, x):
379                 return True
380         return False
381     def lower_than(a,b):
382         return b.is_sequence_constructor and contains(b.args, a)
383
384     for pos, item in enumerate(items):
385         key = item[1] # the ResultRefNode which has already been injected into the sequences
386         new_pos = pos
387         for i in xrange(pos-1, -1, -1):
388             if lower_than(key, items[i][0]):
389                 new_pos = i
390         if new_pos != pos:
391             for i in xrange(pos, new_pos, -1):
392                 items[i] = items[i-1]
393             items[new_pos] = item
394
395 def flatten_parallel_assignments(input, output):
396     #  The input is a list of expression nodes, representing the LHSs
397     #  and RHS of one (possibly cascaded) assignment statement.  For
398     #  sequence constructors, rearranges the matching parts of both
399     #  sides into a list of equivalent assignments between the
400     #  individual elements.  This transformation is applied
401     #  recursively, so that nested structures get matched as well.
402     rhs = input[-1]
403     if not rhs.is_sequence_constructor or not sum([lhs.is_sequence_constructor for lhs in input[:-1]]):
404         output.append(input)
405         return
406
407     complete_assignments = []
408
409     rhs_size = len(rhs.args)
410     lhs_targets = [ [] for _ in xrange(rhs_size) ]
411     starred_assignments = []
412     for lhs in input[:-1]:
413         if not lhs.is_sequence_constructor:
414             if lhs.is_starred:
415                 error(lhs.pos, "starred assignment target must be in a list or tuple")
416             complete_assignments.append(lhs)
417             continue
418         lhs_size = len(lhs.args)
419         starred_targets = sum([1 for expr in lhs.args if expr.is_starred])
420         if starred_targets > 1:
421             error(lhs.pos, "more than 1 starred expression in assignment")
422             output.append([lhs,rhs])
423             continue
424         elif lhs_size - starred_targets > rhs_size:
425             error(lhs.pos, "need more than %d value%s to unpack"
426                   % (rhs_size, (rhs_size != 1) and 's' or ''))
427             output.append([lhs,rhs])
428             continue
429         elif starred_targets:
430             map_starred_assignment(lhs_targets, starred_assignments,
431                                    lhs.args, rhs.args)
432         elif lhs_size < rhs_size:
433             error(lhs.pos, "too many values to unpack (expected %d, got %d)"
434                   % (lhs_size, rhs_size))
435             output.append([lhs,rhs])
436             continue
437         else:
438             for targets, expr in zip(lhs_targets, lhs.args):
439                 targets.append(expr)
440
441     if complete_assignments:
442         complete_assignments.append(rhs)
443         output.append(complete_assignments)
444
445     # recursively flatten partial assignments
446     for cascade, rhs in zip(lhs_targets, rhs.args):
447         if cascade:
448             cascade.append(rhs)
449             flatten_parallel_assignments(cascade, output)
450
451     # recursively flatten starred assignments
452     for cascade in starred_assignments:
453         if cascade[0].is_sequence_constructor:
454             flatten_parallel_assignments(cascade, output)
455         else:
456             output.append(cascade)
457
458 def map_starred_assignment(lhs_targets, starred_assignments, lhs_args, rhs_args):
459     # Appends the fixed-position LHS targets to the target list that
460     # appear left and right of the starred argument.
461     #
462     # The starred_assignments list receives a new tuple
463     # (lhs_target, rhs_values_list) that maps the remaining arguments
464     # (those that match the starred target) to a list.
465
466     # left side of the starred target
467     for i, (targets, expr) in enumerate(zip(lhs_targets, lhs_args)):
468         if expr.is_starred:
469             starred = i
470             lhs_remaining = len(lhs_args) - i - 1
471             break
472         targets.append(expr)
473     else:
474         raise InternalError("no starred arg found when splitting starred assignment")
475
476     # right side of the starred target
477     for i, (targets, expr) in enumerate(zip(lhs_targets[-lhs_remaining:],
478                                             lhs_args[starred + 1:])):
479         targets.append(expr)
480
481     # the starred target itself, must be assigned a (potentially empty) list
482     target = lhs_args[starred].target # unpack starred node
483     starred_rhs = rhs_args[starred:]
484     if lhs_remaining:
485         starred_rhs = starred_rhs[:-lhs_remaining]
486     if starred_rhs:
487         pos = starred_rhs[0].pos
488     else:
489         pos = target.pos
490     starred_assignments.append([
491         target, ExprNodes.ListNode(pos=pos, args=starred_rhs)])
492
493
494 class PxdPostParse(CythonTransform, SkipDeclarations):
495     """
496     Basic interpretation/validity checking that should only be
497     done on pxd trees.
498
499     A lot of this checking currently happens in the parser; but
500     what is listed below happens here.
501
502     - "def" functions are let through only if they fill the
503     getbuffer/releasebuffer slots
504
505     - cdef functions are let through only if they are on the
506     top level and are declared "inline"
507     """
508     ERR_INLINE_ONLY = "function definition in pxd file must be declared 'cdef inline'"
509     ERR_NOGO_WITH_INLINE = "inline function definition in pxd file cannot be '%s'"
510
511     def __call__(self, node):
512         self.scope_type = 'pxd'
513         return super(PxdPostParse, self).__call__(node)
514
515     def visit_CClassDefNode(self, node):
516         old = self.scope_type
517         self.scope_type = 'cclass'
518         self.visitchildren(node)
519         self.scope_type = old
520         return node
521
522     def visit_FuncDefNode(self, node):
523         # FuncDefNode always come with an implementation (without
524         # an imp they are CVarDefNodes..)
525         err = self.ERR_INLINE_ONLY
526
527         if (isinstance(node, Nodes.DefNode) and self.scope_type == 'cclass'
528             and node.name in ('__getbuffer__', '__releasebuffer__')):
529             err = None # allow these slots
530
531         if isinstance(node, Nodes.CFuncDefNode):
532             if u'inline' in node.modifiers and self.scope_type == 'pxd':
533                 node.inline_in_pxd = True
534                 if node.visibility != 'private':
535                     err = self.ERR_NOGO_WITH_INLINE % node.visibility
536                 elif node.api:
537                     err = self.ERR_NOGO_WITH_INLINE % 'api'
538                 else:
539                     err = None # allow inline function
540             else:
541                 err = self.ERR_INLINE_ONLY
542
543         if err:
544             self.context.nonfatal_error(PostParseError(node.pos, err))
545             return None
546         else:
547             return node
548
549 class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
550     """
551     After parsing, directives can be stored in a number of places:
552     - #cython-comments at the top of the file (stored in ModuleNode)
553     - Command-line arguments overriding these
554     - @cython.directivename decorators
555     - with cython.directivename: statements
556
557     This transform is responsible for interpreting these various sources
558     and store the directive in two ways:
559     - Set the directives attribute of the ModuleNode for global directives.
560     - Use a CompilerDirectivesNode to override directives for a subtree.
561
562     (The first one is primarily to not have to modify with the tree
563     structure, so that ModuleNode stay on top.)
564
565     The directives are stored in dictionaries from name to value in effect.
566     Each such dictionary is always filled in for all possible directives,
567     using default values where no value is given by the user.
568
569     The available directives are controlled in Options.py.
570
571     Note that we have to run this prior to analysis, and so some minor
572     duplication of functionality has to occur: We manually track cimports
573     and which names the "cython" module may have been imported to.
574     """
575     unop_method_nodes = {
576         'typeof': ExprNodes.TypeofNode,
577
578         'operator.address': ExprNodes.AmpersandNode,
579         'operator.dereference': ExprNodes.DereferenceNode,
580         'operator.preincrement' : ExprNodes.inc_dec_constructor(True, '++'),
581         'operator.predecrement' : ExprNodes.inc_dec_constructor(True, '--'),
582         'operator.postincrement': ExprNodes.inc_dec_constructor(False, '++'),
583         'operator.postdecrement': ExprNodes.inc_dec_constructor(False, '--'),
584
585         # For backwards compatability.
586         'address': ExprNodes.AmpersandNode,
587     }
588
589     binop_method_nodes = {
590         'operator.comma'        : ExprNodes.c_binop_constructor(','),
591     }
592
593     special_methods = cython.set(['declare', 'union', 'struct', 'typedef', 'sizeof',
594                                   'cast', 'pointer', 'compiled', 'NULL'])
595     special_methods.update(unop_method_nodes.keys())
596
597     def __init__(self, context, compilation_directive_defaults):
598         super(InterpretCompilerDirectives, self).__init__(context)
599         self.compilation_directive_defaults = {}
600         for key, value in compilation_directive_defaults.items():
601             self.compilation_directive_defaults[unicode(key)] = copy.deepcopy(value)
602         self.cython_module_names = cython.set()
603         self.directive_names = {}
604
605     def check_directive_scope(self, pos, directive, scope):
606         legal_scopes = Options.directive_scopes.get(directive, None)
607         if legal_scopes and scope not in legal_scopes:
608             self.context.nonfatal_error(PostParseError(pos, 'The %s compiler directive '
609                                         'is not allowed in %s scope' % (directive, scope)))
610             return False
611         else:
612             return True
613
614     # Set up processing and handle the cython: comments.
615     def visit_ModuleNode(self, node):
616         for key, value in node.directive_comments.items():
617             if not self.check_directive_scope(node.pos, key, 'module'):
618                 self.wrong_scope_error(node.pos, key, 'module')
619                 del node.directive_comments[key]
620
621         directives = copy.deepcopy(Options.directive_defaults)
622         directives.update(copy.deepcopy(self.compilation_directive_defaults))
623         directives.update(node.directive_comments)
624         self.directives = directives
625         node.directives = directives
626         self.visitchildren(node)
627         node.cython_module_names = self.cython_module_names
628         return node
629
630     # The following four functions track imports and cimports that
631     # begin with "cython"
632     def is_cython_directive(self, name):
633         return (name in Options.directive_types or
634                 name in self.special_methods or
635                 PyrexTypes.parse_basic_type(name))
636
637     def visit_CImportStatNode(self, node):
638         if node.module_name == u"cython":
639             self.cython_module_names.add(node.as_name or u"cython")
640         elif node.module_name.startswith(u"cython."):
641             if node.as_name:
642                 self.directive_names[node.as_name] = node.module_name[7:]
643             else:
644                 self.cython_module_names.add(u"cython")
645             # if this cimport was a compiler directive, we don't
646             # want to leave the cimport node sitting in the tree
647             return None
648         return node
649
650     def visit_FromCImportStatNode(self, node):
651         if (node.module_name == u"cython") or \
652                node.module_name.startswith(u"cython."):
653             submodule = (node.module_name + u".")[7:]
654             newimp = []
655             for pos, name, as_name, kind in node.imported_names:
656                 full_name = submodule + name
657                 if self.is_cython_directive(full_name):
658                     if as_name is None:
659                         as_name = full_name
660                     self.directive_names[as_name] = full_name
661                     if kind is not None:
662                         self.context.nonfatal_error(PostParseError(pos,
663                             "Compiler directive imports must be plain imports"))
664                 else:
665                     newimp.append((pos, name, as_name, kind))
666             if not newimp:
667                 return None
668             node.imported_names = newimp
669         return node
670
671     def visit_FromImportStatNode(self, node):
672         if (node.module.module_name.value == u"cython") or \
673                node.module.module_name.value.startswith(u"cython."):
674             submodule = (node.module.module_name.value + u".")[7:]
675             newimp = []
676             for name, name_node in node.items:
677                 full_name = submodule + name
678                 if self.is_cython_directive(full_name):
679                     self.directive_names[name_node.name] = full_name
680                 else:
681                     newimp.append((name, name_node))
682             if not newimp:
683                 return None
684             node.items = newimp
685         return node
686
687     def visit_SingleAssignmentNode(self, node):
688         if (isinstance(node.rhs, ExprNodes.ImportNode) and
689                 node.rhs.module_name.value == u'cython'):
690             node = Nodes.CImportStatNode(node.pos,
691                                          module_name = u'cython',
692                                          as_name = node.lhs.name)
693             self.visit_CImportStatNode(node)
694         else:
695             self.visitchildren(node)
696         return node
697
698     def visit_NameNode(self, node):
699         if node.name in self.cython_module_names:
700             node.is_cython_module = True
701         else:
702             node.cython_attribute = self.directive_names.get(node.name)
703         return node
704
705     def try_to_parse_directives(self, node):
706         # If node is the contents of an directive (in a with statement or
707         # decorator), returns a list of (directivename, value) pairs.
708         # Otherwise, returns None
709         if isinstance(node, ExprNodes.CallNode):
710             self.visit(node.function)
711             optname = node.function.as_cython_attribute()
712             if optname:
713                 directivetype = Options.directive_types.get(optname)
714                 if directivetype:
715                     args, kwds = node.explicit_args_kwds()
716                     directives = []
717                     key_value_pairs = []
718                     if kwds is not None and directivetype is not dict:
719                         for keyvalue in kwds.key_value_pairs:
720                             key, value = keyvalue
721                             sub_optname = "%s.%s" % (optname, key.value)
722                             if Options.directive_types.get(sub_optname):
723                                 directives.append(self.try_to_parse_directive(sub_optname, [value], None, keyvalue.pos))
724                             else:
725                                 key_value_pairs.append(keyvalue)
726                         if not key_value_pairs:
727                             kwds = None
728                         else:
729                             kwds.key_value_pairs = key_value_pairs
730                         if directives and not kwds and not args:
731                             return directives
732                     directives.append(self.try_to_parse_directive(optname, args, kwds, node.function.pos))
733                     return directives
734         elif isinstance(node, (ExprNodes.AttributeNode, ExprNodes.NameNode)):
735             self.visit(node)
736             optname = node.as_cython_attribute()
737             if optname:
738                 directivetype = Options.directive_types.get(optname)
739                 if directivetype is bool:
740                     return [(optname, True)]
741                 elif directivetype is None:
742                     return [(optname, None)]
743                 else:
744                     raise PostParseError(
745                         node.pos, "The '%s' directive should be used as a function call." % optname)
746         return None
747
748     def try_to_parse_directive(self, optname, args, kwds, pos):
749         directivetype = Options.directive_types.get(optname)
750         if len(args) == 1 and isinstance(args[0], ExprNodes.NoneNode):
751             return optname, Options.directive_defaults[optname]
752         elif directivetype is bool:
753             if kwds is not None or len(args) != 1 or not isinstance(args[0], ExprNodes.BoolNode):
754                 raise PostParseError(pos,
755                     'The %s directive takes one compile-time boolean argument' % optname)
756             return (optname, args[0].value)
757         elif directivetype is str:
758             if kwds is not None or len(args) != 1 or not isinstance(args[0], (ExprNodes.StringNode,
759                                                                               ExprNodes.UnicodeNode)):
760                 raise PostParseError(pos,
761                     'The %s directive takes one compile-time string argument' % optname)
762             return (optname, str(args[0].value))
763         elif directivetype is dict:
764             if len(args) != 0:
765                 raise PostParseError(pos,
766                     'The %s directive takes no prepositional arguments' % optname)
767             return optname, dict([(key.value, value) for key, value in kwds.key_value_pairs])
768         elif directivetype is list:
769             if kwds and len(kwds) != 0:
770                 raise PostParseError(pos,
771                     'The %s directive takes no keyword arguments' % optname)
772             return optname, [ str(arg.value) for arg in args ]
773         else:
774             assert False
775
776     def visit_with_directives(self, body, directives):
777         olddirectives = self.directives
778         newdirectives = copy.copy(olddirectives)
779         newdirectives.update(directives)
780         self.directives = newdirectives
781         assert isinstance(body, Nodes.StatListNode), body
782         retbody = self.visit_Node(body)
783         directive = Nodes.CompilerDirectivesNode(pos=retbody.pos, body=retbody,
784                                                  directives=newdirectives)
785         self.directives = olddirectives
786         return directive
787
788     # Handle decorators
789     def visit_FuncDefNode(self, node):
790         directives = self._extract_directives(node, 'function')
791         if not directives:
792             return self.visit_Node(node)
793         body = Nodes.StatListNode(node.pos, stats=[node])
794         return self.visit_with_directives(body, directives)
795
796     def visit_CVarDefNode(self, node):
797         if not node.decorators:
798             return node
799         for dec in node.decorators:
800             for directive in self.try_to_parse_directives(dec.decorator) or ():
801                 if directive is not None and directive[0] == u'locals':
802                     node.directive_locals = directive[1]
803                 else:
804                     self.context.nonfatal_error(PostParseError(dec.pos,
805                         "Cdef functions can only take cython.locals() decorator."))
806         return node
807
808     def visit_CClassDefNode(self, node):
809         directives = self._extract_directives(node, 'cclass')
810         if not directives:
811             return self.visit_Node(node)
812         body = Nodes.StatListNode(node.pos, stats=[node])
813         return self.visit_with_directives(body, directives)
814
815     def visit_PyClassDefNode(self, node):
816         directives = self._extract_directives(node, 'class')
817         if not directives:
818             return self.visit_Node(node)
819         body = Nodes.StatListNode(node.pos, stats=[node])
820         return self.visit_with_directives(body, directives)
821
822     def _extract_directives(self, node, scope_name):
823         if not node.decorators:
824             return {}
825         # Split the decorators into two lists -- real decorators and directives
826         directives = []
827         realdecs = []
828         for dec in node.decorators:
829             new_directives = self.try_to_parse_directives(dec.decorator)
830             if new_directives is not None:
831                 for directive in new_directives:
832                     if self.check_directive_scope(node.pos, directive[0], scope_name):
833                         directives.append(directive)
834             else:
835                 realdecs.append(dec)
836         if realdecs and isinstance(node, (Nodes.CFuncDefNode, Nodes.CClassDefNode)):
837             raise PostParseError(realdecs[0].pos, "Cdef functions/classes cannot take arbitrary decorators.")
838         else:
839             node.decorators = realdecs
840         # merge or override repeated directives
841         optdict = {}
842         directives.reverse() # Decorators coming first take precedence
843         for directive in directives:
844             name, value = directive
845             if name in optdict:
846                 old_value = optdict[name]
847                 # keywords and arg lists can be merged, everything
848                 # else overrides completely
849                 if isinstance(old_value, dict):
850                     old_value.update(value)
851                 elif isinstance(old_value, list):
852                     old_value.extend(value)
853                 else:
854                     optdict[name] = value
855             else:
856                 optdict[name] = value
857         return optdict
858
859     # Handle with statements
860     def visit_WithStatNode(self, node):
861         directive_dict = {}
862         for directive in self.try_to_parse_directives(node.manager) or []:
863             if directive is not None:
864                 if node.target is not None:
865                     self.context.nonfatal_error(
866                         PostParseError(node.pos, "Compiler directive with statements cannot contain 'as'"))
867                 else:
868                     name, value = directive
869                     if name == 'nogil':
870                         # special case: in pure mode, "with nogil" spells "with cython.nogil"
871                         node = Nodes.GILStatNode(node.pos, state = "nogil", body = node.body)
872                         return self.visit_Node(node)
873                     if self.check_directive_scope(node.pos, name, 'with statement'):
874                         directive_dict[name] = value
875         if directive_dict:
876             return self.visit_with_directives(node.body, directive_dict)
877         return self.visit_Node(node)
878
879 class WithTransform(CythonTransform, SkipDeclarations):
880
881     # EXCINFO is manually set to a variable that contains
882     # the exc_info() tuple that can be generated by the enclosing except
883     # statement.
884     template_without_target = TreeFragment(u"""
885         MGR = EXPR
886         EXIT = MGR.__exit__
887         MGR.__enter__()
888         EXC = True
889         try:
890             try:
891                 EXCINFO = None
892                 BODY
893             except:
894                 EXC = False
895                 if not EXIT(*EXCINFO):
896                     raise
897         finally:
898             if EXC:
899                 EXIT(None, None, None)
900     """, temps=[u'MGR', u'EXC', u"EXIT"],
901     pipeline=[NormalizeTree(None)])
902
903     template_with_target = TreeFragment(u"""
904         MGR = EXPR
905         EXIT = MGR.__exit__
906         VALUE = MGR.__enter__()
907         EXC = True
908         try:
909             try:
910                 EXCINFO = None
911                 TARGET = VALUE
912                 BODY
913             except:
914                 EXC = False
915                 if not EXIT(*EXCINFO):
916                     raise
917         finally:
918             if EXC:
919                 EXIT(None, None, None)
920             MGR = EXIT = VALUE = EXC = None
921
922     """, temps=[u'MGR', u'EXC', u"EXIT", u"VALUE"],
923     pipeline=[NormalizeTree(None)])
924
925     def visit_WithStatNode(self, node):
926         # TODO: Cleanup badly needed
927         TemplateTransform.temp_name_counter += 1
928         handle = "__tmpvar_%d" % TemplateTransform.temp_name_counter
929
930         self.visitchildren(node, ['body'])
931         excinfo_temp = ExprNodes.NameNode(node.pos, name=handle)#TempHandle(Builtin.tuple_type)
932         if node.target is not None:
933             result = self.template_with_target.substitute({
934                 u'EXPR' : node.manager,
935                 u'BODY' : node.body,
936                 u'TARGET' : node.target,
937                 u'EXCINFO' : excinfo_temp
938                 }, pos=node.pos)
939         else:
940             result = self.template_without_target.substitute({
941                 u'EXPR' : node.manager,
942                 u'BODY' : node.body,
943                 u'EXCINFO' : excinfo_temp
944                 }, pos=node.pos)
945
946         # Set except excinfo target to EXCINFO
947         try_except = result.stats[-1].body.stats[-1]
948         try_except.except_clauses[0].excinfo_target = ExprNodes.NameNode(node.pos, name=handle)
949 #            excinfo_temp.ref(node.pos))
950
951 #        result.stats[-1].body.stats[-1] = TempsBlockNode(
952 #            node.pos, temps=[excinfo_temp], body=try_except)
953
954         return result
955
956     def visit_ExprNode(self, node):
957         # With statements are never inside expressions.
958         return node
959
960
961 class DecoratorTransform(CythonTransform, SkipDeclarations):
962
963     def visit_DefNode(self, func_node):
964         self.visitchildren(func_node)
965         if not func_node.decorators:
966             return func_node
967         return self._handle_decorators(
968             func_node, func_node.name)
969
970     def visit_CClassDefNode(self, class_node):
971         # This doesn't currently work, so it's disabled.
972         #
973         # Problem: assignments to cdef class names do not work.  They
974         # would require an additional check anyway, as the extension
975         # type must not change its C type, so decorators cannot
976         # replace an extension type, just alter it and return it.
977
978         self.visitchildren(class_node)
979         if not class_node.decorators:
980             return class_node
981         error(class_node.pos,
982               "Decorators not allowed on cdef classes (used on type '%s')" % class_node.class_name)
983         return class_node
984         #return self._handle_decorators(
985         #    class_node, class_node.class_name)
986
987     def visit_ClassDefNode(self, class_node):
988         self.visitchildren(class_node)
989         if not class_node.decorators:
990             return class_node
991         return self._handle_decorators(
992             class_node, class_node.name)
993
994     def _handle_decorators(self, node, name):
995         decorator_result = ExprNodes.NameNode(node.pos, name = name)
996         for decorator in node.decorators[::-1]:
997             decorator_result = ExprNodes.SimpleCallNode(
998                 decorator.pos,
999                 function = decorator.decorator,
1000                 args = [decorator_result])
1001
1002         name_node = ExprNodes.NameNode(node.pos, name = name)
1003         reassignment = Nodes.SingleAssignmentNode(
1004             node.pos,
1005             lhs = name_node,
1006             rhs = decorator_result)
1007         return [node, reassignment]
1008
1009
1010 class AnalyseDeclarationsTransform(CythonTransform):
1011
1012     basic_property = TreeFragment(u"""
1013 property NAME:
1014     def __get__(self):
1015         return ATTR
1016     def __set__(self, value):
1017         ATTR = value
1018     """, level='c_class')
1019     basic_pyobject_property = TreeFragment(u"""
1020 property NAME:
1021     def __get__(self):
1022         return ATTR
1023     def __set__(self, value):
1024         ATTR = value
1025     def __del__(self):
1026         ATTR = None
1027     """, level='c_class')
1028     basic_property_ro = TreeFragment(u"""
1029 property NAME:
1030     def __get__(self):
1031         return ATTR
1032     """, level='c_class')
1033
1034     struct_or_union_wrapper = TreeFragment(u"""
1035 cdef class NAME:
1036     cdef TYPE value
1037     def __init__(self, MEMBER=None):
1038         cdef int count
1039         count = 0
1040         INIT_ASSIGNMENTS
1041         if IS_UNION and count > 1:
1042             raise ValueError, "At most one union member should be specified."
1043     def __str__(self):
1044         return STR_FORMAT % MEMBER_TUPLE
1045     def __repr__(self):
1046         return REPR_FORMAT % MEMBER_TUPLE
1047     """)
1048
1049     init_assignment = TreeFragment(u"""
1050 if VALUE is not None:
1051     ATTR = VALUE
1052     count += 1
1053     """)
1054
1055     def __call__(self, root):
1056         self.env_stack = [root.scope]
1057         # needed to determine if a cdef var is declared after it's used.
1058         self.seen_vars_stack = []
1059         return super(AnalyseDeclarationsTransform, self).__call__(root)
1060
1061     def visit_NameNode(self, node):
1062         self.seen_vars_stack[-1].add(node.name)
1063         return node
1064
1065     def visit_ModuleNode(self, node):
1066         self.seen_vars_stack.append(cython.set())
1067         node.analyse_declarations(self.env_stack[-1])
1068         self.visitchildren(node)
1069         self.seen_vars_stack.pop()
1070         return node
1071
1072     def visit_LambdaNode(self, node):
1073         node.analyse_declarations(self.env_stack[-1])
1074         self.visitchildren(node)
1075         return node
1076
1077     def visit_ClassDefNode(self, node):
1078         self.env_stack.append(node.scope)
1079         self.visitchildren(node)
1080         self.env_stack.pop()
1081         return node
1082
1083     def visit_CClassDefNode(self, node):
1084         node = self.visit_ClassDefNode(node)
1085         if node.scope and node.scope.implemented:
1086             stats = []
1087             for entry in node.scope.var_entries:
1088                 if entry.needs_property:
1089                     property = self.create_Property(entry)
1090                     property.analyse_declarations(node.scope)
1091                     self.visit(property)
1092                     stats.append(property)
1093             if stats:
1094                 node.body.stats += stats
1095         return node
1096
1097     def visit_FuncDefNode(self, node):
1098         self.seen_vars_stack.append(cython.set())
1099         lenv = node.local_scope
1100         node.body.analyse_control_flow(lenv) # this will be totally refactored
1101         node.declare_arguments(lenv)
1102         for var, type_node in node.directive_locals.items():
1103             if not lenv.lookup_here(var):   # don't redeclare args
1104                 type = type_node.analyse_as_type(lenv)
1105                 if type:
1106                     lenv.declare_var(var, type, type_node.pos)
1107                 else:
1108                     error(type_node.pos, "Not a type")
1109         node.body.analyse_declarations(lenv)
1110         self.env_stack.append(lenv)
1111         self.visitchildren(node)
1112         self.env_stack.pop()
1113         self.seen_vars_stack.pop()
1114         return node
1115
1116     def visit_ScopedExprNode(self, node):
1117         env = self.env_stack[-1]
1118         node.analyse_declarations(env)
1119         # the node may or may not have a local scope
1120         if node.has_local_scope:
1121             self.seen_vars_stack.append(cython.set(self.seen_vars_stack[-1]))
1122             self.env_stack.append(node.expr_scope)
1123             node.analyse_scoped_declarations(node.expr_scope)
1124             self.visitchildren(node)
1125             self.env_stack.pop()
1126             self.seen_vars_stack.pop()
1127         else:
1128             node.analyse_scoped_declarations(env)
1129             self.visitchildren(node)
1130         return node
1131
1132     def visit_TempResultFromStatNode(self, node):
1133         self.visitchildren(node)
1134         node.analyse_declarations(self.env_stack[-1])
1135         return node
1136
1137     def visit_CStructOrUnionDefNode(self, node):
1138         # Create a wrapper node if needed.
1139         # We want to use the struct type information (so it can't happen
1140         # before this phase) but also create new objects to be declared
1141         # (so it can't happen later).
1142         # Note that we don't return the original node, as it is
1143         # never used after this phase.
1144         if True: # private (default)
1145             return None
1146
1147         self_value = ExprNodes.AttributeNode(
1148             pos = node.pos,
1149             obj = ExprNodes.NameNode(pos=node.pos, name=u"self"),
1150             attribute = EncodedString(u"value"))
1151         var_entries = node.entry.type.scope.var_entries
1152         attributes = []
1153         for entry in var_entries:
1154             attributes.append(ExprNodes.AttributeNode(pos = entry.pos,
1155                                                       obj = self_value,
1156                                                       attribute = entry.name))
1157         # __init__ assignments
1158         init_assignments = []
1159         for entry, attr in zip(var_entries, attributes):
1160             # TODO: branch on visibility
1161             init_assignments.append(self.init_assignment.substitute({
1162                     u"VALUE": ExprNodes.NameNode(entry.pos, name = entry.name),
1163                     u"ATTR": attr,
1164                 }, pos = entry.pos))
1165
1166         # create the class
1167         str_format = u"%s(%s)" % (node.entry.type.name, ("%s, " * len(attributes))[:-2])
1168         wrapper_class = self.struct_or_union_wrapper.substitute({
1169             u"INIT_ASSIGNMENTS": Nodes.StatListNode(node.pos, stats = init_assignments),
1170             u"IS_UNION": ExprNodes.BoolNode(node.pos, value = not node.entry.type.is_struct),
1171             u"MEMBER_TUPLE": ExprNodes.TupleNode(node.pos, args=attributes),
1172             u"STR_FORMAT": ExprNodes.StringNode(node.pos, value = EncodedString(str_format)),
1173             u"REPR_FORMAT": ExprNodes.StringNode(node.pos, value = EncodedString(str_format.replace("%s", "%r"))),
1174         }, pos = node.pos).stats[0]
1175         wrapper_class.class_name = node.name
1176         wrapper_class.shadow = True
1177         class_body = wrapper_class.body.stats
1178
1179         # fix value type
1180         assert isinstance(class_body[0].base_type, Nodes.CSimpleBaseTypeNode)
1181         class_body[0].base_type.name = node.name
1182
1183         # fix __init__ arguments
1184         init_method = class_body[1]
1185         assert isinstance(init_method, Nodes.DefNode) and init_method.name == '__init__'
1186         arg_template = init_method.args[1]
1187         if not node.entry.type.is_struct:
1188             arg_template.kw_only = True
1189         del init_method.args[1]
1190         for entry, attr in zip(var_entries, attributes):
1191             arg = copy.deepcopy(arg_template)
1192             arg.declarator.name = entry.name
1193             init_method.args.append(arg)
1194             
1195         # setters/getters
1196         for entry, attr in zip(var_entries, attributes):
1197             # TODO: branch on visibility
1198             if entry.type.is_pyobject:
1199                 template = self.basic_pyobject_property
1200             else:
1201                 template = self.basic_property
1202             property = template.substitute({
1203                     u"ATTR": attr,
1204                 }, pos = entry.pos).stats[0]
1205             property.name = entry.name
1206             wrapper_class.body.stats.append(property)
1207             
1208         wrapper_class.analyse_declarations(self.env_stack[-1])
1209         return self.visit_CClassDefNode(wrapper_class)
1210
1211     # Some nodes are no longer needed after declaration
1212     # analysis and can be dropped. The analysis was performed
1213     # on these nodes in a seperate recursive process from the
1214     # enclosing function or module, so we can simply drop them.
1215     def visit_CDeclaratorNode(self, node):
1216         # necessary to ensure that all CNameDeclaratorNodes are visited.
1217         self.visitchildren(node)
1218         return node
1219
1220     def visit_CTypeDefNode(self, node):
1221         return node
1222
1223     def visit_CBaseTypeNode(self, node):
1224         return None
1225
1226     def visit_CEnumDefNode(self, node):
1227         if node.visibility == 'public':
1228             return node
1229         else:
1230             return None
1231
1232     def visit_CNameDeclaratorNode(self, node):
1233         if node.name in self.seen_vars_stack[-1]:
1234             entry = self.env_stack[-1].lookup(node.name)
1235             if entry is None or entry.visibility != 'extern':
1236                 warning(node.pos, "cdef variable '%s' declared after it is used" % node.name, 2)
1237         self.visitchildren(node)
1238         return node
1239
1240     def visit_CVarDefNode(self, node):
1241         # to ensure all CNameDeclaratorNodes are visited.
1242         self.visitchildren(node)
1243         return None
1244
1245     def create_Property(self, entry):
1246         if entry.visibility == 'public':
1247             if entry.type.is_pyobject:
1248                 template = self.basic_pyobject_property
1249             else:
1250                 template = self.basic_property
1251         elif entry.visibility == 'readonly':
1252             template = self.basic_property_ro
1253         property = template.substitute({
1254                 u"ATTR": ExprNodes.AttributeNode(pos=entry.pos,
1255                                                  obj=ExprNodes.NameNode(pos=entry.pos, name="self"),
1256                                                  attribute=entry.name),
1257             }, pos=entry.pos).stats[0]
1258         property.name = entry.name
1259         # ---------------------------------------
1260         # XXX This should go to AutoDocTransforms
1261         # ---------------------------------------
1262         if (Options.docstrings and
1263             self.current_directives['embedsignature']):
1264             attr_name = entry.name
1265             type_name = entry.type.declaration_code("", for_display=1)
1266             default_value = ''
1267             if not entry.type.is_pyobject:
1268                 type_name = "'%s'" % type_name
1269             elif entry.type.is_extension_type:
1270                 type_name = entry.type.module_name + '.' + type_name
1271             if entry.init is not None:
1272                 default_value = ' = ' + entry.init
1273             elif entry.init_to_none:
1274                 default_value = ' = ' + repr(None)
1275             docstring = attr_name + ': ' + type_name + default_value
1276             property.doc = EncodedString(docstring)
1277         # ---------------------------------------
1278         return property
1279
1280 class AnalyseExpressionsTransform(CythonTransform):
1281
1282     def visit_ModuleNode(self, node):
1283         node.scope.infer_types()
1284         node.body.analyse_expressions(node.scope)
1285         self.visitchildren(node)
1286         return node
1287
1288     def visit_FuncDefNode(self, node):
1289         node.local_scope.infer_types()
1290         node.body.analyse_expressions(node.local_scope)
1291         self.visitchildren(node)
1292         return node
1293
1294     def visit_ScopedExprNode(self, node):
1295         if node.has_local_scope:
1296             node.expr_scope.infer_types()
1297             node.analyse_scoped_expressions(node.expr_scope)
1298         self.visitchildren(node)
1299         return node
1300
1301 class ExpandInplaceOperators(EnvTransform):
1302
1303     def visit_InPlaceAssignmentNode(self, node):
1304         lhs = node.lhs
1305         rhs = node.rhs
1306         if lhs.type.is_cpp_class:
1307             # No getting around this exact operator here.
1308             return node
1309         if isinstance(lhs, ExprNodes.IndexNode) and lhs.is_buffer_access:
1310             # There is code to handle this case.
1311             return node
1312
1313         env = self.current_env()
1314         def side_effect_free_reference(node, setting=False):
1315             if isinstance(node, ExprNodes.NameNode):
1316                 return node, []
1317             elif node.type.is_pyobject and not setting:
1318                 node = LetRefNode(node)
1319                 return node, [node]
1320             elif isinstance(node, ExprNodes.IndexNode):
1321                 if node.is_buffer_access:
1322                     raise ValueError, "Buffer access"
1323                 base, temps = side_effect_free_reference(node.base)
1324                 index = LetRefNode(node.index)
1325                 return ExprNodes.IndexNode(node.pos, base=base, index=index), temps + [index]
1326             elif isinstance(node, ExprNodes.AttributeNode):
1327                 obj, temps = side_effect_free_reference(node.obj)
1328                 return ExprNodes.AttributeNode(node.pos, obj=obj, attribute=node.attribute), temps
1329             else:
1330                 node = LetRefNode(node)
1331                 return node, [node]
1332         try:
1333             lhs, let_ref_nodes = side_effect_free_reference(lhs, setting=True)
1334         except ValueError:
1335             return node
1336         dup = lhs.__class__(**lhs.__dict__)
1337         binop = ExprNodes.binop_node(node.pos,
1338                                      operator = node.operator,
1339                                      operand1 = dup,
1340                                      operand2 = rhs,
1341                                      inplace=True)
1342         # Manually analyse types for new node.
1343         lhs.analyse_target_types(env)
1344         dup.analyse_types(env)
1345         binop.analyse_operation(env)
1346         node = Nodes.SingleAssignmentNode(
1347             node.pos,
1348             lhs = lhs,
1349             rhs=binop.coerce_to(lhs.type, env))
1350         # Use LetRefNode to avoid side effects.
1351         let_ref_nodes.reverse()
1352         for t in let_ref_nodes:
1353             node = LetNode(t, node)
1354         return node
1355
1356     def visit_ExprNode(self, node):
1357         # In-place assignments can't happen within an expression.
1358         return node
1359
1360
1361 class AlignFunctionDefinitions(CythonTransform):
1362     """
1363     This class takes the signatures from a .pxd file and applies them to
1364     the def methods in a .py file.
1365     """
1366
1367     def visit_ModuleNode(self, node):
1368         self.scope = node.scope
1369         self.directives = node.directives
1370         self.visitchildren(node)
1371         return node
1372
1373     def visit_PyClassDefNode(self, node):
1374         pxd_def = self.scope.lookup(node.name)
1375         if pxd_def:
1376             if pxd_def.is_cclass:
1377                 return self.visit_CClassDefNode(node.as_cclass(), pxd_def)
1378             else:
1379                 error(node.pos, "'%s' redeclared" % node.name)
1380                 error(pxd_def.pos, "previous declaration here")
1381                 return None
1382         else:
1383             return node
1384
1385     def visit_CClassDefNode(self, node, pxd_def=None):
1386         if pxd_def is None:
1387             pxd_def = self.scope.lookup(node.class_name)
1388         if pxd_def:
1389             outer_scope = self.scope
1390             self.scope = pxd_def.type.scope
1391         self.visitchildren(node)
1392         if pxd_def:
1393             self.scope = outer_scope
1394         return node
1395
1396     def visit_DefNode(self, node):
1397         pxd_def = self.scope.lookup(node.name)
1398         if pxd_def:
1399             if not pxd_def.is_cfunction:
1400                 error(node.pos, "'%s' redeclared" % node.name)
1401                 error(pxd_def.pos, "previous declaration here")
1402                 return None
1403             node = node.as_cfunction(pxd_def)
1404         elif self.scope.is_module_scope and self.directives['auto_cpdef']:
1405             node = node.as_cfunction(scope=self.scope)
1406         # Enable this when internal def functions are allowed.
1407         # self.visitchildren(node)
1408         return node
1409
1410
1411 class MarkClosureVisitor(CythonTransform):
1412
1413     def visit_ModuleNode(self, node):
1414         self.needs_closure = False
1415         self.visitchildren(node)
1416         return node
1417
1418     def visit_FuncDefNode(self, node):
1419         self.needs_closure = False
1420         self.visitchildren(node)
1421         node.needs_closure = self.needs_closure
1422         self.needs_closure = True
1423         return node
1424
1425     def visit_CFuncDefNode(self, node):
1426         self.visit_FuncDefNode(node)
1427         if node.needs_closure:
1428             error(node.pos, "closures inside cdef functions not yet supported")
1429         return node
1430
1431     def visit_LambdaNode(self, node):
1432         self.needs_closure = False
1433         self.visitchildren(node)
1434         node.needs_closure = self.needs_closure
1435         self.needs_closure = True
1436         return node
1437
1438     def visit_ClassDefNode(self, node):
1439         self.visitchildren(node)
1440         self.needs_closure = True
1441         return node
1442
1443
1444 class CreateClosureClasses(CythonTransform):
1445     # Output closure classes in module scope for all functions
1446     # that really need it.
1447
1448     def __init__(self, context):
1449         super(CreateClosureClasses, self).__init__(context)
1450         self.path = []
1451         self.in_lambda = False
1452
1453     def visit_ModuleNode(self, node):
1454         self.module_scope = node.scope
1455         self.visitchildren(node)
1456         return node
1457
1458     def get_scope_use(self, node):
1459         from_closure = []
1460         in_closure = []
1461         for name, entry in node.local_scope.entries.items():
1462             if entry.from_closure:
1463                 from_closure.append((name, entry))
1464             elif entry.in_closure and not entry.from_closure:
1465                 in_closure.append((name, entry))
1466         return from_closure, in_closure
1467
1468     def create_class_from_scope(self, node, target_module_scope, inner_node=None):
1469         from_closure, in_closure = self.get_scope_use(node)
1470         in_closure.sort()
1471
1472         # Now from the begining
1473         node.needs_closure = False
1474         node.needs_outer_scope = False
1475
1476         func_scope = node.local_scope
1477         cscope = node.entry.scope
1478         while cscope.is_py_class_scope or cscope.is_c_class_scope:
1479             cscope = cscope.outer_scope
1480
1481         if not from_closure and (self.path or inner_node):
1482             if not inner_node:
1483                 if not node.assmt:
1484                     raise InternalError, "DefNode does not have assignment node"
1485                 inner_node = node.assmt.rhs
1486             inner_node.needs_self_code = False
1487             node.needs_outer_scope = False
1488         # Simple cases
1489         if not in_closure and not from_closure:
1490             return
1491         elif not in_closure:
1492             func_scope.is_passthrough = True
1493             func_scope.scope_class = cscope.scope_class
1494             node.needs_outer_scope = True
1495             return
1496
1497         as_name = '%s_%s' % (target_module_scope.next_id(Naming.closure_class_prefix), node.entry.cname)
1498
1499         entry = target_module_scope.declare_c_class(name = as_name,
1500             pos = node.pos, defining = True, implementing = True)
1501         func_scope.scope_class = entry
1502         class_scope = entry.type.scope
1503         class_scope.is_internal = True
1504         class_scope.directives = {'final': True}
1505
1506         if from_closure:
1507             assert cscope.is_closure_scope
1508             class_scope.declare_var(pos=node.pos,
1509                                     name=Naming.outer_scope_cname,
1510                                     cname=Naming.outer_scope_cname,
1511                                     type=cscope.scope_class.type,
1512                                     is_cdef=True)
1513             node.needs_outer_scope = True
1514         for name, entry in in_closure:
1515             class_scope.declare_var(pos=entry.pos,
1516                                     name=entry.name,
1517                                     cname=entry.cname,
1518                                     type=entry.type,
1519                                     is_cdef=True)
1520         node.needs_closure = True
1521         # Do it here because other classes are already checked
1522         target_module_scope.check_c_class(func_scope.scope_class)
1523
1524     def visit_LambdaNode(self, node):
1525         was_in_lambda = self.in_lambda
1526         self.in_lambda = True
1527         self.create_class_from_scope(node.def_node, self.module_scope, node)
1528         self.visitchildren(node)
1529         self.in_lambda = was_in_lambda
1530         return node
1531
1532     def visit_FuncDefNode(self, node):
1533         if self.in_lambda:
1534             self.visitchildren(node)
1535             return node
1536         if node.needs_closure or self.path:
1537             self.create_class_from_scope(node, self.module_scope)
1538             self.path.append(node)
1539             self.visitchildren(node)
1540             self.path.pop()
1541         return node
1542
1543
1544 class GilCheck(VisitorTransform):
1545     """
1546     Call `node.gil_check(env)` on each node to make sure we hold the
1547     GIL when we need it.  Raise an error when on Python operations
1548     inside a `nogil` environment.
1549     """
1550     def __call__(self, root):
1551         self.env_stack = [root.scope]
1552         self.nogil = False
1553         return super(GilCheck, self).__call__(root)
1554
1555     def visit_FuncDefNode(self, node):
1556         self.env_stack.append(node.local_scope)
1557         was_nogil = self.nogil
1558         self.nogil = node.local_scope.nogil
1559         if self.nogil and node.nogil_check:
1560             node.nogil_check(node.local_scope)
1561         self.visitchildren(node)
1562         self.env_stack.pop()
1563         self.nogil = was_nogil
1564         return node
1565
1566     def visit_GILStatNode(self, node):
1567         env = self.env_stack[-1]
1568         if self.nogil and node.nogil_check: node.nogil_check()
1569         was_nogil = self.nogil
1570         self.nogil = (node.state == 'nogil')
1571         self.visitchildren(node)
1572         self.nogil = was_nogil
1573         return node
1574
1575     def visit_Node(self, node):
1576         if self.env_stack and self.nogil and node.nogil_check:
1577             node.nogil_check(self.env_stack[-1])
1578         self.visitchildren(node)
1579         return node
1580
1581
1582 class TransformBuiltinMethods(EnvTransform):
1583
1584     def visit_SingleAssignmentNode(self, node):
1585         if node.declaration_only:
1586             return None
1587         else:
1588             self.visitchildren(node)
1589             return node
1590
1591     def visit_AttributeNode(self, node):
1592         self.visitchildren(node)
1593         return self.visit_cython_attribute(node)
1594
1595     def visit_NameNode(self, node):
1596         return self.visit_cython_attribute(node)
1597
1598     def visit_cython_attribute(self, node):
1599         attribute = node.as_cython_attribute()
1600         if attribute:
1601             if attribute == u'compiled':
1602                 node = ExprNodes.BoolNode(node.pos, value=True)
1603             elif attribute == u'NULL':
1604                 node = ExprNodes.NullNode(node.pos)
1605             elif attribute in (u'set', u'frozenset'):
1606                 node = ExprNodes.NameNode(node.pos, name=EncodedString(attribute),
1607                                           entry=self.current_env().builtin_scope().lookup_here(attribute))
1608             elif not PyrexTypes.parse_basic_type(attribute):
1609                 error(node.pos, u"'%s' not a valid cython attribute or is being used incorrectly" % attribute)
1610         return node
1611
1612     def visit_SimpleCallNode(self, node):
1613
1614         # locals builtin
1615         if isinstance(node.function, ExprNodes.NameNode):
1616             if node.function.name == 'locals':
1617                 lenv = self.current_env()
1618                 entry = lenv.lookup_here('locals')
1619                 if entry:
1620                     # not the builtin 'locals'
1621                     return node
1622                 if len(node.args) > 0:
1623                     error(self.pos, "Builtin 'locals()' called with wrong number of args, expected 0, got %d" % len(node.args))
1624                     return node
1625                 pos = node.pos
1626                 items = [ ExprNodes.DictItemNode(pos,
1627                                                  key=ExprNodes.StringNode(pos, value=var),
1628                                                  value=ExprNodes.NameNode(pos, name=var))
1629                           for var in lenv.entries ]
1630                 return ExprNodes.DictNode(pos, key_value_pairs=items)
1631
1632         # cython.foo
1633         function = node.function.as_cython_attribute()
1634         if function:
1635             if function in InterpretCompilerDirectives.unop_method_nodes:
1636                 if len(node.args) != 1:
1637                     error(node.function.pos, u"%s() takes exactly one argument" % function)
1638                 else:
1639                     node = InterpretCompilerDirectives.unop_method_nodes[function](node.function.pos, operand=node.args[0])
1640             elif function in InterpretCompilerDirectives.binop_method_nodes:
1641                 if len(node.args) != 2:
1642                     error(node.function.pos, u"%s() takes exactly two arguments" % function)
1643                 else:
1644                     node = InterpretCompilerDirectives.binop_method_nodes[function](node.function.pos, operand1=node.args[0], operand2=node.args[1])
1645             elif function == u'cast':
1646                 if len(node.args) != 2:
1647                     error(node.function.pos, u"cast() takes exactly two arguments")
1648                 else:
1649                     type = node.args[0].analyse_as_type(self.current_env())
1650                     if type:
1651                         node = ExprNodes.TypecastNode(node.function.pos, type=type, operand=node.args[1])
1652                     else:
1653                         error(node.args[0].pos, "Not a type")
1654             elif function == u'sizeof':
1655                 if len(node.args) != 1:
1656                     error(node.function.pos, u"sizeof() takes exactly one argument")
1657                 else:
1658                     type = node.args[0].analyse_as_type(self.current_env())
1659                     if type:
1660                         node = ExprNodes.SizeofTypeNode(node.function.pos, arg_type=type)
1661                     else:
1662                         node = ExprNodes.SizeofVarNode(node.function.pos, operand=node.args[0])
1663             elif function == 'cmod':
1664                 if len(node.args) != 2:
1665                     error(node.function.pos, u"cmod() takes exactly two arguments")
1666                 else:
1667                     node = ExprNodes.binop_node(node.function.pos, '%', node.args[0], node.args[1])
1668                     node.cdivision = True
1669             elif function == 'cdiv':
1670                 if len(node.args) != 2:
1671                     error(node.function.pos, u"cdiv() takes exactly two arguments")
1672                 else:
1673                     node = ExprNodes.binop_node(node.function.pos, '/', node.args[0], node.args[1])
1674                     node.cdivision = True
1675             elif function == u'set':
1676                 node.function = ExprNodes.NameNode(node.pos, name=EncodedString('set'))
1677             else:
1678                 error(node.function.pos, u"'%s' not a valid cython language construct" % function)
1679
1680         self.visitchildren(node)
1681         return node
1682
1683
1684 class DebugTransform(CythonTransform):
1685     """
1686     Create debug information and all functions' visibility to extern in order
1687     to enable debugging.
1688     """
1689
1690     def __init__(self, context, options, result):
1691         super(DebugTransform, self).__init__(context)
1692         self.visited = cython.set()
1693         # our treebuilder and debug output writer
1694         # (see Cython.Debugger.debug_output.CythonDebugWriter)
1695         self.tb = self.context.gdb_debug_outputwriter
1696         #self.c_output_file = options.output_file
1697         self.c_output_file = result.c_file
1698
1699         # Closure support, basically treat nested functions as if the AST were
1700         # never nested
1701         self.nested_funcdefs = []
1702
1703         # tells visit_NameNode whether it should register step-into functions
1704         self.register_stepinto = False
1705
1706     def visit_ModuleNode(self, node):
1707         self.tb.module_name = node.full_module_name
1708         attrs = dict(
1709             module_name=node.full_module_name,
1710             filename=node.pos[0].filename,
1711             c_filename=self.c_output_file)
1712
1713         self.tb.start('Module', attrs)
1714
1715         # serialize functions
1716         self.tb.start('Functions')
1717         # First, serialize functions normally...
1718         self.visitchildren(node)
1719
1720         # ... then, serialize nested functions
1721         for nested_funcdef in self.nested_funcdefs:
1722             self.visit_FuncDefNode(nested_funcdef)
1723
1724         self.register_stepinto = True
1725         self.serialize_modulenode_as_function(node)
1726         self.register_stepinto = False
1727         self.tb.end('Functions')
1728
1729         # 2.3 compatibility. Serialize global variables
1730         self.tb.start('Globals')
1731         entries = {}
1732
1733         for k, v in node.scope.entries.iteritems():
1734             if (v.qualified_name not in self.visited and not
1735                 v.name.startswith('__pyx_') and not
1736                 v.type.is_cfunction and not
1737                 v.type.is_extension_type):
1738                 entries[k]= v
1739
1740         self.serialize_local_variables(entries)
1741         self.tb.end('Globals')
1742         # self.tb.end('Module') # end Module after the line number mapping in
1743         # Cython.Compiler.ModuleNode.ModuleNode._serialize_lineno_map
1744         return node
1745
1746     def visit_FuncDefNode(self, node):
1747         self.visited.add(node.local_scope.qualified_name)
1748
1749         if getattr(node, 'is_wrapper', False):
1750             return node
1751
1752         if self.register_stepinto:
1753             self.nested_funcdefs.append(node)
1754             return node
1755
1756         # node.entry.visibility = 'extern'
1757         if node.py_func is None:
1758             pf_cname = ''
1759         else:
1760             pf_cname = node.py_func.entry.func_cname
1761
1762         attrs = dict(
1763             name=node.entry.name,
1764             cname=node.entry.func_cname,
1765             pf_cname=pf_cname,
1766             qualified_name=node.local_scope.qualified_name,
1767             lineno=str(node.pos[1]))
1768
1769         self.tb.start('Function', attrs=attrs)
1770
1771         self.tb.start('Locals')
1772         self.serialize_local_variables(node.local_scope.entries)
1773         self.tb.end('Locals')
1774
1775         self.tb.start('Arguments')
1776         for arg in node.local_scope.arg_entries:
1777             self.tb.start(arg.name)
1778             self.tb.end(arg.name)
1779         self.tb.end('Arguments')
1780
1781         self.tb.start('StepIntoFunctions')
1782         self.register_stepinto = True
1783         self.visitchildren(node)
1784         self.register_stepinto = False
1785         self.tb.end('StepIntoFunctions')
1786         self.tb.end('Function')
1787
1788         return node
1789
1790     def visit_NameNode(self, node):
1791         if (self.register_stepinto and
1792             node.type.is_cfunction and
1793             getattr(node, 'is_called', False) and
1794             node.entry.func_cname is not None):
1795             # don't check node.entry.in_cinclude, as 'cdef extern: ...'
1796             # declared functions are not 'in_cinclude'.
1797             # This means we will list called 'cdef' functions as
1798             # "step into functions", but this is not an issue as they will be
1799             # recognized as Cython functions anyway.
1800             attrs = dict(name=node.entry.func_cname)
1801             self.tb.start('StepIntoFunction', attrs=attrs)
1802             self.tb.end('StepIntoFunction')
1803
1804         self.visitchildren(node)
1805         return node
1806
1807     def serialize_modulenode_as_function(self, node):
1808         """
1809         Serialize the module-level code as a function so the debugger will know
1810         it's a "relevant frame" and it will know where to set the breakpoint
1811         for 'break modulename'.
1812         """
1813         name = node.full_module_name.rpartition('.')[-1]
1814
1815         cname_py2 = 'init' + name
1816         cname_py3 = 'PyInit_' + name
1817
1818         py2_attrs = dict(
1819             name=name,
1820             cname=cname_py2,
1821             pf_cname='',
1822             # Ignore the qualified_name, breakpoints should be set using
1823             # `cy break modulename:lineno` for module-level breakpoints.
1824             qualified_name='',
1825             lineno='1',
1826             is_initmodule_function="True",
1827         )
1828
1829         py3_attrs = dict(py2_attrs, cname=cname_py3)
1830
1831         self._serialize_modulenode_as_function(node, py2_attrs)
1832         self._serialize_modulenode_as_function(node, py3_attrs)
1833
1834     def _serialize_modulenode_as_function(self, node, attrs):
1835         self.tb.start('Function', attrs=attrs)
1836
1837         self.tb.start('Locals')
1838         self.serialize_local_variables(node.scope.entries)
1839         self.tb.end('Locals')
1840
1841         self.tb.start('Arguments')
1842         self.tb.end('Arguments')
1843
1844         self.tb.start('StepIntoFunctions')
1845         self.register_stepinto = True
1846         self.visitchildren(node)
1847         self.register_stepinto = False
1848         self.tb.end('StepIntoFunctions')
1849
1850         self.tb.end('Function')
1851
1852     def serialize_local_variables(self, entries):
1853         for entry in entries.values():
1854             if entry.type.is_pyobject:
1855                 vartype = 'PythonObject'
1856             else:
1857                 vartype = 'CObject'
1858
1859             if entry.from_closure:
1860                 # We're dealing with a closure where a variable from an outer
1861                 # scope is accessed, get it from the scope object.
1862                 cname = '%s->%s' % (Naming.cur_scope_cname,
1863                                     entry.outer_entry.cname)
1864
1865                 qname = '%s.%s.%s' % (entry.scope.outer_scope.qualified_name,
1866                                       entry.scope.name,
1867                                       entry.name)
1868             elif entry.in_closure:
1869                 cname = '%s->%s' % (Naming.cur_scope_cname,
1870                                     entry.cname)
1871                 qname = entry.qualified_name
1872             else:
1873                 cname = entry.cname
1874                 qname = entry.qualified_name
1875
1876             if not entry.pos:
1877                 # this happens for variables that are not in the user's code,
1878                 # e.g. for the global __builtins__, __doc__, etc. We can just
1879                 # set the lineno to 0 for those.
1880                 lineno = '0'
1881             else:
1882                 lineno = str(entry.pos[1])
1883
1884             attrs = dict(
1885                 name=entry.name,
1886                 cname=cname,
1887                 qualified_name=qname,
1888                 type=vartype,
1889                 lineno=lineno)
1890
1891             self.tb.start('LocalVar', attrs)
1892             self.tb.end('LocalVar')
1893