Type inference methods for expression nodes.
authorRobert Bradshaw <robertwb@math.washington.edu>
Sun, 4 Oct 2009 05:37:02 +0000 (22:37 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Sun, 4 Oct 2009 05:37:02 +0000 (22:37 -0700)
Cython/Compiler/ExprNodes.py

index 31f76d48299d1f3e02b3c32eef05ee5cc450018b..24c9287e514ab38ec3c57c70c7e068c2c35b6e0d 100644 (file)
@@ -307,6 +307,27 @@ class ExprNode(Node):
         temp_bool = bool.coerce_to_temp(env)
         return temp_bool
     
+    # --------------- Type Inference -----------------
+    
+    def type_dependencies(self):
+        # Returns the list of entries whose types must be determined
+        # before the type of self can be infered.
+        if hasattr(self, 'type') and self.type is not None:
+            return ()
+        return sum([node.type_dependencies() for node in self.subexpr_nodes()], ())
+    
+    def infer_type(self, env):
+        # Attempt to deduce the type of self. 
+        # Differs from analyse_types as it avoids unnecessary 
+        # analysis of subexpressions, but can assume everything
+        # in self.type_dependencies() has been resolved.
+        if hasattr(self, 'type') and self.type is not None:
+            return self.type
+        elif hasattr(self, 'entry') and self.entry is not None:
+            return self.entry.type
+        else:
+            self.not_implemented("infer_type")
+    
     # --------------- Type Analysis ------------------
     
     def analyse_as_module(self, env):
@@ -858,6 +879,8 @@ class LongNode(AtomicExprNode):
     #
     #  value   string
 
+    type = py_object_type
+
     def calculate_constant_result(self):
         self.constant_result = long(self.value)
     
@@ -865,7 +888,6 @@ class LongNode(AtomicExprNode):
         return long(self.value)
     
     def analyse_types(self, env):
-        self.type = py_object_type
         self.is_temp = 1
 
     gil_message = "Constructing Python long int"
@@ -954,6 +976,9 @@ class NameNode(AtomicExprNode):
     
     create_analysed_rvalue = staticmethod(create_analysed_rvalue)
     
+    def type_dependencies(self):
+        return self.entry
+    
     def compile_time_value(self, denv):
         try:
             return denv.lookup(self.name)
@@ -1298,12 +1323,13 @@ class BackquoteNode(ExprNode):
     #
     #  arg    ExprNode
     
+    type = py_object_type
+    
     subexprs = ['arg']
     
     def analyse_types(self, env):
         self.arg.analyse_types(env)
         self.arg = self.arg.coerce_to_pyobject(env)
-        self.type = py_object_type
         self.is_temp = 1
 
     gil_message = "Backquote expression"
@@ -1329,15 +1355,16 @@ class ImportNode(ExprNode):
     #  module_name   IdentifierStringNode     dotted name of module
     #  name_list     ListNode or None         list of names to be imported
     
+    type = py_object_type
+    
     subexprs = ['module_name', 'name_list']
-
+    
     def analyse_types(self, env):
         self.module_name.analyse_types(env)
         self.module_name = self.module_name.coerce_to_pyobject(env)
         if self.name_list:
             self.name_list.analyse_types(env)
             self.name_list.coerce_to_pyobject(env)
-        self.type = py_object_type
         self.is_temp = 1
         env.use_utility_code(import_utility_code)
 
@@ -1367,12 +1394,13 @@ class IteratorNode(ExprNode):
     #
     #  sequence   ExprNode
     
+    type = py_object_type
+    
     subexprs = ['sequence']
     
     def analyse_types(self, env):
         self.sequence.analyse_types(env)
         self.sequence = self.sequence.coerce_to_pyobject(env)
-        self.type = py_object_type
         self.is_temp = 1
 
     gil_message = "Iterating over Python object"
@@ -1424,10 +1452,11 @@ class NextNode(AtomicExprNode):
     #
     #  iterator   ExprNode
     
+    type = py_object_type
+    
     def __init__(self, iterator, env):
         self.pos = iterator.pos
         self.iterator = iterator
-        self.type = py_object_type
         self.is_temp = 1
     
     def generate_result_code(self, code):
@@ -1480,9 +1509,10 @@ class ExcValueNode(AtomicExprNode):
     #  of an ExceptClauseNode to fetch the current
     #  exception value.
     
+    type = py_object_type
+    
     def __init__(self, pos, env):
         ExprNode.__init__(self, pos)
-        self.type = py_object_type
 
     def set_var(self, var):
         self.var = var
@@ -1598,6 +1628,19 @@ class IndexNode(ExprNode):
             return PyrexTypes.CArrayType(base_type, int(self.index.compile_time_value(env)))
         return None
     
+    def type_dependencies(self):
+        return self.base.type_dependencies()
+    
+    def infer_type(self, env):
+        if isinstance(self.base, StringNode):
+            return py_object_type
+        base_type = self.base.infer_type(env)
+        if base_type.is_ptr or base_type.is_array:
+            return base_type.base_type
+        else:
+            # TODO: Handle buffers (hopefully without too much redundancy).
+            return py_object_type
+    
     def analyse_types(self, env):
         self.analyse_base_and_index_types(env, getting = 1)
     
@@ -2106,6 +2149,9 @@ class SliceNode(ExprNode):
     #  start     ExprNode
     #  stop      ExprNode
     #  step      ExprNode
+    
+    type = py_object_type
+    is_temp = 1
 
     def calculate_constant_result(self):
         self.constant_result = self.base.constant_result[
@@ -2137,8 +2183,6 @@ class SliceNode(ExprNode):
         self.start = self.start.coerce_to_pyobject(env)
         self.stop = self.stop.coerce_to_pyobject(env)
         self.step = self.step.coerce_to_pyobject(env)
-        self.type = py_object_type
-        self.is_temp = 1
 
     gil_message = "Constructing Python slice object"
 
@@ -2154,6 +2198,7 @@ class SliceNode(ExprNode):
 
 
 class CallNode(ExprNode):
+
     def analyse_as_type_constructor(self, env):
         type = self.function.analyse_as_type(env)
         if type and type.is_struct_or_union:
@@ -2206,6 +2251,20 @@ class SimpleCallNode(CallNode):
         except Exception, e:
             self.compile_time_value_error(e)
             
+    def type_dependencies(self):
+        # TODO: Update when Danilo's C++ code merged in to handle the
+        # the case of function overloading.
+        return self.function.type_dependencies()
+    
+    def infer_type(self, env):
+        func_type = self.function.infer_type(env)
+        if func_type.is_ptr:
+            func_type = func_type.base_type
+        if func_type.is_cfunction:
+            return func_type.return_type
+        else:
+            return py_object_type            
+    
     def analyse_as_type(self, env):
         attr = self.function.as_cython_attribute()
         if attr == 'pointer':
@@ -2466,6 +2525,8 @@ class GeneralCallNode(CallNode):
     #  keyword_args     ExprNode or None  Dict of keyword arguments
     #  starstar_arg     ExprNode or None  Dict of extra keyword args
     
+    type = py_object_type
+    
     subexprs = ['function', 'positional_args', 'keyword_args', 'starstar_arg']
 
     nogil_check = Node.gil_error
@@ -2643,6 +2704,15 @@ class AttributeNode(ExprNode):
             return getattr(obj, attr)
         except Exception, e:
             self.compile_time_value_error(e)
+    
+    def infer_type(self, env):
+        if self.analyse_as_cimported_attribute(env, 0):
+            return self.entry.type
+        elif self.analyse_as_unbound_cmethod(env):
+            return self.entry.type
+        else:
+            self.analyse_attribute(env)
+            return self.type
 
     def analyse_target_declaration(self, env):
         pass
@@ -3207,6 +3277,8 @@ class SequenceNode(ExprNode):
 
 class TupleNode(SequenceNode):
     #  Tuple constructor.
+    
+    type = tuple_type
 
     gil_message = "Constructing Python tuple"
 
@@ -3216,7 +3288,6 @@ class TupleNode(SequenceNode):
             self.is_literal = 1
         else:
             SequenceNode.analyse_types(self, env, skip_children)
-        self.type = tuple_type
             
     def calculate_result_code(self):
         if len(self.args) > 0:
@@ -3275,6 +3346,13 @@ class ListNode(SequenceNode):
     obj_conversion_errors = []
 
     gil_message = "Constructing Python list"
+    
+    def type_dependencies(self):
+        return ()
+    
+    def infer_type(self, env):
+        # TOOD: Infer non-object list arrays.
+        return list_type
 
     def analyse_expressions(self, env):
         SequenceNode.analyse_expressions(self, env)
@@ -3382,6 +3460,9 @@ class ComprehensionNode(ExprNode):
     subexprs = ["target"]
     child_attrs = ["loop", "append"]
 
+    def infer_type(self, env):
+        return self.target.infer_type(env)
+    
     def analyse_types(self, env):
         self.target.analyse_expressions(env)
         self.type = self.target.type
@@ -3458,10 +3539,12 @@ class DictComprehensionAppendNode(ComprehensionAppendNode):
 class SetNode(ExprNode):
     #  Set constructor.
 
+    type = set_type
+
     subexprs = ['args']
 
     gil_message = "Constructing Python set"
-
+    
     def analyse_types(self, env):
         for i in range(len(self.args)):
             arg = self.args[i]
@@ -3525,6 +3608,13 @@ class DictNode(ExprNode):
         except Exception, e:
             self.compile_time_value_error(e)
     
+    def type_dependencies(self):
+        return ()
+    
+    def infer_type(self, env):
+        # TOOD: Infer struct constructors.
+        return dict_type
+
     def analyse_types(self, env):
         hold_errors()
         for item in self.key_value_pairs:
@@ -3688,12 +3778,13 @@ class UnboundMethodNode(ExprNode):
     #
     #  function      ExprNode   Function object
     
+    type = py_object_type
+    is_temp = 1
+    
     subexprs = ['function']
     
     def analyse_types(self, env):
         self.function.analyse_types(env)
-        self.type = py_object_type
-        self.is_temp = 1
 
     gil_message = "Constructing an unbound method"
 
@@ -3714,10 +3805,12 @@ class PyCFunctionNode(AtomicExprNode):
     #
     #  pymethdef_cname   string   PyMethodDef structure
     
+    type = py_object_type
+    is_temp = 1
+    
     def analyse_types(self, env):
-        self.type = py_object_type
-        self.is_temp = 1
-
+        pass
+    
     gil_message = "Constructing Python function"
 
     def generate_result_code(self, code):
@@ -3771,6 +3864,9 @@ class UnopNode(ExprNode):
             return func(operand)
         except Exception, e:
             self.compile_time_value_error(e)
+    
+    def infer_type(self, env):
+        return self.operand.infer_type(env)
 
     def analyse_types(self, env):
         self.operand.analyse_types(env)
@@ -3819,7 +3915,11 @@ class NotNode(ExprNode):
     #  'not' operator
     #
     #  operand   ExprNode
+    
+    type = PyrexTypes.c_bint_type
 
+    subexprs = ['operand']
+    
     def calculate_constant_result(self):
         self.constant_result = not self.operand.constant_result
 
@@ -3830,12 +3930,12 @@ class NotNode(ExprNode):
         except Exception, e:
             self.compile_time_value_error(e)
 
-    subexprs = ['operand']
+    def infer_type(self, env):
+        return PyrexTypes.c_bint_type
     
     def analyse_types(self, env):
         self.operand.analyse_types(env)
         self.operand = self.operand.coerce_to_boolean(env)
-        self.type = PyrexTypes.c_bint_type
     
     def calculate_result_code(self):
         return "(!%s)" % self.operand.result()
@@ -3903,6 +4003,9 @@ class AmpersandNode(ExprNode):
     #  operand  ExprNode
     
     subexprs = ['operand']
+    
+    def infer_type(self, env):
+        return PyrexTypes.c_ptr_type(self.operand.infer_type(env))
 
     def analyse_types(self, env):
         self.operand.analyse_types(env)
@@ -3961,6 +4064,15 @@ class TypecastNode(ExprNode):
     subexprs = ['operand']
     base_type = declarator = type = None
     
+    def type_dependencies(self):
+        return ()
+    
+    def infer_types(self, env):
+        if self.type is None:
+            base_type = self.base_type.analyse(env)
+            _, self.type = self.declarator.analyse(base_type, env)
+        return self.type
+    
     def analyse_types(self, env):
         if self.type is None:
             base_type = self.base_type.analyse(env)
@@ -4182,6 +4294,10 @@ class BinopNode(ExprNode):
             return func(operand1, operand2)
         except Exception, e:
             self.compile_time_value_error(e)
+    
+    def infer_type(self, env):
+        return self.result_type(self.operand1.infer_type(env),
+                                self.operand1.infer_type(env))
 
     def analyse_types(self, env):
         self.operand1.analyse_types(env)
@@ -4196,13 +4312,21 @@ class BinopNode(ExprNode):
             self.analyse_c_operation(env)
     
     def is_py_operation(self):
-        return (self.operand1.type.is_pyobject 
-            or self.operand2.type.is_pyobject)
+        return self.is_py_operation_types(self.operand1.type, self.operand2.type)
+    
+    def is_py_operation_types(self, type1, type2):
+        return type1.is_pyobject or type2.is_pyobject
+
+    def result_type(self, type1, type2):
+        if self.is_py_operation_types(type1, type2):
+            return py_object_type
+        else:
+            return self.compute_c_result_type(type1, type2)
 
     def nogil_check(self, env):
         if self.is_py_operation():
             self.gil_error()
-    
+        
     def coerce_operands_to_pyobjects(self, env):
         self.operand1 = self.operand1.coerce_to_pyobject(env)
         self.operand2 = self.operand2.coerce_to_pyobject(env)
@@ -4321,12 +4445,11 @@ class IntBinopNode(NumBinopNode):
 class AddNode(NumBinopNode):
     #  '+' operator.
     
-    def is_py_operation(self):
-        if self.operand1.type.is_string \
-            and self.operand2.type.is_string:
-                return 1
+    def is_py_operation_types(self, type1, type2):
+        if type1.is_string and type2.is_string:
+            return 1
         else:
-            return NumBinopNode.is_py_operation(self)
+            return NumBinopNode.is_py_operation_types(self, type1, type2)
 
     def compute_c_result_type(self, type1, type2):
         #print "AddNode.compute_c_result_type:", type1, self.operator, type2 ###
@@ -4355,14 +4478,12 @@ class SubNode(NumBinopNode):
 class MulNode(NumBinopNode):
     #  '*' operator.
     
-    def is_py_operation(self):
-        type1 = self.operand1.type
-        type2 = self.operand2.type
+    def is_py_operation_types(self, type1, type2):
         if (type1.is_string and type2.is_int) \
             or (type2.is_string and type1.is_int):
                 return 1
         else:
-            return NumBinopNode.is_py_operation(self)
+            return NumBinopNode.is_py_operation_types(self, type1, type2)
 
 
 class DivNode(NumBinopNode):
@@ -4499,10 +4620,10 @@ class DivNode(NumBinopNode):
 class ModNode(DivNode):
     #  '%' operator.
 
-    def is_py_operation(self):
-        return (self.operand1.type.is_string
-            or self.operand2.type.is_string
-            or NumBinopNode.is_py_operation(self))
+    def is_py_operation_types(self, type1, type2):
+        return (type1.is_string
+            or type2.is_string
+            or NumBinopNode.is_py_operation_types(self, type1, type2))
 
     def zero_division_message(self):
         if self.type.is_int:
@@ -4578,6 +4699,13 @@ class BoolBinopNode(ExprNode):
     #  operand2     ExprNode
     
     subexprs = ['operand1', 'operand2']
+    
+    def infer_type(self, env):
+        if (self.operand1.infer_type(env).is_pyobject or
+            self.operand2.infer_type(env).is_pyobject):
+            return py_object_type
+        else:
+            return PyrexTypes.c_bint_type
 
     def calculate_constant_result(self):
         if self.operator == 'and':
@@ -4692,6 +4820,13 @@ class CondExprNode(ExprNode):
     false_val = None
     
     subexprs = ['test', 'true_val', 'false_val']
+    
+    def type_dependencies(self):
+        return self.true_val.type_dependencies() + self.false_val.type_dependencies()
+    
+    def infer_types(self, env):
+        return self.compute_result_type(self.true_val.infer_types(env),
+                                        self.false_val.infer_types(env))
 
     def calculate_constant_result(self):
         if self.test.constant_result:
@@ -4776,6 +4911,10 @@ richcmp_constants = {
 class CmpNode(object):
     #  Mixin class containing code common to PrimaryCmpNodes
     #  and CascadedCmpNodes.
+    
+    def infer_types(self, env):
+        # TODO: Actually implement this (after merging with -unstable).
+        return py_object_type
 
     def calculate_cascaded_constant_result(self, operand1_result):
         func = compile_time_binary_operators[self.operator]
@@ -5294,6 +5433,8 @@ class NoneCheckNode(CoercionNode):
 class CoerceToPyTypeNode(CoercionNode):
     #  This node is used to convert a C data type
     #  to a Python object.
+    
+    type = py_object_type
 
     def __init__(self, arg, env):
         CoercionNode.__init__(self, arg)
@@ -5366,9 +5507,10 @@ class CoerceToBooleanNode(CoercionNode):
     #  This node is used when a result needs to be used
     #  in a boolean context.
     
+    type = PyrexTypes.c_bint_type
+    
     def __init__(self, arg, env):
         CoercionNode.__init__(self, arg)
-        self.type = PyrexTypes.c_bint_type
         if arg.type.is_pyobject:
             self.is_temp = 1
 
@@ -5472,6 +5614,12 @@ class CloneNode(CoercionNode):
     
     def result(self):
         return self.arg.result()
+    
+    def type_dependencies(self):
+        return self.arg.type_dependencies()
+    
+    def infer_type(self, env):
+        return self.arg.infer_type(env)
         
     def analyse_types(self, env):
         self.type = self.arg.type