merged with latest cython-devel
[cython.git] / Cython / Compiler / ParseTreeTransforms.py
1 from Cython.Compiler.Visitor import VisitorTransform, TreeVisitor
2 from Cython.Compiler.Visitor import CythonTransform, EnvTransform
3 from Cython.Compiler.ModuleNode import ModuleNode
4 from Cython.Compiler.Nodes import *
5 from Cython.Compiler.ExprNodes import *
6 from Cython.Compiler.UtilNodes import *
7 from Cython.Compiler.TreeFragment import TreeFragment, TemplateTransform
8 from Cython.Compiler.StringEncoding import EncodedString
9 from Cython.Compiler.Errors import error, CompileError
10 try:
11     set
12 except NameError:
13     from sets import Set as set
14 import copy
15
16
17 class NameNodeCollector(TreeVisitor):
18     """Collect all NameNodes of a (sub-)tree in the ``name_nodes``
19     attribute.
20     """
21     def __init__(self):
22         super(NameNodeCollector, self).__init__()
23         self.name_nodes = []
24
25     visit_Node = TreeVisitor.visitchildren
26
27     def visit_NameNode(self, node):
28         self.name_nodes.append(node)
29
30
31 class SkipDeclarations(object):
32     """
33     Variable and function declarations can often have a deep tree structure, 
34     and yet most transformations don't need to descend to this depth. 
35     
36     Declaration nodes are removed after AnalyseDeclarationsTransform, so there 
37     is no need to use this for transformations after that point. 
38     """
39     def visit_CTypeDefNode(self, node):
40         return node
41     
42     def visit_CVarDefNode(self, node):
43         return node
44     
45     def visit_CDeclaratorNode(self, node):
46         return node
47     
48     def visit_CBaseTypeNode(self, node):
49         return node
50     
51     def visit_CEnumDefNode(self, node):
52         return node
53
54     def visit_CStructOrUnionDefNode(self, node):
55         return node
56
57
58 class NormalizeTree(CythonTransform):
59     """
60     This transform fixes up a few things after parsing
61     in order to make the parse tree more suitable for
62     transforms.
63
64     a) After parsing, blocks with only one statement will
65     be represented by that statement, not by a StatListNode.
66     When doing transforms this is annoying and inconsistent,
67     as one cannot in general remove a statement in a consistent
68     way and so on. This transform wraps any single statements
69     in a StatListNode containing a single statement.
70
71     b) The PassStatNode is a noop and serves no purpose beyond
72     plugging such one-statement blocks; i.e., once parsed a
73 `    "pass" can just as well be represented using an empty
74     StatListNode. This means less special cases to worry about
75     in subsequent transforms (one always checks to see if a
76     StatListNode has no children to see if the block is empty).
77     """
78
79     def __init__(self, context):
80         super(NormalizeTree, self).__init__(context)
81         self.is_in_statlist = False
82         self.is_in_expr = False
83
84     def visit_ExprNode(self, node):
85         stacktmp = self.is_in_expr
86         self.is_in_expr = True
87         self.visitchildren(node)
88         self.is_in_expr = stacktmp
89         return node
90
91     def visit_StatNode(self, node, is_listcontainer=False):
92         stacktmp = self.is_in_statlist
93         self.is_in_statlist = is_listcontainer
94         self.visitchildren(node)
95         self.is_in_statlist = stacktmp
96         if not self.is_in_statlist and not self.is_in_expr:
97             return StatListNode(pos=node.pos, stats=[node])
98         else:
99             return node
100
101     def visit_StatListNode(self, node):
102         self.is_in_statlist = True
103         self.visitchildren(node)
104         self.is_in_statlist = False
105         return node
106
107     def visit_ParallelAssignmentNode(self, node):
108         return self.visit_StatNode(node, True)
109     
110     def visit_CEnumDefNode(self, node):
111         return self.visit_StatNode(node, True)
112
113     def visit_CStructOrUnionDefNode(self, node):
114         return self.visit_StatNode(node, True)
115
116     # Eliminate PassStatNode
117     def visit_PassStatNode(self, node):
118         if not self.is_in_statlist:
119             return StatListNode(pos=node.pos, stats=[])
120         else:
121             return []
122
123     def visit_CDeclaratorNode(self, node):
124         return node    
125
126
127 class PostParseError(CompileError): pass
128
129 # error strings checked by unit tests, so define them
130 ERR_CDEF_INCLASS = 'Cannot assign default value to fields in cdef classes, structs or unions'
131 ERR_BUF_DEFAULTS = 'Invalid buffer defaults specification (see docs)'
132 ERR_INVALID_SPECIALATTR_TYPE = 'Special attributes must not have a type declared'
133 class PostParse(CythonTransform):
134     """
135     Basic interpretation of the parse tree, as well as validity
136     checking that can be done on a very basic level on the parse
137     tree (while still not being a problem with the basic syntax,
138     as such).
139
140     Specifically:
141     - Default values to cdef assignments are turned into single
142     assignments following the declaration (everywhere but in class
143     bodies, where they raise a compile error)
144     
145     - Interpret some node structures into Python runtime values.
146     Some nodes take compile-time arguments (currently:
147     TemplatedTypeNode[args] and __cythonbufferdefaults__ = {args}),
148     which should be interpreted. This happens in a general way
149     and other steps should be taken to ensure validity.
150
151     Type arguments cannot be interpreted in this way.
152
153     - For __cythonbufferdefaults__ the arguments are checked for
154     validity.
155
156     TemplatedTypeNode has its directives interpreted:
157     Any first positional argument goes into the "dtype" attribute,
158     any "ndim" keyword argument goes into the "ndim" attribute and
159     so on. Also it is checked that the directive combination is valid.
160     - __cythonbufferdefaults__ attributes are parsed and put into the
161     type information.
162
163     Note: Currently Parsing.py does a lot of interpretation and
164     reorganization that can be refactored into this transform
165     if a more pure Abstract Syntax Tree is wanted.
166     """
167
168     # Track our context.
169     scope_type = None # can be either of 'module', 'function', 'class'
170
171     def __init__(self, context):
172         super(PostParse, self).__init__(context)
173         self.specialattribute_handlers = {
174             '__cythonbufferdefaults__' : self.handle_bufferdefaults
175         }
176
177     def visit_ModuleNode(self, node):
178         self.scope_type = 'module'
179         self.scope_node = node
180         self.lambda_counter = 1
181         self.visitchildren(node)
182         return node
183
184     def visit_scope(self, node, scope_type):
185         prev = self.scope_type, self.scope_node
186         self.scope_type = scope_type
187         self.scope_node = node
188         self.visitchildren(node)
189         self.scope_type, self.scope_node = prev
190         return node
191     
192     def visit_ClassDefNode(self, node):
193         return self.visit_scope(node, 'class')
194
195     def visit_FuncDefNode(self, node):
196         return self.visit_scope(node, 'function')
197
198     def visit_CStructOrUnionDefNode(self, node):
199         return self.visit_scope(node, 'struct')
200
201     def visit_LambdaNode(self, node):
202         # unpack a lambda expression into the corresponding DefNode
203         if self.scope_type != 'function':
204             error(node.pos,
205                   "lambda functions are currently only supported in functions")
206         lambda_id = self.lambda_counter
207         self.lambda_counter += 1
208         node.lambda_name = EncodedString(u'lambda%d' % lambda_id)
209
210         body = Nodes.ReturnStatNode(
211             node.result_expr.pos, value = node.result_expr)
212         node.def_node = Nodes.DefNode(
213             node.pos, name=node.name, lambda_name=node.lambda_name,
214             args=node.args, star_arg=node.star_arg,
215             starstar_arg=node.starstar_arg,
216             body=body)
217         self.visitchildren(node)
218         return node
219
220     # cdef variables
221     def handle_bufferdefaults(self, decl):
222         if not isinstance(decl.default, DictNode):
223             raise PostParseError(decl.pos, ERR_BUF_DEFAULTS)
224         self.scope_node.buffer_defaults_node = decl.default
225         self.scope_node.buffer_defaults_pos = decl.pos
226
227     def visit_CVarDefNode(self, node):
228         # This assumes only plain names and pointers are assignable on
229         # declaration. Also, it makes use of the fact that a cdef decl
230         # must appear before the first use, so we don't have to deal with
231         # "i = 3; cdef int i = i" and can simply move the nodes around.
232         try:
233             self.visitchildren(node)
234             stats = [node]
235             newdecls = []
236             for decl in node.declarators:
237                 declbase = decl
238                 while isinstance(declbase, CPtrDeclaratorNode):
239                     declbase = declbase.base
240                 if isinstance(declbase, CNameDeclaratorNode):
241                     if declbase.default is not None:
242                         if self.scope_type in ('class', 'struct'):
243                             if isinstance(self.scope_node, CClassDefNode):
244                                 handler = self.specialattribute_handlers.get(decl.name)
245                                 if handler:
246                                     if decl is not declbase:
247                                         raise PostParseError(decl.pos, ERR_INVALID_SPECIALATTR_TYPE)
248                                     handler(decl)
249                                     continue # Remove declaration
250                             raise PostParseError(decl.pos, ERR_CDEF_INCLASS)
251                         first_assignment = self.scope_type != 'module'
252                         stats.append(SingleAssignmentNode(node.pos,
253                             lhs=NameNode(node.pos, name=declbase.name),
254                             rhs=declbase.default, first=first_assignment))
255                         declbase.default = None
256                 newdecls.append(decl)
257             node.declarators = newdecls
258             return stats
259         except PostParseError, e:
260             # An error in a cdef clause is ok, simply remove the declaration
261             # and try to move on to report more errors
262             self.context.nonfatal_error(e)
263             return None
264
265     # Split parallel assignments (a,b = b,a) into separate partial
266     # assignments that are executed rhs-first using temps.  This
267     # optimisation is best applied before type analysis so that known
268     # types on rhs and lhs can be matched directly.
269
270     def visit_SingleAssignmentNode(self, node):
271         self.visitchildren(node)
272         return self._visit_assignment_node(node, [node.lhs, node.rhs])
273
274     def visit_CascadedAssignmentNode(self, node):
275         self.visitchildren(node)
276         return self._visit_assignment_node(node, node.lhs_list + [node.rhs])
277
278     def _visit_assignment_node(self, node, expr_list):
279         """Flatten parallel assignments into separate single
280         assignments or cascaded assignments.
281         """
282         if sum([ 1 for expr in expr_list if expr.is_sequence_constructor ]) < 2:
283             # no parallel assignments => nothing to do
284             return node
285
286         expr_list_list = []
287         flatten_parallel_assignments(expr_list, expr_list_list)
288         nodes = []
289         for expr_list in expr_list_list:
290             lhs_list = expr_list[:-1]
291             rhs = expr_list[-1]
292             if len(lhs_list) == 1:
293                 node = Nodes.SingleAssignmentNode(rhs.pos, 
294                     lhs = lhs_list[0], rhs = rhs)
295             else:
296                 node = Nodes.CascadedAssignmentNode(rhs.pos,
297                     lhs_list = lhs_list, rhs = rhs)
298             nodes.append(node)
299         if len(nodes) == 1:
300             return nodes[0]
301         else:
302             return Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes)
303
304
305 def flatten_parallel_assignments(input, output):
306     #  The input is a list of expression nodes, representing the LHSs
307     #  and RHS of one (possibly cascaded) assignment statement.  For
308     #  sequence constructors, rearranges the matching parts of both
309     #  sides into a list of equivalent assignments between the
310     #  individual elements.  This transformation is applied
311     #  recursively, so that nested structures get matched as well.
312     rhs = input[-1]
313     if not rhs.is_sequence_constructor or not sum([lhs.is_sequence_constructor for lhs in input[:-1]]):
314         output.append(input)
315         return
316
317     complete_assignments = []
318
319     rhs_size = len(rhs.args)
320     lhs_targets = [ [] for _ in xrange(rhs_size) ]
321     starred_assignments = []
322     for lhs in input[:-1]:
323         if not lhs.is_sequence_constructor:
324             if lhs.is_starred:
325                 error(lhs.pos, "starred assignment target must be in a list or tuple")
326             complete_assignments.append(lhs)
327             continue
328         lhs_size = len(lhs.args)
329         starred_targets = sum([1 for expr in lhs.args if expr.is_starred])
330         if starred_targets > 1:
331             error(lhs.pos, "more than 1 starred expression in assignment")
332             output.append([lhs,rhs])
333             continue
334         elif lhs_size - starred_targets > rhs_size:
335             error(lhs.pos, "need more than %d value%s to unpack"
336                   % (rhs_size, (rhs_size != 1) and 's' or ''))
337             output.append([lhs,rhs])
338             continue
339         elif starred_targets == 1:
340             map_starred_assignment(lhs_targets, starred_assignments,
341                                    lhs.args, rhs.args)
342         elif lhs_size < rhs_size:
343             error(lhs.pos, "too many values to unpack (expected %d, got %d)"
344                   % (lhs_size, rhs_size))
345             output.append([lhs,rhs])
346             continue
347         else:
348             for targets, expr in zip(lhs_targets, lhs.args):
349                 targets.append(expr)
350
351     if complete_assignments:
352         complete_assignments.append(rhs)
353         output.append(complete_assignments)
354
355     # recursively flatten partial assignments
356     for cascade, rhs in zip(lhs_targets, rhs.args):
357         if cascade:
358             cascade.append(rhs)
359             flatten_parallel_assignments(cascade, output)
360
361     # recursively flatten starred assignments
362     for cascade in starred_assignments:
363         if cascade[0].is_sequence_constructor:
364             flatten_parallel_assignments(cascade, output)
365         else:
366             output.append(cascade)
367
368 def map_starred_assignment(lhs_targets, starred_assignments, lhs_args, rhs_args):
369     # Appends the fixed-position LHS targets to the target list that
370     # appear left and right of the starred argument.
371     #
372     # The starred_assignments list receives a new tuple
373     # (lhs_target, rhs_values_list) that maps the remaining arguments
374     # (those that match the starred target) to a list.
375
376     # left side of the starred target
377     for i, (targets, expr) in enumerate(zip(lhs_targets, lhs_args)):
378         if expr.is_starred:
379             starred = i
380             lhs_remaining = len(lhs_args) - i - 1
381             break
382         targets.append(expr)
383     else:
384         raise InternalError("no starred arg found when splitting starred assignment")
385
386     # right side of the starred target
387     for i, (targets, expr) in enumerate(zip(lhs_targets[-lhs_remaining:],
388                                             lhs_args[-lhs_remaining:])):
389         targets.append(expr)
390
391     # the starred target itself, must be assigned a (potentially empty) list
392     target = lhs_args[starred].target # unpack starred node
393     starred_rhs = rhs_args[starred:]
394     if lhs_remaining:
395         starred_rhs = starred_rhs[:-lhs_remaining]
396     if starred_rhs:
397         pos = starred_rhs[0].pos
398     else:
399         pos = target.pos
400     starred_assignments.append([
401         target, ExprNodes.ListNode(pos=pos, args=starred_rhs)])
402
403
404 class PxdPostParse(CythonTransform, SkipDeclarations):
405     """
406     Basic interpretation/validity checking that should only be
407     done on pxd trees.
408
409     A lot of this checking currently happens in the parser; but
410     what is listed below happens here.
411
412     - "def" functions are let through only if they fill the
413     getbuffer/releasebuffer slots
414     
415     - cdef functions are let through only if they are on the
416     top level and are declared "inline"
417     """
418     ERR_INLINE_ONLY = "function definition in pxd file must be declared 'cdef inline'"
419     ERR_NOGO_WITH_INLINE = "inline function definition in pxd file cannot be '%s'"
420
421     def __call__(self, node):
422         self.scope_type = 'pxd'
423         return super(PxdPostParse, self).__call__(node)
424
425     def visit_CClassDefNode(self, node):
426         old = self.scope_type
427         self.scope_type = 'cclass'
428         self.visitchildren(node)
429         self.scope_type = old
430         return node
431
432     def visit_FuncDefNode(self, node):
433         # FuncDefNode always come with an implementation (without
434         # an imp they are CVarDefNodes..)
435         err = self.ERR_INLINE_ONLY
436
437         if (isinstance(node, DefNode) and self.scope_type == 'cclass'
438             and node.name in ('__getbuffer__', '__releasebuffer__')):
439             err = None # allow these slots
440             
441         if isinstance(node, CFuncDefNode):
442             if u'inline' in node.modifiers and self.scope_type == 'pxd':
443                 node.inline_in_pxd = True
444                 if node.visibility != 'private':
445                     err = self.ERR_NOGO_WITH_INLINE % node.visibility
446                 elif node.api:
447                     err = self.ERR_NOGO_WITH_INLINE % 'api'
448                 else:
449                     err = None # allow inline function
450             else:
451                 err = self.ERR_INLINE_ONLY
452
453         if err:
454             self.context.nonfatal_error(PostParseError(node.pos, err))
455             return None
456         else:
457             return node
458     
459 class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
460     """
461     After parsing, directives can be stored in a number of places:
462     - #cython-comments at the top of the file (stored in ModuleNode)
463     - Command-line arguments overriding these
464     - @cython.directivename decorators
465     - with cython.directivename: statements
466
467     This transform is responsible for interpreting these various sources
468     and store the directive in two ways:
469     - Set the directives attribute of the ModuleNode for global directives.
470     - Use a CompilerDirectivesNode to override directives for a subtree.
471
472     (The first one is primarily to not have to modify with the tree
473     structure, so that ModuleNode stay on top.)
474
475     The directives are stored in dictionaries from name to value in effect.
476     Each such dictionary is always filled in for all possible directives,
477     using default values where no value is given by the user.
478
479     The available directives are controlled in Options.py.
480
481     Note that we have to run this prior to analysis, and so some minor
482     duplication of functionality has to occur: We manually track cimports
483     and which names the "cython" module may have been imported to.
484     """
485     unop_method_nodes = {
486         'typeof': TypeofNode,
487         
488         'operator.address': AmpersandNode,
489         'operator.dereference': DereferenceNode,
490         'operator.preincrement' : inc_dec_constructor(True, '++'),
491         'operator.predecrement' : inc_dec_constructor(True, '--'),
492         'operator.postincrement': inc_dec_constructor(False, '++'),
493         'operator.postdecrement': inc_dec_constructor(False, '--'),
494
495         # For backwards compatability.
496         'address': AmpersandNode,
497     }
498     
499     special_methods = set(['declare', 'union', 'struct', 'typedef', 'sizeof',
500                            'cast', 'pointer', 'compiled', 'NULL']
501                           + unop_method_nodes.keys())
502
503     def __init__(self, context, compilation_directive_defaults):
504         super(InterpretCompilerDirectives, self).__init__(context)
505         self.compilation_directive_defaults = {}
506         for key, value in compilation_directive_defaults.iteritems():
507             self.compilation_directive_defaults[unicode(key)] = value
508         self.cython_module_names = set()
509         self.directive_names = {}
510
511     def check_directive_scope(self, pos, directive, scope):
512         legal_scopes = Options.directive_scopes.get(directive, None)
513         if legal_scopes and scope not in legal_scopes:
514             self.context.nonfatal_error(PostParseError(pos, 'The %s compiler directive '
515                                         'is not allowed in %s scope' % (directive, scope)))
516             return False
517         else:
518             return True
519         
520     # Set up processing and handle the cython: comments.
521     def visit_ModuleNode(self, node):
522         for key, value in node.directive_comments.iteritems():
523             if not self.check_directive_scope(node.pos, key, 'module'):
524                 self.wrong_scope_error(node.pos, key, 'module')
525                 del node.directive_comments[key]
526
527         directives = copy.copy(Options.directive_defaults)
528         directives.update(self.compilation_directive_defaults)
529         directives.update(node.directive_comments)
530         self.directives = directives
531         node.directives = directives
532         self.visitchildren(node)
533         node.cython_module_names = self.cython_module_names
534         return node
535
536     # The following four functions track imports and cimports that
537     # begin with "cython"
538     def is_cython_directive(self, name):
539         return (name in Options.directive_types or
540                 name in self.special_methods or
541                 PyrexTypes.parse_basic_type(name))
542
543     def visit_CImportStatNode(self, node):
544         if node.module_name == u"cython":
545             self.cython_module_names.add(node.as_name or u"cython")
546         elif node.module_name.startswith(u"cython."):
547             if node.as_name:
548                 self.directive_names[node.as_name] = node.module_name[7:]
549             else:
550                 self.cython_module_names.add(u"cython")
551             # if this cimport was a compiler directive, we don't
552             # want to leave the cimport node sitting in the tree
553             return None
554         return node
555     
556     def visit_FromCImportStatNode(self, node):
557         if (node.module_name == u"cython") or \
558                node.module_name.startswith(u"cython."):
559             submodule = (node.module_name + u".")[7:]
560             newimp = []
561             for pos, name, as_name, kind in node.imported_names:
562                 full_name = submodule + name
563                 if self.is_cython_directive(full_name):
564                     if as_name is None:
565                         as_name = full_name
566                     self.directive_names[as_name] = full_name
567                     if kind is not None:
568                         self.context.nonfatal_error(PostParseError(pos,
569                             "Compiler directive imports must be plain imports"))
570                 else:
571                     newimp.append((pos, name, as_name, kind))
572             if not newimp:
573                 return None
574             node.imported_names = newimp
575         return node
576         
577     def visit_FromImportStatNode(self, node):
578         if (node.module.module_name.value == u"cython") or \
579                node.module.module_name.value.startswith(u"cython."):
580             submodule = (node.module.module_name.value + u".")[7:]
581             newimp = []
582             for name, name_node in node.items:
583                 full_name = submodule + name
584                 if self.is_cython_directive(full_name):
585                     self.directive_names[name_node.name] = full_name
586                 else:
587                     newimp.append((name, name_node))
588             if not newimp:
589                 return None
590             node.items = newimp
591         return node
592
593     def visit_SingleAssignmentNode(self, node):
594         if (isinstance(node.rhs, ImportNode) and
595                 node.rhs.module_name.value == u'cython'):
596             node = CImportStatNode(node.pos, 
597                                    module_name = u'cython',
598                                    as_name = node.lhs.name)
599             self.visit_CImportStatNode(node)
600         else:
601             self.visitchildren(node)
602         return node
603             
604     def visit_NameNode(self, node):
605         if node.name in self.cython_module_names:
606             node.is_cython_module = True
607         else:
608             node.cython_attribute = self.directive_names.get(node.name)
609         return node
610
611     def try_to_parse_directives(self, node):
612         # If node is the contents of an directive (in a with statement or
613         # decorator), returns a list of (directivename, value) pairs.
614         # Otherwise, returns None
615         if isinstance(node, CallNode):
616             self.visit(node.function)
617             optname = node.function.as_cython_attribute()
618             if optname:
619                 directivetype = Options.directive_types.get(optname)
620                 if directivetype:
621                     args, kwds = node.explicit_args_kwds()
622                     directives = []
623                     key_value_pairs = []
624                     if kwds is not None and directivetype is not dict:
625                         for keyvalue in kwds.key_value_pairs:
626                             key, value = keyvalue
627                             sub_optname = "%s.%s" % (optname, key.value)
628                             if Options.directive_types.get(sub_optname):
629                                 directives.append(self.try_to_parse_directive(sub_optname, [value], None, keyvalue.pos))
630                             else:
631                                 key_value_pairs.append(keyvalue)
632                         if not key_value_pairs:
633                             kwds = None
634                         else:
635                             kwds.key_value_pairs = key_value_pairs
636                         if directives and not kwds and not args:
637                             return directives
638                     directives.append(self.try_to_parse_directive(optname, args, kwds, node.function.pos))
639                     return directives
640                 
641         return None
642
643     def try_to_parse_directive(self, optname, args, kwds, pos):
644         directivetype = Options.directive_types.get(optname)
645         if len(args) == 1 and isinstance(args[0], NoneNode):
646             return optname, Options.directive_defaults[optname]
647         elif directivetype is bool:
648             if kwds is not None or len(args) != 1 or not isinstance(args[0], BoolNode):
649                 raise PostParseError(pos,
650                     'The %s directive takes one compile-time boolean argument' % optname)
651             return (optname, args[0].value)
652         elif directivetype is str:
653             if kwds is not None or len(args) != 1 or not isinstance(args[0], (StringNode, UnicodeNode)):
654                 raise PostParseError(pos,
655                     'The %s directive takes one compile-time string argument' % optname)
656             return (optname, str(args[0].value))
657         elif directivetype is dict:
658             if len(args) != 0:
659                 raise PostParseError(pos,
660                     'The %s directive takes no prepositional arguments' % optname)
661             return optname, dict([(key.value, value) for key, value in kwds.key_value_pairs])
662         elif directivetype is list:
663             if kwds and len(kwds) != 0:
664                 raise PostParseError(pos,
665                     'The %s directive takes no keyword arguments' % optname)
666             return optname, [ str(arg.value) for arg in args ]
667         else:
668             assert False
669
670     def visit_with_directives(self, body, directives):
671         olddirectives = self.directives
672         newdirectives = copy.copy(olddirectives)
673         newdirectives.update(directives)
674         self.directives = newdirectives
675         assert isinstance(body, StatListNode), body
676         retbody = self.visit_Node(body)
677         directive = CompilerDirectivesNode(pos=retbody.pos, body=retbody,
678                                            directives=newdirectives)
679         self.directives = olddirectives
680         return directive
681  
682     # Handle decorators
683     def visit_FuncDefNode(self, node):
684         directives = []
685         if node.decorators:
686             # Split the decorators into two lists -- real decorators and directives
687             realdecs = []
688             for dec in node.decorators:
689                 new_directives = self.try_to_parse_directives(dec.decorator)
690                 if new_directives is not None:
691                     directives.extend(new_directives)
692                 else:
693                     realdecs.append(dec)
694             if realdecs and isinstance(node, CFuncDefNode):
695                 raise PostParseError(realdecs[0].pos, "Cdef functions cannot take arbitrary decorators.")
696             else:
697                 node.decorators = realdecs
698         
699         if directives:
700             optdict = {}
701             directives.reverse() # Decorators coming first take precedence
702             for directive in directives:
703                 name, value = directive
704                 legal_scopes = Options.directive_scopes.get(name, None)
705                 if not self.check_directive_scope(node.pos, name, 'function'):
706                     continue
707                 if name in optdict:
708                     old_value = optdict[name]
709                     # keywords and arg lists can be merged, everything
710                     # else overrides completely
711                     if isinstance(old_value, dict):
712                         old_value.update(value)
713                     elif isinstance(old_value, list):
714                         old_value.extend(value)
715                     else:
716                         optdict[name] = value
717                 else:
718                     optdict[name] = value
719             body = StatListNode(node.pos, stats=[node])
720             return self.visit_with_directives(body, optdict)
721         else:
722             return self.visit_Node(node)
723     
724     def visit_CVarDefNode(self, node):
725         if node.decorators:
726             for dec in node.decorators:
727                 for directive in self.try_to_parse_directives(dec.decorator) or []:
728                     if directive is not None and directive[0] == u'locals':
729                         node.directive_locals = directive[1]
730                     else:
731                         self.context.nonfatal_error(PostParseError(dec.pos,
732                             "Cdef functions can only take cython.locals() decorator."))
733         return node
734                                    
735     # Handle with statements
736     def visit_WithStatNode(self, node):
737         directive_dict = {}
738         for directive in self.try_to_parse_directives(node.manager) or []:
739             if directive is not None:
740                 if node.target is not None:
741                     self.context.nonfatal_error(
742                         PostParseError(node.pos, "Compiler directive with statements cannot contain 'as'"))
743                 else:
744                     name, value = directive
745                     if self.check_directive_scope(node.pos, name, 'with statement'):
746                         directive_dict[name] = value
747         if directive_dict:
748             return self.visit_with_directives(node.body, directive_dict)
749         return self.visit_Node(node)
750
751 class WithTransform(CythonTransform, SkipDeclarations):
752
753     # EXCINFO is manually set to a variable that contains
754     # the exc_info() tuple that can be generated by the enclosing except
755     # statement.
756     template_without_target = TreeFragment(u"""
757         MGR = EXPR
758         EXIT = MGR.__exit__
759         MGR.__enter__()
760         EXC = True
761         try:
762             try:
763                 EXCINFO = None
764                 BODY
765             except:
766                 EXC = False
767                 if not EXIT(*EXCINFO):
768                     raise
769         finally:
770             if EXC:
771                 EXIT(None, None, None)
772     """, temps=[u'MGR', u'EXC', u"EXIT"],
773     pipeline=[NormalizeTree(None)])
774
775     template_with_target = TreeFragment(u"""
776         MGR = EXPR
777         EXIT = MGR.__exit__
778         VALUE = MGR.__enter__()
779         EXC = True
780         try:
781             try:
782                 EXCINFO = None
783                 TARGET = VALUE
784                 BODY
785             except:
786                 EXC = False
787                 if not EXIT(*EXCINFO):
788                     raise
789         finally:
790             if EXC:
791                 EXIT(None, None, None)
792             MGR = EXIT = VALUE = EXC = None
793             
794     """, temps=[u'MGR', u'EXC', u"EXIT", u"VALUE"],
795     pipeline=[NormalizeTree(None)])
796
797     def visit_WithStatNode(self, node):
798         # TODO: Cleanup badly needed
799         TemplateTransform.temp_name_counter += 1
800         handle = "__tmpvar_%d" % TemplateTransform.temp_name_counter
801         
802         self.visitchildren(node, ['body'])
803         excinfo_temp = NameNode(node.pos, name=handle)#TempHandle(Builtin.tuple_type)
804         if node.target is not None:
805             result = self.template_with_target.substitute({
806                 u'EXPR' : node.manager,
807                 u'BODY' : node.body,
808                 u'TARGET' : node.target,
809                 u'EXCINFO' : excinfo_temp
810                 }, pos=node.pos)
811         else:
812             result = self.template_without_target.substitute({
813                 u'EXPR' : node.manager,
814                 u'BODY' : node.body,
815                 u'EXCINFO' : excinfo_temp
816                 }, pos=node.pos)
817
818         # Set except excinfo target to EXCINFO
819         try_except = result.stats[-1].body.stats[-1]
820         try_except.except_clauses[0].excinfo_target = NameNode(node.pos, name=handle)
821 #            excinfo_temp.ref(node.pos))
822
823 #        result.stats[-1].body.stats[-1] = TempsBlockNode(
824 #            node.pos, temps=[excinfo_temp], body=try_except)
825
826         return result
827         
828     def visit_ExprNode(self, node):
829         # With statements are never inside expressions.
830         return node
831         
832
833 class DecoratorTransform(CythonTransform, SkipDeclarations):
834
835     def visit_DefNode(self, func_node):
836         self.visitchildren(func_node)
837         if not func_node.decorators:
838             return func_node
839         return self._handle_decorators(
840             func_node, func_node.name)
841
842     def _visit_CClassDefNode(self, class_node):
843         # This doesn't currently work, so it's disabled (also in the
844         # parser).
845         #
846         # Problem: assignments to cdef class names do not work.  They
847         # would require an additional check anyway, as the extension
848         # type must not change its C type, so decorators cannot
849         # replace an extension type, just alter it and return it.
850
851         self.visitchildren(class_node)
852         if not class_node.decorators:
853             return class_node
854         return self._handle_decorators(
855             class_node, class_node.class_name)
856
857     def visit_ClassDefNode(self, class_node):
858         self.visitchildren(class_node)
859         if not class_node.decorators:
860             return class_node
861         return self._handle_decorators(
862             class_node, class_node.name)
863
864     def _handle_decorators(self, node, name):
865         decorator_result = NameNode(node.pos, name = name)
866         for decorator in node.decorators[::-1]:
867             decorator_result = SimpleCallNode(
868                 decorator.pos,
869                 function = decorator.decorator,
870                 args = [decorator_result])
871
872         name_node = NameNode(node.pos, name = name)
873         reassignment = SingleAssignmentNode(
874             node.pos,
875             lhs = name_node,
876             rhs = decorator_result)
877         return [node, reassignment]
878
879
880 class AnalyseDeclarationsTransform(CythonTransform):
881
882     basic_property = TreeFragment(u"""
883 property NAME:
884     def __get__(self):
885         return ATTR
886     def __set__(self, value):
887         ATTR = value
888     """, level='c_class')
889
890     def __call__(self, root):
891         self.env_stack = [root.scope]
892         # needed to determine if a cdef var is declared after it's used.
893         self.seen_vars_stack = []
894         return super(AnalyseDeclarationsTransform, self).__call__(root)        
895     
896     def visit_NameNode(self, node):
897         self.seen_vars_stack[-1].add(node.name)
898         return node
899
900     def visit_ModuleNode(self, node):
901         self.seen_vars_stack.append(set())
902         node.analyse_declarations(self.env_stack[-1])
903         self.visitchildren(node)
904         self.seen_vars_stack.pop()
905         return node
906
907     def visit_LambdaNode(self, node):
908         node.analyse_declarations(self.env_stack[-1])
909         self.visitchildren(node)
910         return node
911
912     def visit_ClassDefNode(self, node):
913         self.env_stack.append(node.scope)
914         self.visitchildren(node)
915         self.env_stack.pop()
916         return node
917         
918     def visit_FuncDefNode(self, node):
919         self.seen_vars_stack.append(set())
920         lenv = node.local_scope
921         node.body.analyse_control_flow(lenv) # this will be totally refactored
922         node.declare_arguments(lenv)
923         for var, type_node in node.directive_locals.items():
924             if not lenv.lookup_here(var):   # don't redeclare args
925                 type = type_node.analyse_as_type(lenv)
926                 if type:
927                     lenv.declare_var(var, type, type_node.pos)
928                 else:
929                     error(type_node.pos, "Not a type")
930         node.body.analyse_declarations(lenv)
931         self.env_stack.append(lenv)
932         self.visitchildren(node)
933         self.env_stack.pop()
934         self.seen_vars_stack.pop()
935         return node
936
937     def visit_ComprehensionNode(self, node):
938         self.visitchildren(node)
939         node.analyse_declarations(self.env_stack[-1])
940         return node
941
942     # Some nodes are no longer needed after declaration
943     # analysis and can be dropped. The analysis was performed
944     # on these nodes in a seperate recursive process from the
945     # enclosing function or module, so we can simply drop them.
946     def visit_CDeclaratorNode(self, node):
947         # necessary to ensure that all CNameDeclaratorNodes are visited.
948         self.visitchildren(node)
949         return node
950     
951     def visit_CTypeDefNode(self, node):
952         return node
953
954     def visit_CBaseTypeNode(self, node):
955         return None
956     
957     def visit_CEnumDefNode(self, node):
958         if node.visibility == 'public':
959             return node
960         else:
961             return None
962
963     def visit_CStructOrUnionDefNode(self, node):
964         return None
965
966     def visit_CNameDeclaratorNode(self, node):
967         if node.name in self.seen_vars_stack[-1]:
968             entry = self.env_stack[-1].lookup(node.name)
969             if entry is None or entry.visibility != 'extern':
970                 warning(node.pos, "cdef variable '%s' declared after it is used" % node.name, 2)
971         self.visitchildren(node)
972         return node
973
974     def visit_CVarDefNode(self, node):
975
976         # to ensure all CNameDeclaratorNodes are visited.
977         self.visitchildren(node)
978
979         if node.need_properties:
980             # cdef public attributes may need type testing on 
981             # assignment, so we create a property accesss
982             # mechanism for them. 
983             stats = []
984             for entry in node.need_properties:
985                 property = self.create_Property(entry)
986                 property.analyse_declarations(node.dest_scope)
987                 self.visit(property)
988                 stats.append(property)
989             return StatListNode(pos=node.pos, stats=stats)
990         else:
991             return None
992             
993     def create_Property(self, entry):
994         template = self.basic_property
995         property = template.substitute({
996                 u"ATTR": AttributeNode(pos=entry.pos,
997                                        obj=NameNode(pos=entry.pos, name="self"), 
998                                        attribute=entry.name),
999             }, pos=entry.pos).stats[0]
1000         property.name = entry.name
1001         return property
1002
1003 class AnalyseExpressionsTransform(CythonTransform):
1004
1005     def visit_ModuleNode(self, node):
1006         node.scope.infer_types()
1007         node.body.analyse_expressions(node.scope)
1008         self.visitchildren(node)
1009         return node
1010         
1011     def visit_FuncDefNode(self, node):
1012         node.local_scope.infer_types()
1013         node.body.analyse_expressions(node.local_scope)
1014         self.visitchildren(node)
1015         return node
1016         
1017 class AlignFunctionDefinitions(CythonTransform):
1018     """
1019     This class takes the signatures from a .pxd file and applies them to 
1020     the def methods in a .py file. 
1021     """
1022     
1023     def visit_ModuleNode(self, node):
1024         self.scope = node.scope
1025         self.directives = node.directives
1026         self.visitchildren(node)
1027         return node
1028     
1029     def visit_PyClassDefNode(self, node):
1030         pxd_def = self.scope.lookup(node.name)
1031         if pxd_def:
1032             if pxd_def.is_cclass:
1033                 return self.visit_CClassDefNode(node.as_cclass(), pxd_def)
1034             else:
1035                 error(node.pos, "'%s' redeclared" % node.name)
1036                 error(pxd_def.pos, "previous declaration here")
1037                 return None
1038         else:
1039             return node
1040         
1041     def visit_CClassDefNode(self, node, pxd_def=None):
1042         if pxd_def is None:
1043             pxd_def = self.scope.lookup(node.class_name)
1044         if pxd_def:
1045             outer_scope = self.scope
1046             self.scope = pxd_def.type.scope
1047         self.visitchildren(node)
1048         if pxd_def:
1049             self.scope = outer_scope
1050         return node
1051         
1052     def visit_DefNode(self, node):
1053         pxd_def = self.scope.lookup(node.name)
1054         if pxd_def:
1055             if self.scope.is_c_class_scope and len(pxd_def.type.args) > 0:
1056                 # The self parameter type needs adjusting.
1057                 pxd_def.type.args[0].type = self.scope.parent_type
1058             if pxd_def.is_cfunction:
1059                 node = node.as_cfunction(pxd_def)
1060             else:
1061                 error(node.pos, "'%s' redeclared" % node.name)
1062                 error(pxd_def.pos, "previous declaration here")
1063                 return None
1064         elif self.scope.is_module_scope and self.directives['auto_cpdef']:
1065             node = node.as_cfunction(scope=self.scope)
1066         # Enable this when internal def functions are allowed. 
1067         # self.visitchildren(node)
1068         return node
1069         
1070
1071 class MarkClosureVisitor(CythonTransform):
1072     
1073     needs_closure = False
1074     
1075     def visit_FuncDefNode(self, node):
1076         self.needs_closure = False
1077         self.visitchildren(node)
1078         node.needs_closure = self.needs_closure
1079         self.needs_closure = True
1080         return node
1081
1082     def visit_LambdaNode(self, node):
1083         self.needs_closure = False
1084         self.visitchildren(node)
1085         node.needs_closure = self.needs_closure
1086         self.needs_closure = True
1087         return node
1088
1089     def visit_ClassDefNode(self, node):
1090         self.visitchildren(node)
1091         self.needs_closure = True
1092         return node
1093         
1094     def visit_YieldNode(self, node):
1095         self.needs_closure = True
1096         
1097 class CreateClosureClasses(CythonTransform):
1098     # Output closure classes in module scope for all functions
1099     # that need it. 
1100     
1101     def visit_ModuleNode(self, node):
1102         self.module_scope = node.scope
1103         self.visitchildren(node)
1104         return node
1105
1106     def create_class_from_scope(self, node, target_module_scope):
1107         as_name = "%s%s" % (Naming.closure_class_prefix, node.entry.cname)
1108         func_scope = node.local_scope
1109
1110         entry = target_module_scope.declare_c_class(name = as_name,
1111             pos = node.pos, defining = True, implementing = True)
1112         func_scope.scope_class = entry
1113         class_scope = entry.type.scope
1114         class_scope.is_internal = True
1115         if node.entry.scope.is_closure_scope:
1116             class_scope.declare_var(pos=node.pos,
1117                                     name=Naming.outer_scope_cname, # this could conflict?
1118                                     cname=Naming.outer_scope_cname,
1119                                     type=node.entry.scope.scope_class.type,
1120                                     is_cdef=True)
1121         for entry in func_scope.entries.values():
1122             # This is wasteful--we should do this later when we know
1123             # which vars are actually being used inside...
1124             cname = entry.cname
1125             class_scope.declare_var(pos=entry.pos,
1126                                     name=entry.name,
1127                                     cname=cname,
1128                                     type=entry.type,
1129                                     is_cdef=True)
1130             
1131     def visit_FuncDefNode(self, node):
1132         if node.needs_closure:
1133             self.create_class_from_scope(node, self.module_scope)
1134             self.visitchildren(node)
1135         return node
1136
1137
1138 class GilCheck(VisitorTransform):
1139     """
1140     Call `node.gil_check(env)` on each node to make sure we hold the
1141     GIL when we need it.  Raise an error when on Python operations
1142     inside a `nogil` environment.
1143     """
1144     def __call__(self, root):
1145         self.env_stack = [root.scope]
1146         self.nogil = False
1147         return super(GilCheck, self).__call__(root)
1148
1149     def visit_FuncDefNode(self, node):
1150         self.env_stack.append(node.local_scope)
1151         was_nogil = self.nogil
1152         self.nogil = node.local_scope.nogil
1153         if self.nogil and node.nogil_check:
1154             node.nogil_check(node.local_scope)
1155         self.visitchildren(node)
1156         self.env_stack.pop()
1157         self.nogil = was_nogil
1158         return node
1159
1160     def visit_GILStatNode(self, node):
1161         env = self.env_stack[-1]
1162         if self.nogil and node.nogil_check: node.nogil_check()
1163         was_nogil = self.nogil
1164         self.nogil = (node.state == 'nogil')
1165         self.visitchildren(node)
1166         self.nogil = was_nogil
1167         return node
1168
1169     def visit_Node(self, node):
1170         if self.env_stack and self.nogil and node.nogil_check:
1171             node.nogil_check(self.env_stack[-1])
1172         self.visitchildren(node)
1173         return node
1174
1175
1176 class TransformBuiltinMethods(EnvTransform):
1177
1178     def visit_SingleAssignmentNode(self, node):
1179         if node.declaration_only:
1180             return None
1181         else:
1182             self.visitchildren(node)
1183             return node
1184     
1185     def visit_AttributeNode(self, node):
1186         self.visitchildren(node)
1187         return self.visit_cython_attribute(node)
1188
1189     def visit_NameNode(self, node):
1190         return self.visit_cython_attribute(node)
1191         
1192     def visit_cython_attribute(self, node):
1193         attribute = node.as_cython_attribute()
1194         if attribute:
1195             if attribute == u'compiled':
1196                 node = BoolNode(node.pos, value=True)
1197             elif attribute == u'NULL':
1198                 node = NullNode(node.pos)
1199             elif not PyrexTypes.parse_basic_type(attribute):
1200                 error(node.pos, u"'%s' not a valid cython attribute or is being used incorrectly" % attribute)
1201         return node
1202
1203     def visit_SimpleCallNode(self, node):
1204
1205         # locals builtin
1206         if isinstance(node.function, ExprNodes.NameNode):
1207             if node.function.name == 'locals':
1208                 lenv = self.env_stack[-1]
1209                 entry = lenv.lookup_here('locals')
1210                 if entry:
1211                     # not the builtin 'locals'
1212                     return node
1213                 if len(node.args) > 0:
1214                     error(self.pos, "Builtin 'locals()' called with wrong number of args, expected 0, got %d" % len(node.args))
1215                     return node
1216                 pos = node.pos
1217                 items = [ExprNodes.DictItemNode(pos, 
1218                                                 key=ExprNodes.StringNode(pos, value=var),
1219                                                 value=ExprNodes.NameNode(pos, name=var)) for var in lenv.entries]
1220                 return ExprNodes.DictNode(pos, key_value_pairs=items)
1221
1222         # cython.foo
1223         function = node.function.as_cython_attribute()
1224         if function:
1225             if function in InterpretCompilerDirectives.unop_method_nodes:
1226                 if len(node.args) != 1:
1227                     error(node.function.pos, u"%s() takes exactly one argument" % function)
1228                 else:
1229                     node = InterpretCompilerDirectives.unop_method_nodes[function](node.function.pos, operand=node.args[0])
1230             elif function == u'cast':
1231                 if len(node.args) != 2:
1232                     error(node.function.pos, u"cast() takes exactly two arguments")
1233                 else:
1234                     type = node.args[0].analyse_as_type(self.env_stack[-1])
1235                     if type:
1236                         node = TypecastNode(node.function.pos, type=type, operand=node.args[1])
1237                     else:
1238                         error(node.args[0].pos, "Not a type")
1239             elif function == u'sizeof':
1240                 if len(node.args) != 1:
1241                     error(node.function.pos, u"sizeof() takes exactly one argument")
1242                 else:
1243                     type = node.args[0].analyse_as_type(self.env_stack[-1])
1244                     if type:
1245                         node = SizeofTypeNode(node.function.pos, arg_type=type)
1246                     else:
1247                         node = SizeofVarNode(node.function.pos, operand=node.args[0])
1248             elif function == 'cmod':
1249                 if len(node.args) != 2:
1250                     error(node.function.pos, u"cmod() takes exactly two arguments")
1251                 else:
1252                     node = binop_node(node.function.pos, '%', node.args[0], node.args[1])
1253                     node.cdivision = True
1254             elif function == 'cdiv':
1255                 if len(node.args) != 2:
1256                     error(node.function.pos, u"cdiv() takes exactly two arguments")
1257                 else:
1258                     node = binop_node(node.function.pos, '/', node.args[0], node.args[1])
1259                     node.cdivision = True
1260             else:
1261                 error(node.function.pos, u"'%s' not a valid cython language construct" % function)
1262         
1263         self.visitchildren(node)
1264         return node