Partial merge of trunk progress. Some tests still fail.
[cython.git] / Cython / Compiler / Optimize.py
index 28ed19ff47d27498c9204f274d041c31b066b360..8552645777555d48887ed6a2c6fb2fc92d336b07 100644 (file)
@@ -74,7 +74,7 @@ class IterationTransform(Visitor.VisitorTransform):
     PyDict_Next_name = EncodedString("PyDict_Next")
 
     PyDict_Next_entry = Symtab.Entry(
-        PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
+        name = PyDict_Next_name, cname = PyDict_Next_name, type = PyDict_Next_func_type)
 
     visit_Node = Visitor.VisitorTransform.recurse_to_children
 
@@ -375,6 +375,9 @@ class IterationTransform(Visitor.VisitorTransform):
                 base=counter_temp,
                 type=Builtin.bytes_type,
                 is_temp=1)
+        elif node.target.type.is_ptr and not node.target.type.assignable_from(ptr_type.base_type):
+            # Allow iteration with pointer target to avoid copy.
+            target_value = counter_temp
         else:
             target_value = ExprNodes.IndexNode(
                 node.target.pos,
@@ -927,7 +930,15 @@ class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
         conds = []
         temps = []
         for arg in args:
-            if not arg.is_simple():
+            try:
+                # Trial optimisation to avoid redundant temp
+                # assignments.  However, since is_simple() is meant to
+                # be called after type analysis, we ignore any errors
+                # and just play safe in that case.
+                is_simple_arg = arg.is_simple()
+            except Exception:
+                is_simple_arg = False
+            if not is_simple_arg:
                 # must evaluate all non-simple RHS before doing the comparisons
                 arg = UtilNodes.LetRefNode(arg)
                 temps.append(arg)
@@ -1167,11 +1178,12 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
             self.yield_nodes = []
 
         visit_Node = Visitor.TreeVisitor.visitchildren
-        def visit_YieldExprNode(self, node):
+        # XXX: disable inlining while it's not back supported
+        def __visit_YieldExprNode(self, node):
             self.yield_nodes.append(node)
             self.visitchildren(node)
 
-        def visit_ExprStatNode(self, node):
+        def __visit_ExprStatNode(self, node):
             self.visitchildren(node)
             if node.expr in self.yield_nodes:
                 self.yield_stat_nodes[node.expr] = node
@@ -1600,6 +1612,15 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
             return node.operand
         return node
 
+    def visit_ExprStatNode(self, node):
+        """
+        Drop useless coercions.
+        """
+        self.visitchildren(node)
+        if isinstance(node.expr, ExprNodes.CoerceToPyTypeNode):
+            node.expr = node.expr.arg
+        return node
+
     def visit_CoerceToBooleanNode(self, node):
         """Drop redundant conversion nodes after tree changes.
         """
@@ -1936,7 +1957,7 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
                 node.pos, cfunc_name, self.PyObject_Size_func_type,
                 args = [arg],
                 is_temp = node.is_temp)
-        elif arg.type is PyrexTypes.c_py_unicode_type:
+        elif arg.type.is_unicode_char:
             return ExprNodes.IntNode(node.pos, value='1', constant_result=1,
                                      type=node.type)
         else:
@@ -1986,22 +2007,29 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
         test_nodes = []
         env = self.current_env()
         for test_type_node in types:
-            if not test_type_node.entry:
-                return node
-            entry = env.lookup(test_type_node.entry.name)
-            if not entry or not entry.type or not entry.type.is_builtin_type:
-                return node
-            type_check_function = entry.type.type_check_function(exact=False)
-            if not type_check_function:
-                return node
-            if type_check_function not in tests:
+            builtin_type = None
+            if isinstance(test_type_node, ExprNodes.NameNode):
+                if test_type_node.entry:
+                    entry = env.lookup(test_type_node.entry.name)
+                    if entry and entry.type and entry.type.is_builtin_type:
+                        builtin_type = entry.type
+            if builtin_type and builtin_type is not Builtin.type_type:
+                type_check_function = entry.type.type_check_function(exact=False)
+                if type_check_function in tests:
+                    continue
                 tests.append(type_check_function)
