cython.locals(...) decorator for pure python type declarations.
authorRobert Bradshaw <robertwb@math.washington.edu>
Wed, 1 Oct 2008 11:43:21 +0000 (04:43 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Wed, 1 Oct 2008 11:43:21 +0000 (04:43 -0700)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Nodes.py
Cython/Compiler/Options.py
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/PyrexTypes.py
Cython/Compiler/Scanning.py
Cython/Compiler/TreeFragment.py
Cython/Plex/Scanners.py
Cython/Shadow.py [new file with mode: 0644]
Cython/__init__.py
cython.py

index d7ea1d115fba348b8f041a4e0ec0da4dba6ac418..a5685b1aa9782cde67da055ff3bcf1ac5bc5c5f1 100644 (file)
@@ -713,6 +713,15 @@ class StringNode(ConstNode):
     
     def analyse_types(self, env):
         self.entry = env.add_string_const(self.value)
+        
+    def analyse_as_type(self, env):
+        from TreeFragment import TreeFragment
+        pos = (self.pos[0], self.pos[1], self.pos[2]-7)
+        declaration = TreeFragment(u"sizeof(%s)" % self.value, name=pos[0].filename, initial_pos=pos)
+        sizeof_node = declaration.root.stats[0].expr
+        sizeof_node.analyse_types(env)
+        if isinstance(sizeof_node, SizeofTypeNode):
+            return sizeof_node.arg_type
     
     def coerce_to(self, dst_type, env):
         if dst_type.is_int:
@@ -886,6 +895,8 @@ class NameNode(AtomicExprNode):
         return None
         
     def analyse_as_type(self, env):
+        if self.name in PyrexTypes.rank_to_type_name:
+            return PyrexTypes.simple_c_type(1, 0, self.name)
         entry = self.entry
         if not entry:
             entry = env.lookup(self.name)
@@ -2767,6 +2778,9 @@ class DictItemNode(ExprNode):
     def generate_disposal_code(self, code):
         self.key.generate_disposal_code(code)
         self.value.generate_disposal_code(code)
+        
+    def __iter__(self):
+        return iter([self.key, self.value])
 
 
 class ClassNode(ExprNode):
index 448aa677cf81e71e0ccdb9904f1ba40fb39ce1ae..b815ec11f1a6ca3c578e6cea42dca79c8596939c 100644 (file)
@@ -1166,11 +1166,16 @@ class CFuncDefNode(FuncDefNode):
     #  overridable   whether or not this is a cpdef function
     
     child_attrs = ["base_type", "declarator", "body", "py_func"]
-
+    
     def unqualified_name(self):
         return self.entry.name
         
     def analyse_declarations(self, env):
+        if 'locals' in env.directives:
+            directive_locals = env.directives['locals']
+        else:
+            directive_locals = {}
+        self.directive_locals = directive_locals
         base_type = self.base_type.analyse(env)
         # The 2 here is because we need both function and argument names. 
         name_declarator, type = self.declarator.analyse(base_type, env, nonempty = 2 * (self.body is not None))
@@ -1442,11 +1447,27 @@ class DefNode(FuncDefNode):
     entry = None
     
     def analyse_declarations(self, env):
+        if 'locals' in env.directives:
+            directive_locals = env.directives['locals']
+        else:
+            directive_locals = {}
+        self.directive_locals = directive_locals
         for arg in self.args:
             base_type = arg.base_type.analyse(env)
             name_declarator, type = \
                 arg.declarator.analyse(base_type, env)
             arg.name = name_declarator.name
+            if arg.name in directive_locals:
+                type_node = directive_locals[arg.name]
+                other_type = type_node.analyse_as_type(env)
+                if other_type is None:
+                    error(type_node.pos, "Not a type")
+                elif (type is not PyrexTypes.py_object_type 
+                        and not type.same_as(other_type)):
+                    error(arg.base_type.pos, "Signature does not agree with previous declaration")
+                    error(type_node.pos, "Previous declaration here")
+                else:
+                    type = other_type
             if name_declarator.cname:
                 error(self.pos,
                     "Python function argument cannot have C name specification")
index dd26a14a60ecac9cbd3cc8ad4af9c3873aebaef3..fabd6e6dd8deb8c40a8b237736fa36f8b00e66af 100644 (file)
@@ -58,13 +58,15 @@ c_line_in_traceback = 1
 option_types = {
     'boundscheck' : bool,
     'nonecheck' : bool,
-    'embedsignature' : bool
+    'embedsignature' : bool,
+    'locals' : dict,
 }
 
 option_defaults = {
     'boundscheck' : True,
     'nonecheck' : False,
     'embedsignature' : False,
+    'locals' : {}
 }
 
 def parse_option_value(name, value):
index 825265d54ceed34db57ebaca9915a77c383923ca..ee3d77254a2ea5a8bf586c7989bd6bcb9ca06e2e 100644 (file)
@@ -308,6 +308,14 @@ class InterpretCompilerDirectives(CythonTransform):
                     newimp.append((pos, name, as_name, kind))
             node.imported_names = newimpo
         return node
+        
+    def visit_SingleAssignmentNode(self, node):
+        if (isinstance(node.rhs, ImportNode) and
+                node.rhs.module_name.value == u'cython'):
+            self.cython_module_names.add(node.lhs.name)
+        else:
+            self.visitchildren(node)
+            return node
 
     def visit_Node(self, node):
         self.visitchildren(node)
@@ -318,7 +326,7 @@ class InterpretCompilerDirectives(CythonTransform):
         # decorator), returns (optionname, value).
         # Otherwise, returns None
         optname = None
-        if isinstance(node, SimpleCallNode):
+        if isinstance(node, CallNode):
             if (isinstance(node.function, AttributeNode) and
                   isinstance(node.function.obj, NameNode) and
                   node.function.obj.name in self.cython_module_names):
@@ -330,12 +338,25 @@ class InterpretCompilerDirectives(CythonTransform):
         if optname:
             optiontype = Options.option_types.get(optname)
             if optiontype:
-                args = node.args
+                if isinstance(node, SimpleCallNode):
+                    args = node.args
+                    kwds = None
+                else:
+                    if node.starstar_arg or not isinstance(node.positional_args, TupleNode):
+                        raise PostParseError(dec.function.pos,
+                            'Compile-time keyword arguments must be explicit.' % optname)
+                    args = node.positional_args.args
+                    kwds = node.keyword_args
                 if optiontype is bool:
-                    if len(args) != 1 or not isinstance(args[0], BoolNode):
+                    if kwds is not None or len(args) != 1 or not isinstance(args[0], BoolNode):
                         raise PostParseError(dec.function.pos,
                             'The %s option takes one compile-time boolean argument' % optname)
                     return (optname, args[0].value)
+                elif optiontype is dict:
+                    if len(args) != 0:
+                        raise PostParseError(dec.function.pos,
+                            'The %s option takes no prepositional arguments' % optname)
+                    return optname, dict([(key.value, value) for key, value in kwds.key_value_pairs])
                 else:
                     assert False
 
@@ -367,7 +388,7 @@ class InterpretCompilerDirectives(CythonTransform):
                 else:
                     realdecs.append(dec)
             node.decorators = realdecs
-
+        
         if options:
             optdict = {}
             options.reverse() # Decorators coming first take precedence
@@ -499,12 +520,19 @@ property NAME:
         lenv = node.create_local_scope(self.env_stack[-1])
         node.body.analyse_control_flow(lenv) # this will be totally refactored
         node.declare_arguments(lenv)
+        for var, type_node in node.directive_locals.items():
+            if not lenv.lookup_here(var):   # don't redeclare args
+                type = type_node.analyse_as_type(lenv)
+                if type:
+                    lenv.declare_var(var, type, type_node.pos)
+                else:
+                    error(type_node.pos, "Not a type")
         node.body.analyse_declarations(lenv)
         self.env_stack.append(lenv)
         self.visitchildren(node)
         self.env_stack.pop()
         return node
-        
+    
     # Some nodes are no longer needed after declaration
     # analysis and can be dropped. The analysis was performed
     # on these nodes in a seperate recursive process from the
index ef8d13cd65ad9a184e70d40168c8ac0ce2eb526b..8175b7441211378e35de9f4bed109231cfe8a3e4 100644 (file)
@@ -1168,6 +1168,7 @@ modifiers_and_name_to_type = {
     (1, 0, "int"): c_int_type, 
     (1, 1, "int"): c_long_type,
     (1, 2, "int"): c_longlong_type,
+    (1, 0, "long"): c_long_type,
     (1, 0, "Py_ssize_t"): c_py_ssize_t_type,
     (1, 0, "float"): c_float_type, 
     (1, 0, "double"): c_double_type,
@@ -1216,6 +1217,19 @@ def c_ptr_type(base_type):
         return c_char_ptr_type
     else:
         return CPtrType(base_type)
+        
+def Node_to_type(node, env):
+    from ExprNodes import NameNode, AttributeNode, StringNode, error
+    if isinstance(node, StringNode):
+        node = NameNode(node.pos, name=node.value)
+    if isinstance(node, NameNode) and node.name in rank_to_type_name:
+        return simple_c_type(1, 0, node.name)
+    elif isinstance(node, (AttributeNode, NameNode)):
+        node.analyze_types(env)
+        if not node.entry.is_type:
+            pass
+    else:
+        error(node.pos, "Bad type")
 
 def public_decl(base, dll_linkage):
     if dll_linkage:
index 810ba07e04b8ea6057e6bf1f022cb8c125f42e98..004a8921010ba4753eb0883659167bf53c5ddf79 100644 (file)
@@ -289,8 +289,8 @@ class PyrexScanner(Scanner):
     resword_dict = build_resword_dict()
 
     def __init__(self, file, filename, parent_scanner = None, 
-                 scope = None, context = None, source_encoding=None, parse_comments=True):
-        Scanner.__init__(self, get_lexicon(), file, filename)
+                 scope = None, context = None, source_encoding=None, parse_comments=True, initial_pos=None):
+        Scanner.__init__(self, get_lexicon(), file, filename, initial_pos)
         if parent_scanner:
             self.context = parent_scanner.context
             self.included_files = parent_scanner.included_files
index d1b1d574e8220f4dab70dffec3f691b3bb19d991..4e7d575cd88d99d2a927da1b9bcd8b9ce5622335 100644 (file)
@@ -29,7 +29,7 @@ class StringParseContext(Main.Context):
             raise AssertionError("Not yet supporting any cimports/includes from string code snippets")
         return ModuleScope(module_name, parent_module = None, context = self)
         
-def parse_from_strings(name, code, pxds={}, level=None):
+def parse_from_strings(name, code, pxds={}, level=None, initial_pos=None):
     """
     Utility method to parse a (unicode) string of code. This is mostly
     used for internal Cython compiler purposes (creating code snippets
@@ -47,7 +47,8 @@ def parse_from_strings(name, code, pxds={}, level=None):
     encoding = "UTF-8"
 
     module_name = name
-    initial_pos = (name, 1, 0)
+    if initial_pos is None:
+        initial_pos = (name, 1, 0)
     code_source = StringSourceDescriptor(name, code)
 
     context = StringParseContext([], name)
@@ -56,7 +57,7 @@ def parse_from_strings(name, code, pxds={}, level=None):
     buf = StringIO(code.encode(encoding))
 
     scanner = PyrexScanner(buf, code_source, source_encoding = encoding,
-                     scope = scope, context = context)
+                     scope = scope, context = context, initial_pos = initial_pos)
     if level is None:
         tree = Parsing.p_module(scanner, 0, module_name)
     else:
@@ -181,7 +182,7 @@ def strip_common_indent(lines):
     return lines
     
 class TreeFragment(object):
-    def __init__(self, code, name="(tree fragment)", pxds={}, temps=[], pipeline=[], level=None):
+    def __init__(self, code, name="(tree fragment)", pxds={}, temps=[], pipeline=[], level=None, initial_pos=None):
         if isinstance(code, unicode):
             def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n"))) 
             
@@ -189,8 +190,7 @@ class TreeFragment(object):
             fmt_pxds = {}
             for key, value in pxds.iteritems():
                 fmt_pxds[key] = fmt(value)
-                
-            mod = t = parse_from_strings(name, fmt_code, fmt_pxds, level=level)
+            mod = t = parse_from_strings(name, fmt_code, fmt_pxds, level=level, initial_pos=initial_pos)
             if level is None:
                 t = t.body # Make sure a StatListNode is at the top
             if not isinstance(t, StatListNode):
index acf4b83bc13973d698c0176cf91f8af3a630060e..3a7e8ddc0f06df637ecdb72db2d567d54e4a4e4b 100644 (file)
@@ -60,7 +60,7 @@ class Scanner:
   queue = None          # list of tokens to be returned
   trace = 0
 
-  def __init__(self, lexicon, stream, name = ''):
+  def __init__(self, lexicon, stream, name = '', initial_pos = None):
     """
     Scanner(lexicon, stream, name = '')
 
@@ -84,6 +84,8 @@ class Scanner:
     self.cur_line_start = 0
     self.cur_char = BOL
     self.input_state = 1
+    if initial_pos is not None:
+        self.cur_line, self.cur_line_start = initial_pos[1], -initial_pos[2]
 
   def read(self):
     """
diff --git a/Cython/Shadow.py b/Cython/Shadow.py
new file mode 100644 (file)
index 0000000..976e3f8
--- /dev/null
@@ -0,0 +1,16 @@
+def empty_decorator(x):
+    return x
+
+def locals(**arg_types):
+    return empty_decorator
+
+def cast(type, arg):
+    # can/should we emulate anything here?
+    return arg
+
+py_int = int
+py_long = long
+py_float = float
+
+# They just have to exist...
+int = long = char = bint = uint = ulong = longlong = ulonglong = Py_ssize_t = float = double = None
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..5ef285c09faad7113958951caa39efc6d5d697b3 100644 (file)
@@ -0,0 +1,2 @@
+# Void cython.* directives (for case insensitive operating systems). 
+from Shadow import *
index f96c577906bdc6557f3c99d54a15a88085add1e9..c84a8003c83f7184dc761f1a30f3e871a4b58f42 100644 (file)
--- a/cython.py
+++ b/cython.py
@@ -2,5 +2,11 @@
 #   Cython -- Main Program, generic
 #
 
-from Cython.Compiler.Main import main
-main(command_line = 1)
+if __name__ == '__main__':
+
+    from Cython.Compiler.Main import main
+    main(command_line = 1)
+
+else:
+    # Void cython.* directives. 
+    from Cython.Shadow import *