Implement "from module [c]import *", some more work on sequence indexing.
authorRobert Bradshaw <robertwb@math.washington.edu>
Thu, 29 May 2008 01:09:49 +0000 (18:09 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Thu, 29 May 2008 01:09:49 +0000 (18:09 -0700)
Cython/Compiler/ExprNodes.py
Cython/Compiler/ModuleNode.py
Cython/Compiler/Naming.py
Cython/Compiler/Nodes.py
Cython/Compiler/Parsing.py
Cython/Compiler/Symtab.py

index c10c1fa6f0f34a469096b4375b0677bc46aeb3bc..a614262b0eb4876a4fcf066562d88f5bb7495756 100644 (file)
@@ -1290,12 +1290,10 @@ class IndexNode(ExprNode):
             self.base.result_code, self.index.result_code)
 
     def generate_subexpr_evaluation_code(self, code):
-        # do not evaluate self.py_index in case we don't need it
         self.base.generate_evaluation_code(code)
         self.index.generate_evaluation_code(code)
         
     def generate_subexpr_disposal_code(self, code):
-        # if we used self.py_index, it will be disposed of manually
         self.base.generate_disposal_code(code)
         self.index.generate_disposal_code(code)
 
@@ -1344,13 +1342,20 @@ class IndexNode(ExprNode):
     
     def generate_deletion_code(self, code):
         self.generate_subexpr_evaluation_code(code)
-        self.py_index.generate_evaluation_code(code)
-        code.put_error_if_neg(self.pos, 
-            "PyObject_DelItem(%s, %s)" % (
+        #if self.type.is_pyobject:
+        if self.index.type.is_int:
+            function = "PySequence_DelItem"
+            index_code = self.index.result_code
+        else:
+            function = "PyObject_DelItem"
+            index_code = self.index.py_result()
+        code.putln(
+            "if (%s(%s, %s) < 0) %s" % (
+                function,
                 self.base.py_result(),
-                self.py_index.py_result()))
+                index_code,
+                code.error_goto(self.pos)))
         self.generate_subexpr_disposal_code(code)
-        self.py_index.generate_disposal_code(code)
 
 
 class SliceIndexNode(ExprNode):
index c86204fb5902f5226eb49e1acdeae5eacf003d4e..af059c09da4138a5d56fb2389f202ee1899349cb 100644 (file)
@@ -230,6 +230,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
         self.generate_typeobj_definitions(env, code)
         self.generate_method_table(env, code)
         self.generate_filename_init_prototype(code)
+        if env.has_import_star:
+            self.generate_import_star(env, code)
         self.generate_module_init_func(modules[:-1], env, code)
         code.mark_pos(None)
         self.generate_module_cleanup_func(env, code)
@@ -1429,6 +1431,66 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
     def generate_filename_init_prototype(self, code):
         code.putln("");
         code.putln("static void %s(void); /*proto*/" % Naming.fileinit_cname)
+        
+    def generate_import_star(self, env, code):
+        code.putln()
+        code.putln("char* %s_type_names[] = {" % Naming.import_star)
+        for name, entry in env.entries.items():
+            if entry.is_type:
+                code.putln('"%s",' % name)
+        code.putln("0")
+        code.putln("};")
+        code.putln()
+        code.putln("static int %s(PyObject *o, PyObject* py_name, char *name) {" % Naming.import_star_set)
+        code.putln("char** type_name = %s_type_names;" % Naming.import_star)
+        code.putln("while (*type_name) {")
+        code.putln("if (!strcmp(name, *type_name)) {")
+        code.putln('PyErr_Format(PyExc_TypeError, "Cannot overwrite C type %s", name);')
+        code.putln('goto bad;')
+        code.putln("}")
+        code.putln("type_name++;")
+        code.putln("}")
+        old_error_label = code.new_error_label()
+        code.putln("if (0);") # so the first one can be "else if"
+        for name, entry in env.entries.items():
+            if entry.is_cglobal and entry.used:
+                code.putln('else if (!strcmp(name, "%s")) {' % name)
+                if entry.type.is_pyobject:
+                    if entry.type.is_extension_type or entry.type.is_builtin_type:
+                        code.putln("if (!(%s)) %s;" % (
+                            entry.type.type_test_code("o"),
+                            code.error_goto(entry.pos)))
+                    code.put_var_decref(entry)
+                    code.putln("%s = %s;" % (
+                        entry.cname, 
+                        PyrexTypes.typecast(entry.type, py_object_type, "o")))
+                elif entry.type.from_py_function:
+                    rhs = "%s(o)" % entry.type.from_py_function
+                    if entry.type.is_enum:
+                        rhs = typecast(entry.type, c_long_type, rhs)
+                    code.putln("%s = %s; if (%s) %s;" % (
+                        entry.cname,
+                        rhs,
+                        entry.type.error_condition(entry.cname),
+                        code.error_goto(entry.pos)))
+                    code.putln("Py_DECREF(o);")
+                else:
+                    code.putln('PyErr_Format(PyExc_TypeError, "Cannot convert Python object %s to %s");' % (name, entry.type))
+                    code.putln(code.error_goto(entry.pos))
+                code.putln("}")
+        code.putln("else {")
+        code.putln("if (PyObject_SetAttr(%s, py_name, o) < 0) goto bad;" % Naming.module_cname)
+        code.putln("}")
+        code.putln("return 0;")
+        code.put_label(code.error_label)
+        # This helps locate the offending name.
+        code.putln('__Pyx_AddTraceback("%s");' % self.full_module_name);
+        code.error_label = old_error_label
+        code.putln("bad:")
+        code.putln("Py_DECREF(o);")
+        code.putln("return -1;")
+        code.putln("}")
+        code.putln(import_star_utility_code)
 
     def generate_module_init_func(self, imported_modules, env, code):
         code.putln("")
