List and tuple types.
authorRobert Bradshaw <robertwb@math.washington.edu>
Sun, 27 Apr 2008 08:29:00 +0000 (01:29 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Sun, 27 Apr 2008 08:29:00 +0000 (01:29 -0700)
Cython/Compiler/Builtin.py
Cython/Compiler/ExprNodes.py
Cython/Compiler/Nodes.py
Cython/Compiler/PyrexTypes.py
Cython/Compiler/Symtab.py

index f2750cd1311dad4e4c9916dee6de2f9ad0cc2cbd..68f6815b9e0f49e18358cd5785d9b01b305d5be4 100644 (file)
@@ -78,6 +78,28 @@ builtin_function_table = [
 #  type
 #  xrange
 
+builtin_types_table = [
+
+    ("type",    "PyType_Type",     []),
+#    ("str",     "PyString_Type",   []),
+    ("unicode", "PyUnicode_Type",  []),
+    ("file",    "PyFile_Type",     []),
+#    ("slice",   "PySlice_Type",    []),
+#    ("set",     "PySet_Type",      []),
+    ("frozenset", "PyFrozenSet_Type",   []),
+
+    ("tuple",   "PyTuple_Type",    []),
+    
+    ("list",    "PyList_Type",     [("append", "OO",   "i", "PyList_Append"),
+                                    ("insert", "OiO",  "i", "PyList_Insert"),
+                                    ("sort",   "O",    "i", "PyList_Sort"),
+                                    ("reverse","O",    "i", "PyList_Reverse")]),
+                                    
+    ("dict",    "PyDict_Type",     [("items", "O",   "O", "PyDict_Items"),
+                                    ("keys",  "O",   "O", "PyDict_Keys"),
+                                    ("values","O",   "O", "PyDict_Values")]),
+]
+
 getattr3_utility_code = ["""
 static PyObject *__Pyx_GetAttr3(PyObject *, PyObject *, PyObject *); /*proto*/
 ""","""
@@ -112,7 +134,19 @@ def init_builtin_funcs():
     for desc in builtin_function_table:
         declare_builtin_func(*desc)
 
+def init_builtin_types():
+    for name, cname, funcs in builtin_types_table:
+        the_type = builtin_scope.declare_builtin_type(name, cname)
+        for name, args, ret, cname in funcs:
+            sig = Signature(args, ret)
+            the_type.scope.declare_cfunction(name, sig.function_type(), None, cname)
+
 def init_builtins():
     init_builtin_funcs()
-
+    init_builtin_types()
+    global list_type, tuple_type, dict_type
+    list_type  = builtin_scope.lookup('list').type
+    tuple_type = builtin_scope.lookup('tuple').type
+    dict_type  = builtin_scope.lookup('dict').type
+    
 init_builtins()
index 5745f700f72f3de83b32eed2fce0775f7b8dced9..63110a1744269862d1010a37b5cbe3bf394518d9 100644 (file)
@@ -10,6 +10,7 @@ import Naming
 from Nodes import Node
 import PyrexTypes
 from PyrexTypes import py_object_type, c_long_type, typecast, error_type
+from Builtin import list_type, tuple_type, dict_type
 import Symtab
 import Options
 from Annotate import AnnotationItem
@@ -2052,9 +2053,12 @@ class AttributeNode(ExprNode):
         obj_code = obj.result_as(obj.type)
         #print "...obj_code =", obj_code ###
         if self.entry and self.entry.is_cmethod:
-            return "((struct %s *)%s%s%s)->%s" % (
-                obj.type.vtabstruct_cname, obj_code, self.op, 
-                obj.type.vtabslot_cname, self.member)
+            if obj.type.is_extension_type:
+                return "((struct %s *)%s%s%s)->%s" % (
+                    obj.type.vtabstruct_cname, obj_code, self.op, 
+                    obj.type.vtabslot_cname, self.member)
+            else:
+                return self.member
         else:
             return "%s%s%s" % (obj_code, self.op, self.member)
     
@@ -2261,11 +2265,11 @@ class TupleNode(SequenceNode):
     
     def analyse_types(self, env):
         if len(self.args) == 0:
-            self.type = py_object_type
             self.is_temp = 0
             self.is_literal = 1
         else:
             SequenceNode.analyse_types(self, env)
+        self.type = tuple_type
             
     def calculate_result_code(self):
         if len(self.args) > 0:
@@ -2310,6 +2314,10 @@ class TupleNode(SequenceNode):
 class ListNode(SequenceNode):
     #  List constructor.
     
+    def analyse_types(self, env):
+        SequenceNode.analyse_types(self, env)
+        self.type = list_type
+
     def compile_time_value(self, denv):
         return self.compile_time_value_list(denv)
 
@@ -2342,7 +2350,7 @@ class ListComprehensionNode(SequenceNode):
     is_sequence_constructor = 0 # not unpackable
 
     def analyse_types(self, env): 
-        self.type = py_object_type
+        self.type = list_type
         self.is_temp = 1
         self.append.target = self # this is a CloneNode used in the PyList_Append in the inner loop
         
@@ -2401,7 +2409,7 @@ class DictNode(ExprNode):
     def analyse_types(self, env):
         for item in self.key_value_pairs:
             item.analyse_types(env)
-        self.type = py_object_type
+        self.type = dict_type
         self.is_temp = 1
     
     def allocate_temps(self, env, result = None):
@@ -3719,7 +3727,7 @@ class PyTypeTestNode(CoercionNode):
     def __init__(self, arg, dst_type, env):
         #  The arg is know to be a Python object, and
         #  the dst_type is known to be an extension type.
-        assert dst_type.is_extension_type, "PyTypeTest on non extension type"
+        assert dst_type.is_extension_type or dst_type.is_builtin_type, "PyTypeTest on non extension type"
         CoercionNode.__init__(self, arg)
         self.type = dst_type
         self.result_ctype = arg.ctype()
@@ -3740,9 +3748,8 @@ class PyTypeTestNode(CoercionNode):
     def generate_result_code(self, code):
         if self.type.typeobj_is_available():
             code.putln(
-                "if (!__Pyx_TypeTest(%s, %s)) %s" % (
-                    self.arg.py_result(), 
-                    self.type.typeptr_cname,
+                "if (!(%s)) %s" % (
+                    self.type.type_test_code(self.arg.py_result()),
                     code.error_goto(self.pos)))
         else:
             error(self.pos, "Cannot test type of extern C class "
index 6d25751d14b2daf5833b93038786fa13cb1ff438..2b6927f082a5eaf8fd98fd12d31206885e03d399 100644 (file)
@@ -1176,11 +1176,12 @@ class CFuncDefNode(FuncDefNode):
             typeptr_cname = arg.type.typeptr_cname
             arg_code = "((PyObject *)%s)" % arg.cname
             code.putln(
-                'if (unlikely(!__Pyx_ArgTypeTest(%s, %s, %d, "%s"))) %s' % (
+                'if (unlikely(!__Pyx_ArgTypeTest(%s, %s, %d, "%s", %s))) %s' % (
                     arg_code, 
                     typeptr_cname,
                     not arg.not_none,
                     arg.name,
+                    type.is_builtin_type,
                     code.error_goto(arg.pos)))
         else:
             error(arg.pos, "Cannot test type of extern C class "
@@ -1337,7 +1338,8 @@ class DefNode(FuncDefNode):
             if not sig.has_generic_args:
                 self.bad_signature()
             for arg in self.args:
-                if arg.is_generic and arg.type.is_extension_type:
+                if arg.is_generic and \
+                        (arg.type.is_extension_type or arg.type.is_builtin_type):
                     arg.needs_type_test = 1
                     any_type_tests_needed = 1
                 elif arg.type is PyrexTypes.c_py_ssize_t_type:
@@ -1843,11 +1845,12 @@ class DefNode(FuncDefNode):
             typeptr_cname = arg.type.typeptr_cname
             arg_code = "((PyObject *)%s)" % arg.entry.cname
             code.putln(
-                'if (unlikely(!__Pyx_ArgTypeTest(%s, %s, %d, "%s"))) %s' % (
+                'if (unlikely(!__Pyx_ArgTypeTest(%s, %s, %d, "%s", %s))) %s' % (
                     arg_code, 
                     typeptr_cname,
                     not arg.not_none,
                     arg.name,
+                    arg.type.is_builtin_type,
                     code.error_goto(arg.pos)))
         else:
             error(arg.pos, "Cannot test type of extern C class "
@@ -2990,7 +2993,7 @@ class ForInStatNode(LoopNode, StatNode):
             if isinstance(sequence, ExprNodes.SimpleCallNode) \
                   and sequence.self is None \
                   and isinstance(sequence.function, ExprNodes.NameNode) \
-                  and sequence.function.name == 'range':
+                  and (sequence.function.name == 'range' or sequence.function.name == 'xrange'):
                 args = sequence.args
                 # Make sure we can determine direction from step
                 if self.analyse_range_step(args):
@@ -3890,15 +3893,20 @@ static void __Pyx_ReRaise(void) {
 
 arg_type_test_utility_code = [
 """
-static int __Pyx_ArgTypeTest(PyObject *obj, PyTypeObject *type, int none_allowed, char *name); /*proto*/
+static int __Pyx_ArgTypeTest(PyObject *obj, PyTypeObject *type, int none_allowed, char *name, int exact); /*proto*/
 ""","""
-static int __Pyx_ArgTypeTest(PyObject *obj, PyTypeObject *type, int none_allowed, char *name) {
+static int __Pyx_ArgTypeTest(PyObject *obj, PyTypeObject *type, int none_allowed, char *name, int exact) {
     if (!type) {
         PyErr_Format(PyExc_SystemError, "Missing type object");
         return 0;
     }
-    if ((none_allowed && obj == Py_None) || PyObject_TypeCheck(obj, type))
-        return 1;
+    if (none_allowed && obj == Py_None) return 1;
+    else if (exact) {
+        if (PyObject_TypeCheck(obj, type)) return 1;
+    }
+    else {
+        if (obj->ob_type == type) return 1;
+    }
     PyErr_Format(PyExc_TypeError,
         "Argument '%s' has incorrect type (expected %s, got %s)",
         name, type->tp_name, obj->ob_type->tp_name);
index d7427d4cbfc603b2be780f809ee46993cef2c5c8..9019f4a9fe764f2a7d2bf4656cf7073706628759 100644 (file)
@@ -72,6 +72,7 @@ class PyrexType(BaseType):
         
     is_pyobject = 0
     is_extension_type = 0
+    is_builtin_type = 0
     is_numeric = 0
     is_int = 0
     is_float = 0
@@ -210,6 +211,48 @@ class PyObjectType(PyrexType):
             return "%s *%s" % (public_decl("PyObject", dll_linkage), entity_code)
 
 
+class BuiltinObjectType(PyObjectType):
+
+    is_builtin_type = 1
+    has_attributes = 1
+    base_type = None
+    module_name = '__builtin__'
+
+    def __init__(self, name, cname):
+        self.name = name
+        self.cname = cname
+        self.typeptr_cname = "&" + cname
+                                 
+    def set_scope(self, scope):
+        self.scope = scope
+        if scope:
+            scope.parent_type = self
+        
+    def __str__(self):
+        return "%s object" % self.name
+    
+    def __repr__(self):
+        return "<%s>"% self.cname
+        
+    def assignable_from(self, src_type):
+        if isinstance(src_type, BuiltinObjectType):
+            return src_type.name == self.name
+        else:
+            return not src_type.is_extension_type
+            
+    def typeobj_is_available(self):
+        return True
+        
+    def attributes_known(self):
+        return True
+        
+    def subtype_of(self, type):
+        return type.is_pyobject and self.assignable_from(type)
+        
+    def type_test_code(self, arg):
+        return 'likely(Py%s_CheckExact(%s)) || (%s) == Py_None || (PyErr_Format(PyExc_TypeError, "Expected %s, got %%s", %s->ob_type->tp_name), 0)' % (self.name[0].upper() + self.name[1:], arg, arg, self.name, arg)
+
+
 class PyExtensionType(PyObjectType):
     #
     #  A Python extension type.
@@ -281,6 +324,9 @@ class PyExtensionType(PyObjectType):
             else:
                 return "%s *%s" % (base,  entity_code)
 
+    def type_test_code(self, py_arg):
+        return "__Pyx_TypeTest(%s, %s)" % (py_arg, self.typeptr_cname)
+    
     def attributes_known(self):
         return self.scope is not None
     
index 510d8df50735b866a7bb28664022bc279581e63b..456445ec152a3f2b1446195fe9f3d8d17034ab74 100644 (file)
@@ -588,6 +588,7 @@ class BuiltinScope(Scope):
             Scope.__init__(self, "__builtin__", None, None)
         else:
             Scope.__init__(self, "__builtin__", PreImportScope(), None)
+        self.type_names = {}
         
         for name, definition in self.builtin_entries.iteritems():
             cname, type = definition
@@ -614,6 +615,23 @@ class BuiltinScope(Scope):
             var_entry.is_builtin = 1
             entry.as_variable = var_entry
         return entry
+        
+    def declare_builtin_type(self, name, cname):
+        type = PyrexTypes.BuiltinObjectType(name, cname)
+        type.set_scope(CClassScope(name, outer_scope=None, visibility='extern'))
+        self.type_names[name] = 1
+        entry = self.declare_type(name, type, None, visibility='extern')
+
+        var_entry = Entry(name = entry.name,
+            type = py_object_type,
+            pos = entry.pos,
+            cname = "((PyObject*)%s)" % entry.type.typeptr_cname)
+        var_entry.is_variable = 1
+        var_entry.is_cglobal = 1
+        var_entry.is_readonly = 1
+        entry.as_variable = var_entry
+
+        return type
 
     def builtin_scope(self):
         return self
@@ -684,7 +702,7 @@ class ModuleScope(Scope):
         self.module_entries = {}
         self.python_include_files = ["Python.h", "structmember.h"]
         self.include_files = []
-        self.type_names = {}
+        self.type_names = dict(outer_scope.type_names)
         self.pxd_file_loaded = 0
         self.cimported_modules = []
         self.intern_map = {}