More complete code writing.
[cython.git] / Cython / Compiler / ParseTreeTransforms.py
1
2 import cython
3 cython.declare(PyrexTypes=object, Naming=object, ExprNodes=object, Nodes=object,
4                Options=object, UtilNodes=object, ModuleNode=object,
5                LetNode=object, LetRefNode=object, TreeFragment=object,
6                TemplateTransform=object, EncodedString=object,
7                error=object, warning=object, copy=object)
8
9 import PyrexTypes
10 import Naming
11 import ExprNodes
12 import Nodes
13 import Options
14
15 from Cython.Compiler.Visitor import VisitorTransform, TreeVisitor
16 from Cython.Compiler.Visitor import CythonTransform, EnvTransform, ScopeTrackingTransform
17 from Cython.Compiler.ModuleNode import ModuleNode
18 from Cython.Compiler.UtilNodes import LetNode, LetRefNode
19 from Cython.Compiler.TreeFragment import TreeFragment, TemplateTransform
20 from Cython.Compiler.StringEncoding import EncodedString
21 from Cython.Compiler.Errors import error, warning, CompileError, InternalError
22
23 import copy
24
25
26 class NameNodeCollector(TreeVisitor):
27     """Collect all NameNodes of a (sub-)tree in the ``name_nodes``
28     attribute.
29     """
30     def __init__(self):
31         super(NameNodeCollector, self).__init__()
32         self.name_nodes = []
33
34     def visit_NameNode(self, node):
35         self.name_nodes.append(node)
36
37     def visit_Node(self, node):
38         self._visitchildren(node, None)
39
40
41 class SkipDeclarations(object):
42     """
43     Variable and function declarations can often have a deep tree structure,
44     and yet most transformations don't need to descend to this depth.
45
46     Declaration nodes are removed after AnalyseDeclarationsTransform, so there
47     is no need to use this for transformations after that point.
48     """
49     def visit_CTypeDefNode(self, node):
50         return node
51
52     def visit_CVarDefNode(self, node):
53         return node
54
55     def visit_CDeclaratorNode(self, node):
56         return node
57
58     def visit_CBaseTypeNode(self, node):
59         return node
60
61     def visit_CEnumDefNode(self, node):
62         return node
63
64     def visit_CStructOrUnionDefNode(self, node):
65         return node
66
67 class NormalizeTree(CythonTransform):
68     """
69     This transform fixes up a few things after parsing
70     in order to make the parse tree more suitable for
71     transforms.
72
73     a) After parsing, blocks with only one statement will
74     be represented by that statement, not by a StatListNode.
75     When doing transforms this is annoying and inconsistent,
76     as one cannot in general remove a statement in a consistent
77     way and so on. This transform wraps any single statements
78     in a StatListNode containing a single statement.
79
80     b) The PassStatNode is a noop and serves no purpose beyond
81     plugging such one-statement blocks; i.e., once parsed a
82 `    "pass" can just as well be represented using an empty
83     StatListNode. This means less special cases to worry about
84     in subsequent transforms (one always checks to see if a
85     StatListNode has no children to see if the block is empty).
86     """
87
88     def __init__(self, context):
89         super(NormalizeTree, self).__init__(context)
90         self.is_in_statlist = False
91         self.is_in_expr = False
92
93     def visit_ExprNode(self, node):
94         stacktmp = self.is_in_expr
95         self.is_in_expr = True
96         self.visitchildren(node)
97         self.is_in_expr = stacktmp
98         return node
99
100     def visit_StatNode(self, node, is_listcontainer=False):
101         stacktmp = self.is_in_statlist
102         self.is_in_statlist = is_listcontainer
103         self.visitchildren(node)
104         self.is_in_statlist = stacktmp
105         if not self.is_in_statlist and not self.is_in_expr:
106             return Nodes.StatListNode(pos=node.pos, stats=[node])
107         else:
108             return node
109
110     def visit_StatListNode(self, node):
111         self.is_in_statlist = True
112         self.visitchildren(node)
113         self.is_in_statlist = False
114         return node
115
116     def visit_ParallelAssignmentNode(self, node):
117         return self.visit_StatNode(node, True)
118
119     def visit_CEnumDefNode(self, node):
120         return self.visit_StatNode(node, True)
121
122     def visit_CStructOrUnionDefNode(self, node):
123         return self.visit_StatNode(node, True)
124
125     # Eliminate PassStatNode
126     def visit_PassStatNode(self, node):
127         if not self.is_in_statlist:
128             return Nodes.StatListNode(pos=node.pos, stats=[])
129         else:
130             return []
131
132     def visit_CDeclaratorNode(self, node):
133         return node
134
135
136 class PostParseError(CompileError): pass
137
138 # error strings checked by unit tests, so define them
139 ERR_CDEF_INCLASS = 'Cannot assign default value to fields in cdef classes, structs or unions'
140 ERR_BUF_DEFAULTS = 'Invalid buffer defaults specification (see docs)'
141 ERR_INVALID_SPECIALATTR_TYPE = 'Special attributes must not have a type declared'
142 class PostParse(ScopeTrackingTransform):
143     """
144     Basic interpretation of the parse tree, as well as validity
145     checking that can be done on a very basic level on the parse
146     tree (while still not being a problem with the basic syntax,
147     as such).
148
149     Specifically:
150     - Default values to cdef assignments are turned into single
151     assignments following the declaration (everywhere but in class
152     bodies, where they raise a compile error)
153
154     - Interpret some node structures into Python runtime values.
155     Some nodes take compile-time arguments (currently:
156     TemplatedTypeNode[args] and __cythonbufferdefaults__ = {args}),
157     which should be interpreted. This happens in a general way
158     and other steps should be taken to ensure validity.
159
160     Type arguments cannot be interpreted in this way.
161
162     - For __cythonbufferdefaults__ the arguments are checked for
163     validity.
164
165     TemplatedTypeNode has its directives interpreted:
166     Any first positional argument goes into the "dtype" attribute,
167     any "ndim" keyword argument goes into the "ndim" attribute and
168     so on. Also it is checked that the directive combination is valid.
169     - __cythonbufferdefaults__ attributes are parsed and put into the
170     type information.
171
172     Note: Currently Parsing.py does a lot of interpretation and
173     reorganization that can be refactored into this transform
174     if a more pure Abstract Syntax Tree is wanted.
175     """
176
177     def __init__(self, context):
178         super(PostParse, self).__init__(context)
179         self.specialattribute_handlers = {
180             '__cythonbufferdefaults__' : self.handle_bufferdefaults
181         }
182
183     def visit_ModuleNode(self, node):
184         self.lambda_counter = 1
185         return super(PostParse, self).visit_ModuleNode(node)
186
187     def visit_LambdaNode(self, node):
188         # unpack a lambda expression into the corresponding DefNode
189         lambda_id = self.lambda_counter
190         self.lambda_counter += 1
191         node.lambda_name = EncodedString(u'lambda%d' % lambda_id)
192
193         body = Nodes.ReturnStatNode(
194             node.result_expr.pos, value = node.result_expr)
195         node.def_node = Nodes.DefNode(
196             node.pos, name=node.name, lambda_name=node.lambda_name,
197             args=node.args, star_arg=node.star_arg,
198             starstar_arg=node.starstar_arg,
199             body=body)
200         self.visitchildren(node)
201         return node
202
203     # cdef variables
204     def handle_bufferdefaults(self, decl):
205         if not isinstance(decl.default, ExprNodes.DictNode):
206             raise PostParseError(decl.pos, ERR_BUF_DEFAULTS)
207         self.scope_node.buffer_defaults_node = decl.default
208         self.scope_node.buffer_defaults_pos = decl.pos
209
210     def visit_CVarDefNode(self, node):
211         # This assumes only plain names and pointers are assignable on
212         # declaration. Also, it makes use of the fact that a cdef decl
213         # must appear before the first use, so we don't have to deal with
214         # "i = 3; cdef int i = i" and can simply move the nodes around.
215         try:
216             self.visitchildren(node)
217             stats = [node]
218             newdecls = []
219             for decl in node.declarators:
220                 declbase = decl
221                 while isinstance(declbase, Nodes.CPtrDeclaratorNode):
222                     declbase = declbase.base
223                 if isinstance(declbase, Nodes.CNameDeclaratorNode):
224                     if declbase.default is not None:
225                         if self.scope_type in ('cclass', 'pyclass', 'struct'):
226                             if isinstance(self.scope_node, Nodes.CClassDefNode):
227                                 handler = self.specialattribute_handlers.get(decl.name)
228                                 if handler:
229                                     if decl is not declbase:
230                                         raise PostParseError(decl.pos, ERR_INVALID_SPECIALATTR_TYPE)
231                                     handler(decl)
232                                     continue # Remove declaration
233                             raise PostParseError(decl.pos, ERR_CDEF_INCLASS)
234                         first_assignment = self.scope_type != 'module'
235                         stats.append(Nodes.SingleAssignmentNode(node.pos,
236                             lhs=ExprNodes.NameNode(node.pos, name=declbase.name),
237                             rhs=declbase.default, first=first_assignment))
238                         declbase.default = None
239                 newdecls.append(decl)
240             node.declarators = newdecls
241             return stats
242         except PostParseError, e:
243             # An error in a cdef clause is ok, simply remove the declaration
244             # and try to move on to report more errors
245             self.context.nonfatal_error(e)
246             return None
247
248     # Split parallel assignments (a,b = b,a) into separate partial
249     # assignments that are executed rhs-first using temps.  This
250     # restructuring must be applied before type analysis so that known
251     # types on rhs and lhs can be matched directly.  It is required in
252     # the case that the types cannot be coerced to a Python type in
253     # order to assign from a tuple.
254
255     def visit_SingleAssignmentNode(self, node):
256         self.visitchildren(node)
257         return self._visit_assignment_node(node, [node.lhs, node.rhs])
258
259     def visit_CascadedAssignmentNode(self, node):
260         self.visitchildren(node)
261         return self._visit_assignment_node(node, node.lhs_list + [node.rhs])
262
263     def _visit_assignment_node(self, node, expr_list):
264         """Flatten parallel assignments into separate single
265         assignments or cascaded assignments.
266         """
267         if sum([ 1 for expr in expr_list if expr.is_sequence_constructor ]) < 2:
268             # no parallel assignments => nothing to do
269             return node
270
271         expr_list_list = []
272         flatten_parallel_assignments(expr_list, expr_list_list)
273         temp_refs = []
274         eliminate_rhs_duplicates(expr_list_list, temp_refs)
275
276         nodes = []
277         for expr_list in expr_list_list:
278             lhs_list = expr_list[:-1]
279             rhs = expr_list[-1]
280             if len(lhs_list) == 1:
281                 node = Nodes.SingleAssignmentNode(rhs.pos,
282                     lhs = lhs_list[0], rhs = rhs)
283             else:
284                 node = Nodes.CascadedAssignmentNode(rhs.pos,
285                     lhs_list = lhs_list, rhs = rhs)
286             nodes.append(node)
287
288         if len(nodes) == 1:
289             assign_node = nodes[0]
290         else:
291             assign_node = Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes)
292
293         if temp_refs:
294             duplicates_and_temps = [ (temp.expression, temp)
295                                      for temp in temp_refs ]
296             sort_common_subsequences(duplicates_and_temps)
297             for _, temp_ref in duplicates_and_temps[::-1]:
298                 assign_node = LetNode(temp_ref, assign_node)
299
300         return assign_node
301
302 def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence):
303     """Replace rhs items by LetRefNodes if they appear more than once.
304     Creates a sequence of LetRefNodes that set up the required temps
305     and appends them to ref_node_sequence.  The input list is modified
306     in-place.
307     """
308     seen_nodes = cython.set()
309     ref_nodes = {}
310     def find_duplicates(node):
311         if node.is_literal or node.is_name:
312             # no need to replace those; can't include attributes here
313             # as their access is not necessarily side-effect free
314             return
315         if node in seen_nodes:
316             if node not in ref_nodes:
317                 ref_node = LetRefNode(node)
318                 ref_nodes[node] = ref_node
319                 ref_node_sequence.append(ref_node)
320         else:
321             seen_nodes.add(node)
322             if node.is_sequence_constructor:
323                 for item in node.args:
324                     find_duplicates(item)
325
326     for expr_list in expr_list_list:
327         rhs = expr_list[-1]
328         find_duplicates(rhs)
329     if not ref_nodes:
330         return
331
332     def substitute_nodes(node):
333         if node in ref_nodes:
334             return ref_nodes[node]
335         elif node.is_sequence_constructor:
336             node.args = list(map(substitute_nodes, node.args))
337         return node
338
339     # replace nodes inside of the common subexpressions
340     for node in ref_nodes:
341         if node.is_sequence_constructor:
342             node.args = list(map(substitute_nodes, node.args))
343
344     # replace common subexpressions on all rhs items
345     for expr_list in expr_list_list:
346         expr_list[-1] = substitute_nodes(expr_list[-1])
347
348 def sort_common_subsequences(items):
349     """Sort items/subsequences so that all items and subsequences that
350     an item contains appear before the item itself.  This is needed
351     because each rhs item must only be evaluated once, so its value
352     must be evaluated first and then reused when packing sequences
353     that contain it.
354
355     This implies a partial order, and the sort must be stable to
356     preserve the original order as much as possible, so we use a
357     simple insertion sort (which is very fast for short sequences, the
358     normal case in practice).
359     """
360     def contains(seq, x):
361         for item in seq:
362             if item is x:
363                 return True
364             elif item.is_sequence_constructor and contains(item.args, x):
365                 return True
366         return False
367     def lower_than(a,b):
368         return b.is_sequence_constructor and contains(b.args, a)
369
370     for pos, item in enumerate(items):
371         key = item[1] # the ResultRefNode which has already been injected into the sequences
372         new_pos = pos
373         for i in xrange(pos-1, -1, -1):
374             if lower_than(key, items[i][0]):
375                 new_pos = i
376         if new_pos != pos:
377             for i in xrange(pos, new_pos, -1):
378                 items[i] = items[i-1]
379             items[new_pos] = item
380
381 def flatten_parallel_assignments(input, output):
382     #  The input is a list of expression nodes, representing the LHSs
383     #  and RHS of one (possibly cascaded) assignment statement.  For
384     #  sequence constructors, rearranges the matching parts of both
385     #  sides into a list of equivalent assignments between the
386     #  individual elements.  This transformation is applied
387     #  recursively, so that nested structures get matched as well.
388     rhs = input[-1]
389     if not rhs.is_sequence_constructor or not sum([lhs.is_sequence_constructor for lhs in input[:-1]]):
390         output.append(input)
391         return
392
393     complete_assignments = []
394
395     rhs_size = len(rhs.args)
396     lhs_targets = [ [] for _ in xrange(rhs_size) ]
397     starred_assignments = []
398     for lhs in input[:-1]:
399         if not lhs.is_sequence_constructor:
400             if lhs.is_starred:
401                 error(lhs.pos, "starred assignment target must be in a list or tuple")
402             complete_assignments.append(lhs)
403             continue
404         lhs_size = len(lhs.args)
405         starred_targets = sum([1 for expr in lhs.args if expr.is_starred])
406         if starred_targets > 1:
407             error(lhs.pos, "more than 1 starred expression in assignment")
408             output.append([lhs,rhs])
409             continue
410         elif lhs_size - starred_targets > rhs_size:
411             error(lhs.pos, "need more than %d value%s to unpack"
412                   % (rhs_size, (rhs_size != 1) and 's' or ''))
413             output.append([lhs,rhs])
414             continue
415         elif starred_targets:
416             map_starred_assignment(lhs_targets, starred_assignments,
417                                    lhs.args, rhs.args)
418         elif lhs_size < rhs_size:
419             error(lhs.pos, "too many values to unpack (expected %d, got %d)"
420                   % (lhs_size, rhs_size))
421             output.append([lhs,rhs])
422             continue
423         else:
424             for targets, expr in zip(lhs_targets, lhs.args):
425                 targets.append(expr)
426
427     if complete_assignments:
428         complete_assignments.append(rhs)
429         output.append(complete_assignments)
430
431     # recursively flatten partial assignments
432     for cascade, rhs in zip(lhs_targets, rhs.args):
433         if cascade:
434             cascade.append(rhs)
435             flatten_parallel_assignments(cascade, output)
436
437     # recursively flatten starred assignments
438     for cascade in starred_assignments:
439         if cascade[0].is_sequence_constructor:
440             flatten_parallel_assignments(cascade, output)
441         else:
442             output.append(cascade)
443
444 def map_starred_assignment(lhs_targets, starred_assignments, lhs_args, rhs_args):
445     # Appends the fixed-position LHS targets to the target list that
446     # appear left and right of the starred argument.
447     #
448     # The starred_assignments list receives a new tuple
449     # (lhs_target, rhs_values_list) that maps the remaining arguments
450     # (those that match the starred target) to a list.
451
452     # left side of the starred target
453     for i, (targets, expr) in enumerate(zip(lhs_targets, lhs_args)):
454         if expr.is_starred:
455             starred = i
456             lhs_remaining = len(lhs_args) - i - 1
457             break
458         targets.append(expr)
459     else:
460         raise InternalError("no starred arg found when splitting starred assignment")
461
462     # right side of the starred target
463     for i, (targets, expr) in enumerate(zip(lhs_targets[-lhs_remaining:],
464                                             lhs_args[starred + 1:])):
465         targets.append(expr)
466
467     # the starred target itself, must be assigned a (potentially empty) list
468     target = lhs_args[starred].target # unpack starred node
469     starred_rhs = rhs_args[starred:]
470     if lhs_remaining:
471         starred_rhs = starred_rhs[:-lhs_remaining]
472     if starred_rhs:
473         pos = starred_rhs[0].pos
474     else:
475         pos = target.pos
476     starred_assignments.append([
477         target, ExprNodes.ListNode(pos=pos, args=starred_rhs)])
478
479
480 class PxdPostParse(CythonTransform, SkipDeclarations):
481     """
482     Basic interpretation/validity checking that should only be
483     done on pxd trees.
484
485     A lot of this checking currently happens in the parser; but
486     what is listed below happens here.
487
488     - "def" functions are let through only if they fill the
489     getbuffer/releasebuffer slots
490
491     - cdef functions are let through only if they are on the
492     top level and are declared "inline"
493     """
494     ERR_INLINE_ONLY = "function definition in pxd file must be declared 'cdef inline'"
495     ERR_NOGO_WITH_INLINE = "inline function definition in pxd file cannot be '%s'"
496
497     def __call__(self, node):
498         self.scope_type = 'pxd'
499         return super(PxdPostParse, self).__call__(node)
500
501     def visit_CClassDefNode(self, node):
502         old = self.scope_type
503         self.scope_type = 'cclass'
504         self.visitchildren(node)
505         self.scope_type = old
506         return node
507
508     def visit_FuncDefNode(self, node):
509         # FuncDefNode always come with an implementation (without
510         # an imp they are CVarDefNodes..)
511         err = self.ERR_INLINE_ONLY
512
513         if (isinstance(node, Nodes.DefNode) and self.scope_type == 'cclass'
514             and node.name in ('__getbuffer__', '__releasebuffer__')):
515             err = None # allow these slots
516
517         if isinstance(node, Nodes.CFuncDefNode):
518             if u'inline' in node.modifiers and self.scope_type == 'pxd':
519                 node.inline_in_pxd = True
520                 if node.visibility != 'private':
521                     err = self.ERR_NOGO_WITH_INLINE % node.visibility
522                 elif node.api:
523                     err = self.ERR_NOGO_WITH_INLINE % 'api'
524                 else:
525                     err = None # allow inline function
526             else:
527                 err = self.ERR_INLINE_ONLY
528
529         if err:
530             self.context.nonfatal_error(PostParseError(node.pos, err))
531             return None
532         else:
533             return node
534
535 class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
536     """
537     After parsing, directives can be stored in a number of places:
538     - #cython-comments at the top of the file (stored in ModuleNode)
539     - Command-line arguments overriding these
540     - @cython.directivename decorators
541     - with cython.directivename: statements
542
543     This transform is responsible for interpreting these various sources
544     and store the directive in two ways:
545     - Set the directives attribute of the ModuleNode for global directives.
546     - Use a CompilerDirectivesNode to override directives for a subtree.
547
548     (The first one is primarily to not have to modify with the tree
549     structure, so that ModuleNode stay on top.)
550
551     The directives are stored in dictionaries from name to value in effect.
552     Each such dictionary is always filled in for all possible directives,
553     using default values where no value is given by the user.
554
555     The available directives are controlled in Options.py.
556
557     Note that we have to run this prior to analysis, and so some minor
558     duplication of functionality has to occur: We manually track cimports
559     and which names the "cython" module may have been imported to.
560     """
561     unop_method_nodes = {
562         'typeof': ExprNodes.TypeofNode,
563
564         'operator.address': ExprNodes.AmpersandNode,
565         'operator.dereference': ExprNodes.DereferenceNode,
566         'operator.preincrement' : ExprNodes.inc_dec_constructor(True, '++'),
567         'operator.predecrement' : ExprNodes.inc_dec_constructor(True, '--'),
568         'operator.postincrement': ExprNodes.inc_dec_constructor(False, '++'),
569         'operator.postdecrement': ExprNodes.inc_dec_constructor(False, '--'),
570
571         # For backwards compatability.
572         'address': ExprNodes.AmpersandNode,
573     }
574
575     binop_method_nodes = {
576         'operator.comma'        : ExprNodes.c_binop_constructor(','),
577     }
578
579     special_methods = cython.set(['declare', 'union', 'struct', 'typedef', 'sizeof',
580                                   'cast', 'pointer', 'compiled', 'NULL'])
581     special_methods.update(unop_method_nodes.keys())
582
583     def __init__(self, context, compilation_directive_defaults):
584         super(InterpretCompilerDirectives, self).__init__(context)
585         self.compilation_directive_defaults = {}
586         for key, value in compilation_directive_defaults.items():
587             self.compilation_directive_defaults[unicode(key)] = value
588         self.cython_module_names = cython.set()
589         self.directive_names = {}
590
591     def check_directive_scope(self, pos, directive, scope):
592         legal_scopes = Options.directive_scopes.get(directive, None)
593         if legal_scopes and scope not in legal_scopes:
594             self.context.nonfatal_error(PostParseError(pos, 'The %s compiler directive '
595                                         'is not allowed in %s scope' % (directive, scope)))
596             return False
597         else:
598             return True
599
600     # Set up processing and handle the cython: comments.
601     def visit_ModuleNode(self, node):
602         for key, value in node.directive_comments.items():
603             if not self.check_directive_scope(node.pos, key, 'module'):
604                 self.wrong_scope_error(node.pos, key, 'module')
605                 del node.directive_comments[key]
606
607         directives = copy.copy(Options.directive_defaults)
608         directives.update(self.compilation_directive_defaults)
609         directives.update(node.directive_comments)
610         self.directives = directives
611         node.directives = directives
612         self.visitchildren(node)
613         node.cython_module_names = self.cython_module_names
614         return node
615
616     # The following four functions track imports and cimports that
617     # begin with "cython"
618     def is_cython_directive(self, name):
619         return (name in Options.directive_types or
620                 name in self.special_methods or
621                 PyrexTypes.parse_basic_type(name))
622
623     def visit_CImportStatNode(self, node):
624         if node.module_name == u"cython":
625             self.cython_module_names.add(node.as_name or u"cython")
626         elif node.module_name.startswith(u"cython."):
627             if node.as_name:
628                 self.directive_names[node.as_name] = node.module_name[7:]
629             else:
630                 self.cython_module_names.add(u"cython")
631             # if this cimport was a compiler directive, we don't
632             # want to leave the cimport node sitting in the tree
633             return None
634         return node
635
636     def visit_FromCImportStatNode(self, node):
637         if (node.module_name == u"cython") or \
638                node.module_name.startswith(u"cython."):
639             submodule = (node.module_name + u".")[7:]
640             newimp = []
641             for pos, name, as_name, kind in node.imported_names:
642                 full_name = submodule + name
643                 if self.is_cython_directive(full_name):
644                     if as_name is None:
645                         as_name = full_name
646                     self.directive_names[as_name] = full_name
647                     if kind is not None:
648                         self.context.nonfatal_error(PostParseError(pos,
649                             "Compiler directive imports must be plain imports"))
650                 else:
651                     newimp.append((pos, name, as_name, kind))
652             if not newimp:
653                 return None
654             node.imported_names = newimp
655         return node
656
657     def visit_FromImportStatNode(self, node):
658         if (node.module.module_name.value == u"cython") or \
659                node.module.module_name.value.startswith(u"cython."):
660             submodule = (node.module.module_name.value + u".")[7:]
661             newimp = []
662             for name, name_node in node.items:
663                 full_name = submodule + name
664                 if self.is_cython_directive(full_name):
665                     self.directive_names[name_node.name] = full_name
666                 else:
667                     newimp.append((name, name_node))
668             if not newimp:
669                 return None
670             node.items = newimp
671         return node
672
673     def visit_SingleAssignmentNode(self, node):
674         if (isinstance(node.rhs, ExprNodes.ImportNode) and
675                 node.rhs.module_name.value == u'cython'):
676             node = Nodes.CImportStatNode(node.pos,
677                                          module_name = u'cython',
678                                          as_name = node.lhs.name)
679             self.visit_CImportStatNode(node)
680         else:
681             self.visitchildren(node)
682         return node
683
684     def visit_NameNode(self, node):
685         if node.name in self.cython_module_names:
686             node.is_cython_module = True
687         else:
688             node.cython_attribute = self.directive_names.get(node.name)
689         return node
690
691     def try_to_parse_directives(self, node):
692         # If node is the contents of an directive (in a with statement or
693         # decorator), returns a list of (directivename, value) pairs.
694         # Otherwise, returns None
695         if isinstance(node, ExprNodes.CallNode):
696             self.visit(node.function)
697             optname = node.function.as_cython_attribute()
698             if optname:
699                 directivetype = Options.directive_types.get(optname)
700                 if directivetype:
701                     args, kwds = node.explicit_args_kwds()
702                     directives = []
703                     key_value_pairs = []
704                     if kwds is not None and directivetype is not dict:
705                         for keyvalue in kwds.key_value_pairs:
706                             key, value = keyvalue
707                             sub_optname = "%s.%s" % (optname, key.value)
708                             if Options.directive_types.get(sub_optname):
709                                 directives.append(self.try_to_parse_directive(sub_optname, [value], None, keyvalue.pos))
710                             else:
711                                 key_value_pairs.append(keyvalue)
712                         if not key_value_pairs:
713                             kwds = None
714                         else:
715                             kwds.key_value_pairs = key_value_pairs
716                         if directives and not kwds and not args:
717                             return directives
718                     directives.append(self.try_to_parse_directive(optname, args, kwds, node.function.pos))
719                     return directives
720         elif isinstance(node, (ExprNodes.AttributeNode, ExprNodes.NameNode)):
721             self.visit(node)
722             optname = node.as_cython_attribute()
723             if optname:
724                 directivetype = Options.directive_types.get(optname)
725                 if directivetype is bool:
726                     return [(optname, True)]
727                 elif directivetype is None:
728                     return [(optname, None)]
729                 else:
730                     raise PostParseError(
731                         node.pos, "The '%s' directive should be used as a function call." % optname)
732         return None
733
734     def try_to_parse_directive(self, optname, args, kwds, pos):
735         directivetype = Options.directive_types.get(optname)
736         if len(args) == 1 and isinstance(args[0], ExprNodes.NoneNode):
737             return optname, Options.directive_defaults[optname]
738         elif directivetype is bool:
739             if kwds is not None or len(args) != 1 or not isinstance(args[0], ExprNodes.BoolNode):
740                 raise PostParseError(pos,
741                     'The %s directive takes one compile-time boolean argument' % optname)
742             return (optname, args[0].value)
743         elif directivetype is str:
744             if kwds is not None or len(args) != 1 or not isinstance(args[0], (ExprNodes.StringNode,
745                                                                               ExprNodes.UnicodeNode)):
746                 raise PostParseError(pos,
747                     'The %s directive takes one compile-time string argument' % optname)
748             return (optname, str(args[0].value))
749         elif directivetype is dict:
750             if len(args) != 0:
751                 raise PostParseError(pos,
752                     'The %s directive takes no prepositional arguments' % optname)
753             return optname, dict([(key.value, value) for key, value in kwds.key_value_pairs])
754         elif directivetype is list:
755             if kwds and len(kwds) != 0:
756                 raise PostParseError(pos,
757                     'The %s directive takes no keyword arguments' % optname)
758             return optname, [ str(arg.value) for arg in args ]
759         else:
760             assert False
761
762     def visit_with_directives(self, body, directives):
763         olddirectives = self.directives
764         newdirectives = copy.copy(olddirectives)
765         newdirectives.update(directives)
766         self.directives = newdirectives
767         assert isinstance(body, Nodes.StatListNode), body
768         retbody = self.visit_Node(body)
769         directive = Nodes.CompilerDirectivesNode(pos=retbody.pos, body=retbody,
770                                                  directives=newdirectives)
771         self.directives = olddirectives
772         return directive
773
774     # Handle decorators
775     def visit_FuncDefNode(self, node):
776         directives = self._extract_directives(node, 'function')
777         if not directives:
778             return self.visit_Node(node)
779         body = Nodes.StatListNode(node.pos, stats=[node])
780         return self.visit_with_directives(body, directives)
781
782     def visit_CVarDefNode(self, node):
783         if not node.decorators:
784             return node
785         for dec in node.decorators:
786             for directive in self.try_to_parse_directives(dec.decorator) or ():
787                 if directive is not None and directive[0] == u'locals':
788                     node.directive_locals = directive[1]
789                 else:
790                     self.context.nonfatal_error(PostParseError(dec.pos,
791                         "Cdef functions can only take cython.locals() decorator."))
792         return node
793
794     def visit_CClassDefNode(self, node):
795         directives = self._extract_directives(node, 'cclass')
796         if not directives:
797             return self.visit_Node(node)
798         body = Nodes.StatListNode(node.pos, stats=[node])
799         return self.visit_with_directives(body, directives)
800
801     def visit_PyClassDefNode(self, node):
802         directives = self._extract_directives(node, 'class')
803         if not directives:
804             return self.visit_Node(node)
805         body = Nodes.StatListNode(node.pos, stats=[node])
806         return self.visit_with_directives(body, directives)
807
808     def _extract_directives(self, node, scope_name):
809         if not node.decorators:
810             return {}
811         # Split the decorators into two lists -- real decorators and directives
812         directives = []
813         realdecs = []
814         for dec in node.decorators:
815             new_directives = self.try_to_parse_directives(dec.decorator)
816             if new_directives is not None:
817                 for directive in new_directives:
818                     if self.check_directive_scope(node.pos, directive[0], scope_name):
819                         directives.append(directive)
820             else:
821                 realdecs.append(dec)
822         if realdecs and isinstance(node, (Nodes.CFuncDefNode, Nodes.CClassDefNode)):
823             raise PostParseError(realdecs[0].pos, "Cdef functions/classes cannot take arbitrary decorators.")
824         else:
825             node.decorators = realdecs
826         # merge or override repeated directives
827         optdict = {}
828         directives.reverse() # Decorators coming first take precedence
829         for directive in directives:
830             name, value = directive
831             if name in optdict:
832                 old_value = optdict[name]
833                 # keywords and arg lists can be merged, everything
834                 # else overrides completely
835                 if isinstance(old_value, dict):
836                     old_value.update(value)
837                 elif isinstance(old_value, list):
838                     old_value.extend(value)
839                 else:
840                     optdict[name] = value
841             else:
842                 optdict[name] = value
843         return optdict
844
845     # Handle with statements
846     def visit_WithStatNode(self, node):
847         directive_dict = {}
848         for directive in self.try_to_parse_directives(node.manager) or []:
849             if directive is not None:
850                 if node.target is not None:
851                     self.context.nonfatal_error(
852                         PostParseError(node.pos, "Compiler directive with statements cannot contain 'as'"))
853                 else:
854                     name, value = directive
855                     if name == 'nogil':
856                         # special case: in pure mode, "with nogil" spells "with cython.nogil"
857                         node = Nodes.GILStatNode(node.pos, state = "nogil", body = node.body)
858                         return self.visit_Node(node)
859                     if self.check_directive_scope(node.pos, name, 'with statement'):
860                         directive_dict[name] = value
861         if directive_dict:
862             return self.visit_with_directives(node.body, directive_dict)
863         return self.visit_Node(node)
864
865 class WithTransform(CythonTransform, SkipDeclarations):
866
867     # EXCINFO is manually set to a variable that contains
868     # the exc_info() tuple that can be generated by the enclosing except
869     # statement.
870     template_without_target = TreeFragment(u"""
871         MGR = EXPR
872         EXIT = MGR.__exit__
873         MGR.__enter__()
874         EXC = True
875         try:
876             try:
877                 EXCINFO = None
878                 BODY
879             except:
880                 EXC = False
881                 if not EXIT(*EXCINFO):
882                     raise
883         finally:
884             if EXC:
885                 EXIT(None, None, None)
886     """, temps=[u'MGR', u'EXC', u"EXIT"],
887     pipeline=[NormalizeTree(None)])
888
889     template_with_target = TreeFragment(u"""
890         MGR = EXPR
891         EXIT = MGR.__exit__
892         VALUE = MGR.__enter__()
893         EXC = True
894         try:
895             try:
896                 EXCINFO = None
897                 TARGET = VALUE
898                 BODY
899             except:
900                 EXC = False
901                 if not EXIT(*EXCINFO):
902                     raise
903         finally:
904             if EXC:
905                 EXIT(None, None, None)
906             MGR = EXIT = VALUE = EXC = None
907
908     """, temps=[u'MGR', u'EXC', u"EXIT", u"VALUE"],
909     pipeline=[NormalizeTree(None)])
910
911     def visit_WithStatNode(self, node):
912         # TODO: Cleanup badly needed
913         TemplateTransform.temp_name_counter += 1
914         handle = "__tmpvar_%d" % TemplateTransform.temp_name_counter
915
916         self.visitchildren(node, ['body'])
917         excinfo_temp = ExprNodes.NameNode(node.pos, name=handle)#TempHandle(Builtin.tuple_type)
918         if node.target is not None:
919             result = self.template_with_target.substitute({
920                 u'EXPR' : node.manager,
921                 u'BODY' : node.body,
922                 u'TARGET' : node.target,
923                 u'EXCINFO' : excinfo_temp
924                 }, pos=node.pos)
925         else:
926             result = self.template_without_target.substitute({
927                 u'EXPR' : node.manager,
928                 u'BODY' : node.body,
929                 u'EXCINFO' : excinfo_temp
930                 }, pos=node.pos)
931
932         # Set except excinfo target to EXCINFO
933         try_except = result.stats[-1].body.stats[-1]
934         try_except.except_clauses[0].excinfo_target = ExprNodes.NameNode(node.pos, name=handle)
935 #            excinfo_temp.ref(node.pos))
936
937 #        result.stats[-1].body.stats[-1] = TempsBlockNode(
938 #            node.pos, temps=[excinfo_temp], body=try_except)
939
940         return result
941
942     def visit_ExprNode(self, node):
943         # With statements are never inside expressions.
944         return node
945
946
947 class DecoratorTransform(CythonTransform, SkipDeclarations):
948
949     def visit_DefNode(self, func_node):
950         self.visitchildren(func_node)
951         if not func_node.decorators:
952             return func_node
953         return self._handle_decorators(
954             func_node, func_node.name)
955
956     def visit_CClassDefNode(self, class_node):
957         # This doesn't currently work, so it's disabled.
958         #
959         # Problem: assignments to cdef class names do not work.  They
960         # would require an additional check anyway, as the extension
961         # type must not change its C type, so decorators cannot
962         # replace an extension type, just alter it and return it.
963
964         self.visitchildren(class_node)
965         if not class_node.decorators:
966             return class_node
967         error(class_node.pos,
968               "Decorators not allowed on cdef classes (used on type '%s')" % class_node.class_name)
969         return class_node
970         #return self._handle_decorators(
971         #    class_node, class_node.class_name)
972
973     def visit_ClassDefNode(self, class_node):
974         self.visitchildren(class_node)
975         if not class_node.decorators:
976             return class_node
977         return self._handle_decorators(
978             class_node, class_node.name)
979
980     def _handle_decorators(self, node, name):
981         decorator_result = ExprNodes.NameNode(node.pos, name = name)
982         for decorator in node.decorators[::-1]:
983             decorator_result = ExprNodes.SimpleCallNode(
984                 decorator.pos,
985                 function = decorator.decorator,
986                 args = [decorator_result])
987
988         name_node = ExprNodes.NameNode(node.pos, name = name)
989         reassignment = Nodes.SingleAssignmentNode(
990             node.pos,
991             lhs = name_node,
992             rhs = decorator_result)
993         return [node, reassignment]
994
995
996 class AnalyseDeclarationsTransform(CythonTransform):
997
998     basic_property = TreeFragment(u"""
999 property NAME:
1000     def __get__(self):
1001         return ATTR
1002     def __set__(self, value):
1003         ATTR = value
1004     """, level='c_class')
1005     basic_pyobject_property = TreeFragment(u"""
1006 property NAME:
1007     def __get__(self):
1008         return ATTR
1009     def __set__(self, value):
1010         ATTR = value
1011     def __del__(self):
1012         ATTR = None
1013     """, level='c_class')
1014     basic_property_ro = TreeFragment(u"""
1015 property NAME:
1016     def __get__(self):
1017         return ATTR
1018     """, level='c_class')
1019
1020     struct_or_union_wrapper = TreeFragment(u"""
1021 cdef class NAME:
1022     cdef TYPE value
1023     def __init__(self, MEMBER=None):
1024         cdef int count
1025         count = 0
1026         INIT_ASSIGNMENTS
1027         if IS_UNION and count > 1:
1028             raise ValueError, "At most one union member should be specified."
1029     def __str__(self):
1030         return STR_FORMAT % MEMBER_TUPLE
1031     def __repr__(self):
1032         return REPR_FORMAT % MEMBER_TUPLE
1033     """)
1034
1035     init_assignment = TreeFragment(u"""
1036 if VALUE is not None:
1037     ATTR = VALUE
1038     count += 1
1039     """)
1040
1041     def __call__(self, root):
1042         self.env_stack = [root.scope]
1043         # needed to determine if a cdef var is declared after it's used.
1044         self.seen_vars_stack = []
1045         return super(AnalyseDeclarationsTransform, self).__call__(root)
1046
1047     def visit_NameNode(self, node):
1048         self.seen_vars_stack[-1].add(node.name)
1049         return node
1050
1051     def visit_ModuleNode(self, node):
1052         self.seen_vars_stack.append(cython.set())
1053         node.analyse_declarations(self.env_stack[-1])
1054         self.visitchildren(node)
1055         self.seen_vars_stack.pop()
1056         return node
1057
1058     def visit_LambdaNode(self, node):
1059         node.analyse_declarations(self.env_stack[-1])
1060         self.visitchildren(node)
1061         return node
1062
1063     def visit_ClassDefNode(self, node):
1064         self.env_stack.append(node.scope)
1065         self.visitchildren(node)
1066         self.env_stack.pop()
1067         return node
1068
1069     def visit_CClassDefNode(self, node):
1070         node = self.visit_ClassDefNode(node)
1071         if node.scope and node.scope.implemented:
1072             stats = []
1073             for entry in node.scope.var_entries:
1074                 if entry.needs_property:
1075                     property = self.create_Property(entry)
1076                     property.analyse_declarations(node.scope)
1077                     self.visit(property)
1078                     stats.append(property)
1079             if stats:
1080                 node.body.stats += stats
1081         return node
1082
1083     def visit_FuncDefNode(self, node):
1084         self.seen_vars_stack.append(cython.set())
1085         lenv = node.local_scope
1086         node.body.analyse_control_flow(lenv) # this will be totally refactored
1087         node.declare_arguments(lenv)
1088         for var, type_node in node.directive_locals.items():
1089             if not lenv.lookup_here(var):   # don't redeclare args
1090                 type = type_node.analyse_as_type(lenv)
1091                 if type:
1092                     lenv.declare_var(var, type, type_node.pos)
1093                 else:
1094                     error(type_node.pos, "Not a type")
1095         node.body.analyse_declarations(lenv)
1096         self.env_stack.append(lenv)
1097         self.visitchildren(node)
1098         self.env_stack.pop()
1099         self.seen_vars_stack.pop()
1100         return node
1101
1102     def visit_ScopedExprNode(self, node):
1103         env = self.env_stack[-1]
1104         node.analyse_declarations(env)
1105         # the node may or may not have a local scope
1106         if node.has_local_scope:
1107             self.seen_vars_stack.append(cython.set(self.seen_vars_stack[-1]))
1108             self.env_stack.append(node.expr_scope)
1109             node.analyse_scoped_declarations(node.expr_scope)
1110             self.visitchildren(node)
1111             self.env_stack.pop()
1112             self.seen_vars_stack.pop()
1113         else:
1114             node.analyse_scoped_declarations(env)
1115             self.visitchildren(node)
1116         return node
1117
1118     def visit_TempResultFromStatNode(self, node):
1119         self.visitchildren(node)
1120         node.analyse_declarations(self.env_stack[-1])
1121         return node
1122
1123     def visit_CStructOrUnionDefNode(self, node):
1124         # Create a wrapper node if needed.
1125         # We want to use the struct type information (so it can't happen
1126         # before this phase) but also create new objects to be declared
1127         # (so it can't happen later).
1128         # Note that we don't return the original node, as it is
1129         # never used after this phase.
1130         if True: # private (default)
1131             return None
1132
1133         self_value = ExprNodes.AttributeNode(
1134             pos = node.pos,
1135             obj = ExprNodes.NameNode(pos=node.pos, name=u"self"),
1136             attribute = EncodedString(u"value"))
1137         var_entries = node.entry.type.scope.var_entries
1138         attributes = []
1139         for entry in var_entries:
1140             attributes.append(ExprNodes.AttributeNode(pos = entry.pos,
1141                                                       obj = self_value,
1142                                                       attribute = entry.name))
1143         # __init__ assignments
1144         init_assignments = []
1145         for entry, attr in zip(var_entries, attributes):
1146             # TODO: branch on visibility
1147             init_assignments.append(self.init_assignment.substitute({
1148                     u"VALUE": ExprNodes.NameNode(entry.pos, name = entry.name),
1149                     u"ATTR": attr,
1150                 }, pos = entry.pos))
1151
1152         # create the class
1153         str_format = u"%s(%s)" % (node.entry.type.name, ("%s, " * len(attributes))[:-2])
1154         wrapper_class = self.struct_or_union_wrapper.substitute({
1155             u"INIT_ASSIGNMENTS": Nodes.StatListNode(node.pos, stats = init_assignments),
1156             u"IS_UNION": ExprNodes.BoolNode(node.pos, value = not node.entry.type.is_struct),
1157             u"MEMBER_TUPLE": ExprNodes.TupleNode(node.pos, args=attributes),
1158             u"STR_FORMAT": ExprNodes.StringNode(node.pos, value = EncodedString(str_format)),
1159             u"REPR_FORMAT": ExprNodes.StringNode(node.pos, value = EncodedString(str_format.replace("%s", "%r"))),
1160         }, pos = node.pos).stats[0]
1161         wrapper_class.class_name = node.name
1162         wrapper_class.shadow = True
1163         class_body = wrapper_class.body.stats
1164
1165         # fix value type
1166         assert isinstance(class_body[0].base_type, Nodes.CSimpleBaseTypeNode)
1167         class_body[0].base_type.name = node.name
1168
1169         # fix __init__ arguments
1170         init_method = class_body[1]
1171         assert isinstance(init_method, Nodes.DefNode) and init_method.name == '__init__'
1172         arg_template = init_method.args[1]
1173         if not node.entry.type.is_struct:
1174             arg_template.kw_only = True
1175         del init_method.args[1]
1176         for entry, attr in zip(var_entries, attributes):
1177             arg = copy.deepcopy(arg_template)
1178             arg.declarator.name = entry.name
1179             init_method.args.append(arg)
1180             
1181         # setters/getters
1182         for entry, attr in zip(var_entries, attributes):
1183             # TODO: branch on visibility
1184             if entry.type.is_pyobject:
1185                 template = self.basic_pyobject_property
1186             else:
1187                 template = self.basic_property
1188             property = template.substitute({
1189                     u"ATTR": attr,
1190                 }, pos = entry.pos).stats[0]
1191             property.name = entry.name
1192             wrapper_class.body.stats.append(property)
1193             
1194         wrapper_class.analyse_declarations(self.env_stack[-1])
1195         return self.visit_CClassDefNode(wrapper_class)
1196
1197     # Some nodes are no longer needed after declaration
1198     # analysis and can be dropped. The analysis was performed
1199     # on these nodes in a seperate recursive process from the
1200     # enclosing function or module, so we can simply drop them.
1201     def visit_CDeclaratorNode(self, node):
1202         # necessary to ensure that all CNameDeclaratorNodes are visited.
1203         self.visitchildren(node)
1204         return node
1205
1206     def visit_CTypeDefNode(self, node):
1207         return node
1208
1209     def visit_CBaseTypeNode(self, node):
1210         return None
1211
1212     def visit_CEnumDefNode(self, node):
1213         if node.visibility == 'public':
1214             return node
1215         else:
1216             return None
1217
1218     def visit_CNameDeclaratorNode(self, node):
1219         if node.name in self.seen_vars_stack[-1]:
1220             entry = self.env_stack[-1].lookup(node.name)
1221             if entry is None or entry.visibility != 'extern':
1222                 warning(node.pos, "cdef variable '%s' declared after it is used" % node.name, 2)
1223         self.visitchildren(node)
1224         return node
1225
1226     def visit_CVarDefNode(self, node):
1227         # to ensure all CNameDeclaratorNodes are visited.
1228         self.visitchildren(node)
1229         return None
1230
1231     def create_Property(self, entry):
1232         if entry.visibility == 'public':
1233             if entry.type.is_pyobject:
1234                 template = self.basic_pyobject_property
1235             else:
1236                 template = self.basic_property
1237         elif entry.visibility == 'readonly':
1238             template = self.basic_property_ro
1239         property = template.substitute({
1240                 u"ATTR": ExprNodes.AttributeNode(pos=entry.pos,
1241                                                  obj=ExprNodes.NameNode(pos=entry.pos, name="self"),
1242                                                  attribute=entry.name),
1243             }, pos=entry.pos).stats[0]
1244         property.name = entry.name
1245         # ---------------------------------------
1246         # XXX This should go to AutoDocTransforms
1247         # ---------------------------------------
1248         if (Options.docstrings and
1249             self.current_directives['embedsignature']):
1250             attr_name = entry.name
1251             type_name = entry.type.declaration_code("", for_display=1)
1252             default_value = ''
1253             if not entry.type.is_pyobject:
1254                 type_name = "'%s'" % type_name
1255             elif entry.type.is_extension_type:
1256                 type_name = entry.type.module_name + '.' + type_name
1257             if entry.init is not None:
1258                 default_value = ' = ' + entry.init
1259             elif entry.init_to_none:
1260                 default_value = ' = ' + repr(None)
1261             docstring = attr_name + ': ' + type_name + default_value
1262             property.doc = EncodedString(docstring)
1263         # ---------------------------------------
1264         return property
1265
1266 class AnalyseExpressionsTransform(CythonTransform):
1267
1268     def visit_ModuleNode(self, node):
1269         node.scope.infer_types()
1270         node.body.analyse_expressions(node.scope)
1271         self.visitchildren(node)
1272         return node
1273
1274     def visit_FuncDefNode(self, node):
1275         node.local_scope.infer_types()
1276         node.body.analyse_expressions(node.local_scope)
1277         self.visitchildren(node)
1278         return node
1279
1280     def visit_ScopedExprNode(self, node):
1281         if node.has_local_scope:
1282             node.expr_scope.infer_types()
1283             node.analyse_scoped_expressions(node.expr_scope)
1284         self.visitchildren(node)
1285         return node
1286
1287 class ExpandInplaceOperators(EnvTransform):
1288
1289     def visit_InPlaceAssignmentNode(self, node):
1290         lhs = node.lhs
1291         rhs = node.rhs
1292         if lhs.type.is_cpp_class:
1293             # No getting around this exact operator here.
1294             return node
1295         if isinstance(lhs, ExprNodes.IndexNode) and lhs.is_buffer_access:
1296             # There is code to handle this case.
1297             return node
1298
1299         env = self.current_env()
1300         def side_effect_free_reference(node, setting=False):
1301             if isinstance(node, ExprNodes.NameNode):
1302                 return node, []
1303             elif node.type.is_pyobject and not setting:
1304                 node = LetRefNode(node)
1305                 return node, [node]
1306             elif isinstance(node, ExprNodes.IndexNode):
1307                 if node.is_buffer_access:
1308                     raise ValueError, "Buffer access"
1309                 base, temps = side_effect_free_reference(node.base)
1310                 index = LetRefNode(node.index)
1311                 return ExprNodes.IndexNode(node.pos, base=base, index=index), temps + [index]
1312             elif isinstance(node, ExprNodes.AttributeNode):
1313                 obj, temps = side_effect_free_reference(node.obj)
1314                 return ExprNodes.AttributeNode(node.pos, obj=obj, attribute=node.attribute), temps
1315             else:
1316                 node = LetRefNode(node)
1317                 return node, [node]
1318         try:
1319             lhs, let_ref_nodes = side_effect_free_reference(lhs, setting=True)
1320         except ValueError:
1321             return node
1322         dup = lhs.__class__(**lhs.__dict__)
1323         binop = ExprNodes.binop_node(node.pos,
1324                                      operator = node.operator,
1325                                      operand1 = dup,
1326                                      operand2 = rhs,
1327                                      inplace=True)
1328         # Manually analyse types for new node.
1329         lhs.analyse_target_types(env)
1330         dup.analyse_types(env)
1331         binop.analyse_operation(env)
1332         node = Nodes.SingleAssignmentNode(
1333             node.pos,
1334             lhs = lhs,
1335             rhs=binop.coerce_to(lhs.type, env))
1336         # Use LetRefNode to avoid side effects.
1337         let_ref_nodes.reverse()
1338         for t in let_ref_nodes:
1339             node = LetNode(t, node)
1340         return node
1341
1342     def visit_ExprNode(self, node):
1343         # In-place assignments can't happen within an expression.
1344         return node
1345
1346
1347 class AlignFunctionDefinitions(CythonTransform):
1348     """
1349     This class takes the signatures from a .pxd file and applies them to
1350     the def methods in a .py file.
1351     """
1352
1353     def visit_ModuleNode(self, node):
1354         self.scope = node.scope
1355         self.directives = node.directives
1356         self.visitchildren(node)
1357         return node
1358
1359     def visit_PyClassDefNode(self, node):
1360         pxd_def = self.scope.lookup(node.name)
1361         if pxd_def:
1362             if pxd_def.is_cclass:
1363                 return self.visit_CClassDefNode(node.as_cclass(), pxd_def)
1364             else:
1365                 error(node.pos, "'%s' redeclared" % node.name)
1366                 error(pxd_def.pos, "previous declaration here")
1367                 return None
1368         else:
1369             return node
1370
1371     def visit_CClassDefNode(self, node, pxd_def=None):
1372         if pxd_def is None:
1373             pxd_def = self.scope.lookup(node.class_name)
1374         if pxd_def:
1375             outer_scope = self.scope
1376             self.scope = pxd_def.type.scope
1377         self.visitchildren(node)
1378         if pxd_def:
1379             self.scope = outer_scope
1380         return node
1381
1382     def visit_DefNode(self, node):
1383         pxd_def = self.scope.lookup(node.name)
1384         if pxd_def:
1385             if not pxd_def.is_cfunction:
1386                 error(node.pos, "'%s' redeclared" % node.name)
1387                 error(pxd_def.pos, "previous declaration here")
1388                 return None
1389             node = node.as_cfunction(pxd_def)
1390         elif self.scope.is_module_scope and self.directives['auto_cpdef']:
1391             node = node.as_cfunction(scope=self.scope)
1392         # Enable this when internal def functions are allowed.
1393         # self.visitchildren(node)
1394         return node
1395
1396
1397 class MarkClosureVisitor(CythonTransform):
1398
1399     def visit_ModuleNode(self, node):
1400         self.needs_closure = False
1401         self.visitchildren(node)
1402         return node
1403
1404     def visit_FuncDefNode(self, node):
1405         self.needs_closure = False
1406         self.visitchildren(node)
1407         node.needs_closure = self.needs_closure
1408         self.needs_closure = True
1409         return node
1410
1411     def visit_CFuncDefNode(self, node):
1412         self.visit_FuncDefNode(node)
1413         if node.needs_closure:
1414             error(node.pos, "closures inside cdef functions not yet supported")
1415         return node
1416
1417     def visit_LambdaNode(self, node):
1418         self.needs_closure = False
1419         self.visitchildren(node)
1420         node.needs_closure = self.needs_closure
1421         self.needs_closure = True
1422         return node
1423
1424     def visit_ClassDefNode(self, node):
1425         self.visitchildren(node)
1426         self.needs_closure = True
1427         return node
1428
1429
1430 class CreateClosureClasses(CythonTransform):
1431     # Output closure classes in module scope for all functions
1432     # that really need it.
1433
1434     def __init__(self, context):
1435         super(CreateClosureClasses, self).__init__(context)
1436         self.path = []
1437         self.in_lambda = False
1438
1439     def visit_ModuleNode(self, node):
1440         self.module_scope = node.scope
1441         self.visitchildren(node)
1442         return node
1443
1444     def get_scope_use(self, node):
1445         from_closure = []
1446         in_closure = []
1447         for name, entry in node.local_scope.entries.items():
1448             if entry.from_closure:
1449                 from_closure.append((name, entry))
1450             elif entry.in_closure and not entry.from_closure:
1451                 in_closure.append((name, entry))
1452         return from_closure, in_closure
1453
1454     def create_class_from_scope(self, node, target_module_scope, inner_node=None):
1455         from_closure, in_closure = self.get_scope_use(node)
1456         in_closure.sort()
1457
1458         # Now from the begining
1459         node.needs_closure = False
1460         node.needs_outer_scope = False
1461
1462         func_scope = node.local_scope
1463         cscope = node.entry.scope
1464         while cscope.is_py_class_scope or cscope.is_c_class_scope:
1465             cscope = cscope.outer_scope
1466
1467         if not from_closure and (self.path or inner_node):
1468             if not inner_node:
1469                 if not node.assmt:
1470                     raise InternalError, "DefNode does not have assignment node"
1471                 inner_node = node.assmt.rhs
1472             inner_node.needs_self_code = False
1473             node.needs_outer_scope = False
1474         # Simple cases
1475         if not in_closure and not from_closure:
1476             return
1477         elif not in_closure:
1478             func_scope.is_passthrough = True
1479             func_scope.scope_class = cscope.scope_class
1480             node.needs_outer_scope = True
1481             return
1482
1483         as_name = '%s_%s' % (target_module_scope.next_id(Naming.closure_class_prefix), node.entry.cname)
1484
1485         entry = target_module_scope.declare_c_class(name = as_name,
1486             pos = node.pos, defining = True, implementing = True)
1487         func_scope.scope_class = entry
1488         class_scope = entry.type.scope
1489         class_scope.is_internal = True
1490         class_scope.directives = {'final': True}
1491
1492         if from_closure:
1493             assert cscope.is_closure_scope
1494             class_scope.declare_var(pos=node.pos,
1495                                     name=Naming.outer_scope_cname,
1496                                     cname=Naming.outer_scope_cname,
1497                                     type=cscope.scope_class.type,
1498                                     is_cdef=True)
1499             node.needs_outer_scope = True
1500         for name, entry in in_closure:
1501             class_scope.declare_var(pos=entry.pos,
1502                                     name=entry.name,
1503                                     cname=entry.cname,
1504                                     type=entry.type,
1505                                     is_cdef=True)
1506         node.needs_closure = True
1507         # Do it here because other classes are already checked
1508         target_module_scope.check_c_class(func_scope.scope_class)
1509
1510     def visit_LambdaNode(self, node):
1511         was_in_lambda = self.in_lambda
1512         self.in_lambda = True
1513         self.create_class_from_scope(node.def_node, self.module_scope, node)
1514         self.visitchildren(node)
1515         self.in_lambda = was_in_lambda
1516         return node
1517
1518     def visit_FuncDefNode(self, node):
1519         if self.in_lambda:
1520             self.visitchildren(node)
1521             return node
1522         if node.needs_closure or self.path:
1523             self.create_class_from_scope(node, self.module_scope)
1524             self.path.append(node)
1525             self.visitchildren(node)
1526             self.path.pop()
1527         return node
1528
1529
1530 class GilCheck(VisitorTransform):
1531     """
1532     Call `node.gil_check(env)` on each node to make sure we hold the
1533     GIL when we need it.  Raise an error when on Python operations
1534     inside a `nogil` environment.
1535     """
1536     def __call__(self, root):
1537         self.env_stack = [root.scope]
1538         self.nogil = False
1539         return super(GilCheck, self).__call__(root)
1540
1541     def visit_FuncDefNode(self, node):
1542         self.env_stack.append(node.local_scope)
1543         was_nogil = self.nogil
1544         self.nogil = node.local_scope.nogil
1545         if self.nogil and node.nogil_check:
1546             node.nogil_check(node.local_scope)
1547         self.visitchildren(node)
1548         self.env_stack.pop()
1549         self.nogil = was_nogil
1550         return node
1551
1552     def visit_GILStatNode(self, node):
1553         env = self.env_stack[-1]
1554         if self.nogil and node.nogil_check: node.nogil_check()
1555         was_nogil = self.nogil
1556         self.nogil = (node.state == 'nogil')
1557         self.visitchildren(node)
1558         self.nogil = was_nogil
1559         return node
1560
1561     def visit_Node(self, node):
1562         if self.env_stack and self.nogil and node.nogil_check:
1563             node.nogil_check(self.env_stack[-1])
1564         self.visitchildren(node)
1565         return node
1566
1567
1568 class TransformBuiltinMethods(EnvTransform):
1569
1570     def visit_SingleAssignmentNode(self, node):
1571         if node.declaration_only:
1572             return None
1573         else:
1574             self.visitchildren(node)
1575             return node
1576
1577     def visit_AttributeNode(self, node):
1578         self.visitchildren(node)
1579         return self.visit_cython_attribute(node)
1580
1581     def visit_NameNode(self, node):
1582         return self.visit_cython_attribute(node)
1583
1584     def visit_cython_attribute(self, node):
1585         attribute = node.as_cython_attribute()
1586         if attribute:
1587             if attribute == u'compiled':
1588                 node = ExprNodes.BoolNode(node.pos, value=True)
1589             elif attribute == u'NULL':
1590                 node = ExprNodes.NullNode(node.pos)
1591             elif attribute in (u'set', u'frozenset'):
1592                 node = ExprNodes.NameNode(node.pos, name=EncodedString(attribute),
1593                                           entry=self.current_env().builtin_scope().lookup_here(attribute))
1594             elif not PyrexTypes.parse_basic_type(attribute):
1595                 error(node.pos, u"'%s' not a valid cython attribute or is being used incorrectly" % attribute)
1596         return node
1597
1598     def visit_SimpleCallNode(self, node):
1599
1600         # locals builtin
1601         if isinstance(node.function, ExprNodes.NameNode):
1602             if node.function.name == 'locals':
1603                 lenv = self.current_env()
1604                 entry = lenv.lookup_here('locals')
1605                 if entry:
1606                     # not the builtin 'locals'
1607                     return node
1608                 if len(node.args) > 0:
1609                     error(self.pos, "Builtin 'locals()' called with wrong number of args, expected 0, got %d" % len(node.args))
1610                     return node
1611                 pos = node.pos
1612                 items = [ ExprNodes.DictItemNode(pos,
1613                                                  key=ExprNodes.StringNode(pos, value=var),
1614                                                  value=ExprNodes.NameNode(pos, name=var))
1615                           for var in lenv.entries ]
1616                 return ExprNodes.DictNode(pos, key_value_pairs=items)
1617
1618         # cython.foo
1619         function = node.function.as_cython_attribute()
1620         if function:
1621             if function in InterpretCompilerDirectives.unop_method_nodes:
1622                 if len(node.args) != 1:
1623                     error(node.function.pos, u"%s() takes exactly one argument" % function)
1624                 else:
1625                     node = InterpretCompilerDirectives.unop_method_nodes[function](node.function.pos, operand=node.args[0])
1626             elif function in InterpretCompilerDirectives.binop_method_nodes:
1627                 if len(node.args) != 2:
1628                     error(node.function.pos, u"%s() takes exactly two arguments" % function)
1629                 else:
1630                     node = InterpretCompilerDirectives.binop_method_nodes[function](node.function.pos, operand1=node.args[0], operand2=node.args[1])
1631             elif function == u'cast':
1632                 if len(node.args) != 2:
1633                     error(node.function.pos, u"cast() takes exactly two arguments")
1634                 else:
1635                     type = node.args[0].analyse_as_type(self.current_env())
1636                     if type:
1637                         node = ExprNodes.TypecastNode(node.function.pos, type=type, operand=node.args[1])
1638                     else:
1639                         error(node.args[0].pos, "Not a type")
1640             elif function == u'sizeof':
1641                 if len(node.args) != 1:
1642                     error(node.function.pos, u"sizeof() takes exactly one argument")
1643                 else:
1644                     type = node.args[0].analyse_as_type(self.current_env())
1645                     if type:
1646                         node = ExprNodes.SizeofTypeNode(node.function.pos, arg_type=type)
1647                     else:
1648                         node = ExprNodes.SizeofVarNode(node.function.pos, operand=node.args[0])
1649             elif function == 'cmod':
1650                 if len(node.args) != 2:
1651                     error(node.function.pos, u"cmod() takes exactly two arguments")
1652                 else:
1653                     node = ExprNodes.binop_node(node.function.pos, '%', node.args[0], node.args[1])
1654                     node.cdivision = True
1655             elif function == 'cdiv':
1656                 if len(node.args) != 2:
1657                     error(node.function.pos, u"cdiv() takes exactly two arguments")
1658                 else:
1659                     node = ExprNodes.binop_node(node.function.pos, '/', node.args[0], node.args[1])
1660                     node.cdivision = True
1661             elif function == u'set':
1662                 node.function = ExprNodes.NameNode(node.pos, name=EncodedString('set'))
1663             else:
1664                 error(node.function.pos, u"'%s' not a valid cython language construct" % function)
1665
1666         self.visitchildren(node)
1667         return node
1668
1669
1670 class DebugTransform(CythonTransform):
1671     """
1672     Create debug information and all functions' visibility to extern in order
1673     to enable debugging.
1674     """
1675
1676     def __init__(self, context, options, result):
1677         super(DebugTransform, self).__init__(context)
1678         self.visited = cython.set()
1679         # our treebuilder and debug output writer
1680         # (see Cython.Debugger.debug_output.CythonDebugWriter)
1681         self.tb = self.context.gdb_debug_outputwriter
1682         #self.c_output_file = options.output_file
1683         self.c_output_file = result.c_file
1684
1685         # Closure support, basically treat nested functions as if the AST were
1686         # never nested
1687         self.nested_funcdefs = []
1688
1689         # tells visit_NameNode whether it should register step-into functions
1690         self.register_stepinto = False
1691
1692     def visit_ModuleNode(self, node):
1693         self.tb.module_name = node.full_module_name
1694         attrs = dict(
1695             module_name=node.full_module_name,
1696             filename=node.pos[0].filename,
1697             c_filename=self.c_output_file)
1698
1699         self.tb.start('Module', attrs)
1700
1701         # serialize functions
1702         self.tb.start('Functions')
1703         # First, serialize functions normally...
1704         self.visitchildren(node)
1705
1706         # ... then, serialize nested functions
1707         for nested_funcdef in self.nested_funcdefs:
1708             self.visit_FuncDefNode(nested_funcdef)
1709
1710         self.register_stepinto = True
1711         self.serialize_modulenode_as_function(node)
1712         self.register_stepinto = False
1713         self.tb.end('Functions')
1714
1715         # 2.3 compatibility. Serialize global variables
1716         self.tb.start('Globals')
1717         entries = {}
1718
1719         for k, v in node.scope.entries.iteritems():
1720             if (v.qualified_name not in self.visited and not
1721                 v.name.startswith('__pyx_') and not
1722                 v.type.is_cfunction and not
1723                 v.type.is_extension_type):
1724                 entries[k]= v
1725
1726         self.serialize_local_variables(entries)
1727         self.tb.end('Globals')
1728         # self.tb.end('Module') # end Module after the line number mapping in
1729         # Cython.Compiler.ModuleNode.ModuleNode._serialize_lineno_map
1730         return node
1731
1732     def visit_FuncDefNode(self, node):
1733         self.visited.add(node.local_scope.qualified_name)
1734
1735         if getattr(node, 'is_wrapper', False):
1736             return node
1737
1738         if self.register_stepinto:
1739             self.nested_funcdefs.append(node)
1740             return node
1741
1742         # node.entry.visibility = 'extern'
1743         if node.py_func is None:
1744             pf_cname = ''
1745         else:
1746             pf_cname = node.py_func.entry.func_cname
1747
1748         attrs = dict(
1749             name=node.entry.name,
1750             cname=node.entry.func_cname,
1751             pf_cname=pf_cname,
1752             qualified_name=node.local_scope.qualified_name,
1753             lineno=str(node.pos[1]))
1754
1755         self.tb.start('Function', attrs=attrs)
1756
1757         self.tb.start('Locals')
1758         self.serialize_local_variables(node.local_scope.entries)
1759         self.tb.end('Locals')
1760
1761         self.tb.start('Arguments')
1762         for arg in node.local_scope.arg_entries:
1763             self.tb.start(arg.name)
1764             self.tb.end(arg.name)
1765         self.tb.end('Arguments')
1766
1767         self.tb.start('StepIntoFunctions')
1768         self.register_stepinto = True
1769         self.visitchildren(node)
1770         self.register_stepinto = False
1771         self.tb.end('StepIntoFunctions')
1772         self.tb.end('Function')
1773
1774         return node
1775
1776     def visit_NameNode(self, node):
1777         if (self.register_stepinto and
1778             node.type.is_cfunction and
1779             getattr(node, 'is_called', False) and
1780             node.entry.func_cname is not None):
1781             # don't check node.entry.in_cinclude, as 'cdef extern: ...'
1782             # declared functions are not 'in_cinclude'.
1783             # This means we will list called 'cdef' functions as
1784             # "step into functions", but this is not an issue as they will be
1785             # recognized as Cython functions anyway.
1786             attrs = dict(name=node.entry.func_cname)
1787             self.tb.start('StepIntoFunction', attrs=attrs)
1788             self.tb.end('StepIntoFunction')
1789
1790         self.visitchildren(node)
1791         return node
1792
1793     def serialize_modulenode_as_function(self, node):
1794         """
1795         Serialize the module-level code as a function so the debugger will know
1796         it's a "relevant frame" and it will know where to set the breakpoint
1797         for 'break modulename'.
1798         """
1799         name = node.full_module_name.rpartition('.')[-1]
1800
1801         cname_py2 = 'init' + name
1802         cname_py3 = 'PyInit_' + name
1803
1804         py2_attrs = dict(
1805             name=name,
1806             cname=cname_py2,
1807             pf_cname='',
1808             # Ignore the qualified_name, breakpoints should be set using
1809             # `cy break modulename:lineno` for module-level breakpoints.
1810             qualified_name='',
1811             lineno='1',
1812             is_initmodule_function="True",
1813         )
1814
1815         py3_attrs = dict(py2_attrs, cname=cname_py3)
1816
1817         self._serialize_modulenode_as_function(node, py2_attrs)
1818         self._serialize_modulenode_as_function(node, py3_attrs)
1819
1820     def _serialize_modulenode_as_function(self, node, attrs):
1821         self.tb.start('Function', attrs=attrs)
1822
1823         self.tb.start('Locals')
1824         self.serialize_local_variables(node.scope.entries)
1825         self.tb.end('Locals')
1826
1827         self.tb.start('Arguments')
1828         self.tb.end('Arguments')
1829
1830         self.tb.start('StepIntoFunctions')
1831         self.register_stepinto = True
1832         self.visitchildren(node)
1833         self.register_stepinto = False
1834         self.tb.end('StepIntoFunctions')
1835
1836         self.tb.end('Function')
1837
1838     def serialize_local_variables(self, entries):
1839         for entry in entries.values():
1840             if entry.type.is_pyobject:
1841                 vartype = 'PythonObject'
1842             else:
1843                 vartype = 'CObject'
1844
1845             if entry.from_closure:
1846                 # We're dealing with a closure where a variable from an outer
1847                 # scope is accessed, get it from the scope object.
1848                 cname = '%s->%s' % (Naming.cur_scope_cname,
1849                                     entry.outer_entry.cname)
1850
1851                 qname = '%s.%s.%s' % (entry.scope.outer_scope.qualified_name,
1852                                       entry.scope.name,
1853                                       entry.name)
1854             elif entry.in_closure:
1855                 cname = '%s->%s' % (Naming.cur_scope_cname,
1856                                     entry.cname)
1857                 qname = entry.qualified_name
1858             else:
1859                 cname = entry.cname
1860                 qname = entry.qualified_name
1861
1862             if not entry.pos:
1863                 # this happens for variables that are not in the user's code,
1864                 # e.g. for the global __builtins__, __doc__, etc. We can just
1865                 # set the lineno to 0 for those.
1866                 lineno = '0'
1867             else:
1868                 lineno = str(entry.pos[1])
1869
1870             attrs = dict(
1871                 name=entry.name,
1872                 cname=cname,
1873                 qualified_name=qname,
1874                 type=vartype,
1875                 lineno=lineno)
1876
1877             self.tb.start('LocalVar', attrs)
1878             self.tb.end('LocalVar')
1879