@@ -2019,3 +2081,94 @@ bad:
     return ret;
 }
 """]
+
+import_star_utility_code = """
+
+/* import_all_from is an unexposed function from ceval.c */
+
+static int
+__Pyx_import_all_from(PyObject *locals, PyObject *v)
+{
+       PyObject *all = PyObject_GetAttrString(v, "__all__");
+       PyObject *dict, *name, *value;
+       int skip_leading_underscores = 0;
+       int pos, err;
+
+       if (all == NULL) {
+               if (!PyErr_ExceptionMatches(PyExc_AttributeError))
+                       return -1; /* Unexpected error */
+               PyErr_Clear();
+               dict = PyObject_GetAttrString(v, "__dict__");
+               if (dict == NULL) {
+                       if (!PyErr_ExceptionMatches(PyExc_AttributeError))
+                               return -1;
+                       PyErr_SetString(PyExc_ImportError,
+                       "from-import-* object has no __dict__ and no __all__");
+                       return -1;
+               }
+               all = PyMapping_Keys(dict);
+               Py_DECREF(dict);
+               if (all == NULL)
+                       return -1;
+               skip_leading_underscores = 1;
+       }
+
+       for (pos = 0, err = 0; ; pos++) {
+               name = PySequence_GetItem(all, pos);
+               if (name == NULL) {
+                       if (!PyErr_ExceptionMatches(PyExc_IndexError))
+                               err = -1;
+                       else
+                               PyErr_Clear();
+                       break;
+               }
+               if (skip_leading_underscores &&
+                   PyString_Check(name) &&
+                   PyString_AS_STRING(name)[0] == '_')
+               {
+                       Py_DECREF(name);
+                       continue;
+               }
+               value = PyObject_GetAttr(v, name);
+               if (value == NULL)
+                       err = -1;
+               else if (PyDict_CheckExact(locals))
+                       err = PyDict_SetItem(locals, name, value);
+               else
+                       err = PyObject_SetItem(locals, name, value);
+               Py_DECREF(name);
+               Py_XDECREF(value);
+               if (err != 0)
+                       break;
+       }
+       Py_DECREF(all);
+       return err;
+}
+
+
+static int %s(PyObject* m) {
+
+    int i;
+    int ret = -1;
+    PyObject *locals = 0;
+    PyObject *list = 0;
+    PyObject *name;
+    PyObject *item;
+    
+    locals = PyDict_New();              if (!locals) goto bad;
+    if (__Pyx_import_all_from(locals, m) < 0) goto bad;
+    list = PyDict_Items(locals);        if (!list) goto bad;
+    
+    for(i=0; i<PyList_GET_SIZE(list); i++) {
+        name = PyTuple_GET_ITEM(PyList_GET_ITEM(list, i), 0);
+        item = PyTuple_GET_ITEM(PyList_GET_ITEM(list, i), 1);
+        if (%s(item, name, PyString_AsString(name)) < 0) goto bad;
+    }
+    ret = 0;
+    
+bad:
+    Py_XDECREF(locals);
+    Py_XDECREF(list);
+    return ret;
+}
+""" % ( Naming.import_star, Naming.import_star_set )
index 6b74e752e7f7d3804d5699547a0b87e535f64135..c24a3f7c5124b3c5117401458d2f4dfceffe2947 100644 (file)
@@ -67,6 +67,8 @@ print_function_kwargs   = pyrex_prefix + "print_kwargs"
 cleanup_cname    = pyrex_prefix + "module_cleanup"
 optional_args_cname = pyrex_prefix + "optional_args"
 no_opt_args      = pyrex_prefix + "no_opt_args"
+import_star      = pyrex_prefix + "import_star"
+import_star_set  = pyrex_prefix + "import_star_set"
 
 line_c_macro = "__LINE__"
 
index 8669729ae28c02a39e5b655eb93d708cd94e2108..ad82433ef610a7fbb020e4f0f0f0fc0eeba6cc3a 100644 (file)
@@ -3665,10 +3665,14 @@ class FromCImportStatNode(StatNode):
         module_scope = env.find_module(self.module_name, self.pos)
         env.add_imported_module(module_scope)
         for pos, name, as_name in self.imported_names:
-            entry = module_scope.find(name, pos)
-            if entry:
-                local_name = as_name or name
-                env.add_imported_entry(local_name, entry, pos)
+            if name == "*":
+                for local_name, entry in module_scope.entries.items():
+                    env.add_imported_entry(local_name, entry, pos)
+            else:
+                entry = module_scope.find(name, pos)
+                if entry:
+                    local_name = as_name or name
+                    env.add_imported_entry(local_name, entry, pos)
 
     def analyse_expressions(self, env):
         pass
@@ -3684,12 +3688,21 @@ class FromImportStatNode(StatNode):
     #  items            [(string, NameNode)]
     #  interned_items   [(string, NameNode)]
     #  item             PyTempNode            used internally
+    #  import_star      boolean               used internally
 
     child_attrs = ["module"]
+    import_star = 0
     
     def analyse_declarations(self, env):
-        for _, target in self.items:
-            target.analyse_target_declaration(env)
+        for name, target in self.items:
+            if name == "*":
+                if not env.is_module_scope:
+                    error(self.pos, "import * only allowed at module level")
+                    return
+                env.has_import_star = 1
+                self.import_star = 1
+            else:
+                target.analyse_target_declaration(env)
     
     def analyse_expressions(self, env):
         import ExprNodes
@@ -3698,15 +3711,27 @@ class FromImportStatNode(StatNode):
         self.item.allocate_temp(env)
         self.interned_items = []
         for name, target in self.items:
-            self.interned_items.append(
-                (env.intern_identifier(name), target))
-            target.analyse_target_expression(env, None)
-            #target.release_target_temp(env) # was release_temp ?!?
+            if name == '*':
+                for _, entry in env.entries.items():
+                    if not entry.is_type and entry.type.is_extension_type:
+                        env.use_utility_code(ExprNodes.type_test_utility_code)
+                        break
+            else:
+                self.interned_items.append(
+                    (env.intern_identifier(name), target))
+                target.analyse_target_expression(env, None)
+                #target.release_target_temp(env) # was release_temp ?!?
         self.module.release_temp(env)
         self.item.release_temp(env)
     
     def generate_execution_code(self, code):
         self.module.generate_evaluation_code(code)
+        if self.import_star:
+            code.putln(
+                'if (%s(%s) < 0) %s;' % (
+                    Naming.import_star,
+                    self.module.py_result(),
+                    code.error_goto(self.pos)))
         for cname, target in self.interned_items:
             code.putln(
                 '%s = PyObject_GetAttr(%s, %s); %s' % (
index b7b2a88cd7bcdb34ccce2c35f3b46c67b823d686..ce01a9de82acb7ad4eb07c392629fbb06a8ce39a 100644 (file)
@@ -944,8 +944,11 @@ def p_from_import_statement(s, first_statement = 0):
     else:
         s.error("Expected 'import' or 'cimport'")
     if s.sy == '*':
-        s.error("'import *' not supported")
-    imported_names = [p_imported_name(s)]
+#        s.error("'import *' not supported")
+        imported_names = [(s.position(), "*", None)]
+        s.next()
+    else:
+        imported_names = [p_imported_name(s)]
     while s.sy == ',':
         s.next()
         imported_names.append(p_imported_name(s))
index 08edcff6fe3984a74e55968d6b6586c5c43ccfdf..1fa1ede66ca6c74d953af350c783fab3e6cffa5f 100644 (file)
@@ -693,8 +693,10 @@ class ModuleScope(Scope):
     # interned_nums        [int/long]         Interned numeric constants
     # all_pystring_entries [Entry]            Python string consts from all scopes
     # types_imported       {PyrexType : 1}    Set of types for which import code generated
+    # has_import_star      boolean            Module contains import *
     
     is_module_scope = 1
+    has_import_star = 0
 
     def __init__(self, name, parent_module, context):
         self.parent_module = parent_module
@@ -734,7 +736,10 @@ class ModuleScope(Scope):
     
     def declare_builtin(self, name, pos):
         if not hasattr(__builtin__, name):
-            if self.outer_scope is not None:
+            if self.has_import_star:
+                entry = self.declare_var(name, py_object_type, pos)
+                return entry
+            elif self.outer_scope is not None:
                 return self.outer_scope.declare_builtin(name, pos)
             else:
                 error(pos, "undeclared name not builtin: %s"%name)