-                test_nodes.append(
-                    ExprNodes.PythonCapiCallNode(
-                        test_type_node.pos, type_check_function, self.Py_type_check_func_type,
-                        args = [arg],
-                        is_temp = True,
-                        ))
+                type_check_args = [arg]
+            elif test_type_node.type is Builtin.type_type:
+                type_check_function = '__Pyx_TypeCheck'
+                type_check_args = [arg, test_type_node]
+            else:
+                return node
+            test_nodes.append(
+                ExprNodes.PythonCapiCallNode(
+                    test_type_node.pos, type_check_function, self.Py_type_check_func_type,
+                    args = type_check_args,
+                    is_temp = True,
+                    ))
 
         def join_with_or(a,b, make_binop_node=ExprNodes.binop_node):
             or_node = make_binop_node(node.pos, 'or', a, b)
@@ -2021,7 +2049,7 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
             return node
         arg = pos_args[0]
         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
-            if arg.arg.type is PyrexTypes.c_py_unicode_type:
+            if arg.arg.type.is_unicode_char:
                 return arg.arg.coerce_to(node.type, self.current_env())
         return node
 
@@ -2135,7 +2163,7 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
     _handle_simple_method_list_pop = _handle_simple_method_object_pop
 
     single_param_func_type = PyrexTypes.CFuncType(
-        PyrexTypes.c_int_type, [
+        PyrexTypes.c_returncode_type, [
             PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
             ],
         exception_value = "-1")
@@ -2147,7 +2175,7 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
             return node
         return self._substitute_method_call(
             node, "PyList_Sort", self.single_param_func_type,
-            'sort', is_unbound_method, args)
+            'sort', is_unbound_method, args).coerce_to(node.type, self.current_env)
 
     Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType(
         PyrexTypes.py_object_type, [
@@ -2176,7 +2204,7 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
 
     PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType(
         PyrexTypes.c_bint_type, [
-            PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_unicode_type, None),
+            PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
             ])
 
     def _inject_unicode_predicate(self, node, args, is_unbound_method):
@@ -2184,7 +2212,7 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
             return node
         ustring = args[0]
         if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
-               ustring.arg.type is not PyrexTypes.c_py_unicode_type:
+               not ustring.arg.type.is_unicode_char:
             return node
         uchar = ustring.arg
         method_name = node.function.attribute
@@ -2214,8 +2242,8 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
     _handle_simple_method_unicode_isupper   = _inject_unicode_predicate
 
     PyUnicode_uchar_conversion_func_type = PyrexTypes.CFuncType(
-        PyrexTypes.c_py_unicode_type, [
-            PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_unicode_type, None),
+        PyrexTypes.c_py_ucs4_type, [
+            PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
             ])
 
     def _inject_unicode_character_conversion(self, node, args, is_unbound_method):
@@ -2223,7 +2251,7 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
             return node
         ustring = args[0]
         if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
-               ustring.arg.type is not PyrexTypes.c_py_unicode_type:
+               not ustring.arg.type.is_unicode_char:
             return node
         uchar = ustring.arg
         method_name = node.function.attribute
@@ -2695,10 +2723,18 @@ py_unicode_istitle_utility_code = UtilityCode(
 # Py_UNICODE_ISTITLE() doesn't match unicode.istitle() as the latter
 # additionally allows character that comply with Py_UNICODE_ISUPPER()
 proto = '''
+#if PY_VERSION_HEX < 0x030200A2
 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UNICODE uchar); /* proto */
+#else
+static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UCS4 uchar); /* proto */
+#endif
 ''',
 impl = '''
+#if PY_VERSION_HEX < 0x030200A2
 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UNICODE uchar) {
+#else
+static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UCS4 uchar) {
+#endif
     return Py_UNICODE_ISTITLE(uchar) || Py_UNICODE_ISUPPER(uchar);
 }
 ''')
@@ -3102,6 +3138,15 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
         return ExprNodes.BoolNode(node.pos, value=bool_result,
                                   constant_result=bool_result)
 
+    def visit_CondExprNode(self, node):
+        self._calculate_const(node)
+        if node.test.constant_result is ExprNodes.not_a_constant:
+            return node
+        if node.test.constant_result:
+            return node.true_val
+        else:
+            return node.false_val
+
     def visit_IfStatNode(self, node):
         self.visitchildren(node)
         # eliminate dead code based on constant condition results