From d3ef6b9cd1fafe3022437f5957ba6df4f3f11fc6 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Sat, 3 Oct 2009 22:37:02 -0700 Subject: [PATCH] Type inference methods for expression nodes. --- Cython/Compiler/ExprNodes.py | 218 +++++++++++++++++++++++++++++------ 1 file changed, 183 insertions(+), 35 deletions(-) diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 31f76d48..24c9287e 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -307,6 +307,27 @@ class ExprNode(Node): temp_bool = bool.coerce_to_temp(env) return temp_bool + # --------------- Type Inference ----------------- + + def type_dependencies(self): + # Returns the list of entries whose types must be determined + # before the type of self can be infered. + if hasattr(self, 'type') and self.type is not None: + return () + return sum([node.type_dependencies() for node in self.subexpr_nodes()], ()) + + def infer_type(self, env): + # Attempt to deduce the type of self. + # Differs from analyse_types as it avoids unnecessary + # analysis of subexpressions, but can assume everything + # in self.type_dependencies() has been resolved. + if hasattr(self, 'type') and self.type is not None: + return self.type + elif hasattr(self, 'entry') and self.entry is not None: + return self.entry.type + else: + self.not_implemented("infer_type") + # --------------- Type Analysis ------------------ def analyse_as_module(self, env): @@ -858,6 +879,8 @@ class LongNode(AtomicExprNode): # # value string + type = py_object_type + def calculate_constant_result(self): self.constant_result = long(self.value) @@ -865,7 +888,6 @@ class LongNode(AtomicExprNode): return long(self.value) def analyse_types(self, env): - self.type = py_object_type self.is_temp = 1 gil_message = "Constructing Python long int" @@ -954,6 +976,9 @@ class NameNode(AtomicExprNode): create_analysed_rvalue = staticmethod(create_analysed_rvalue) + def type_dependencies(self): + return self.entry + def compile_time_value(self, denv): try: return denv.lookup(self.name) @@ -1298,12 +1323,13 @@ class BackquoteNode(ExprNode): # # arg ExprNode + type = py_object_type + subexprs = ['arg'] def analyse_types(self, env): self.arg.analyse_types(env) self.arg = self.arg.coerce_to_pyobject(env) - self.type = py_object_type self.is_temp = 1 gil_message = "Backquote expression" @@ -1329,15 +1355,16 @@ class ImportNode(ExprNode): # module_name IdentifierStringNode dotted name of module # name_list ListNode or None list of names to be imported + type = py_object_type + subexprs = ['module_name', 'name_list'] - + def analyse_types(self, env): self.module_name.analyse_types(env) self.module_name = self.module_name.coerce_to_pyobject(env) if self.name_list: self.name_list.analyse_types(env) self.name_list.coerce_to_pyobject(env) - self.type = py_object_type self.is_temp = 1 env.use_utility_code(import_utility_code) @@ -1367,12 +1394,13 @@ class IteratorNode(ExprNode): # # sequence ExprNode + type = py_object_type + subexprs = ['sequence'] def analyse_types(self, env): self.sequence.analyse_types(env) self.sequence = self.sequence.coerce_to_pyobject(env) - self.type = py_object_type self.is_temp = 1 gil_message = "Iterating over Python object" @@ -1424,10 +1452,11 @@ class NextNode(AtomicExprNode): # # iterator ExprNode + type = py_object_type + def __init__(self, iterator, env): self.pos = iterator.pos self.iterator = iterator - self.type = py_object_type self.is_temp = 1 def generate_result_code(self, code): @@ -1480,9 +1509,10 @@ class ExcValueNode(AtomicExprNode): # of an ExceptClauseNode to fetch the current # exception value. + type = py_object_type + def __init__(self, pos, env): ExprNode.__init__(self, pos) - self.type = py_object_type def set_var(self, var): self.var = var @@ -1598,6 +1628,19 @@ class IndexNode(ExprNode): return PyrexTypes.CArrayType(base_type, int(self.index.compile_time_value(env))) return None + def type_dependencies(self): + return self.base.type_dependencies() + + def infer_type(self, env): + if isinstance(self.base, StringNode): + return py_object_type + base_type = self.base.infer_type(env) + if base_type.is_ptr or base_type.is_array: + return base_type.base_type + else: + # TODO: Handle buffers (hopefully without too much redundancy). + return py_object_type + def analyse_types(self, env): self.analyse_base_and_index_types(env, getting = 1) @@ -2106,6 +2149,9 @@ class SliceNode(ExprNode): # start ExprNode # stop ExprNode # step ExprNode + + type = py_object_type + is_temp = 1 def calculate_constant_result(self): self.constant_result = self.base.constant_result[ @@ -2137,8 +2183,6 @@ class SliceNode(ExprNode): self.start = self.start.coerce_to_pyobject(env) self.stop = self.stop.coerce_to_pyobject(env) self.step = self.step.coerce_to_pyobject(env) - self.type = py_object_type - self.is_temp = 1 gil_message = "Constructing Python slice object" @@ -2154,6 +2198,7 @@ class SliceNode(ExprNode): class CallNode(ExprNode): + def analyse_as_type_constructor(self, env): type = self.function.analyse_as_type(env) if type and type.is_struct_or_union: @@ -2206,6 +2251,20 @@ class SimpleCallNode(CallNode): except Exception, e: self.compile_time_value_error(e) + def type_dependencies(self): + # TODO: Update when Danilo's C++ code merged in to handle the + # the case of function overloading. + return self.function.type_dependencies() + + def infer_type(self, env): + func_type = self.function.infer_type(env) + if func_type.is_ptr: + func_type = func_type.base_type + if func_type.is_cfunction: + return func_type.return_type + else: + return py_object_type + def analyse_as_type(self, env): attr = self.function.as_cython_attribute() if attr == 'pointer': @@ -2466,6 +2525,8 @@ class GeneralCallNode(CallNode): # keyword_args ExprNode or None Dict of keyword arguments # starstar_arg ExprNode or None Dict of extra keyword args + type = py_object_type + subexprs = ['function', 'positional_args', 'keyword_args', 'starstar_arg'] nogil_check = Node.gil_error @@ -2643,6 +2704,15 @@ class AttributeNode(ExprNode): return getattr(obj, attr) except Exception, e: self.compile_time_value_error(e) + + def infer_type(self, env): + if self.analyse_as_cimported_attribute(env, 0): + return self.entry.type + elif self.analyse_as_unbound_cmethod(env): + return self.entry.type + else: + self.analyse_attribute(env) + return self.type def analyse_target_declaration(self, env): pass @@ -3207,6 +3277,8 @@ class SequenceNode(ExprNode): class TupleNode(SequenceNode): # Tuple constructor. + + type = tuple_type gil_message = "Constructing Python tuple" @@ -3216,7 +3288,6 @@ class TupleNode(SequenceNode): self.is_literal = 1 else: SequenceNode.analyse_types(self, env, skip_children) - self.type = tuple_type def calculate_result_code(self): if len(self.args) > 0: @@ -3275,6 +3346,13 @@ class ListNode(SequenceNode): obj_conversion_errors = [] gil_message = "Constructing Python list" + + def type_dependencies(self): + return () + + def infer_type(self, env): + # TOOD: Infer non-object list arrays. + return list_type def analyse_expressions(self, env): SequenceNode.analyse_expressions(self, env) @@ -3382,6 +3460,9 @@ class ComprehensionNode(ExprNode): subexprs = ["target"] child_attrs = ["loop", "append"] + def infer_type(self, env): + return self.target.infer_type(env) + def analyse_types(self, env): self.target.analyse_expressions(env) self.type = self.target.type @@ -3458,10 +3539,12 @@ class DictComprehensionAppendNode(ComprehensionAppendNode): class SetNode(ExprNode): # Set constructor. + type = set_type + subexprs = ['args'] gil_message = "Constructing Python set" - + def analyse_types(self, env): for i in range(len(self.args)): arg = self.args[i] @@ -3525,6 +3608,13 @@ class DictNode(ExprNode): except Exception, e: self.compile_time_value_error(e) + def type_dependencies(self): + return () + + def infer_type(self, env): + # TOOD: Infer struct constructors. + return dict_type + def analyse_types(self, env): hold_errors() for item in self.key_value_pairs: @@ -3688,12 +3778,13 @@ class UnboundMethodNode(ExprNode): # # function ExprNode Function object + type = py_object_type + is_temp = 1 + subexprs = ['function'] def analyse_types(self, env): self.function.analyse_types(env) - self.type = py_object_type - self.is_temp = 1 gil_message = "Constructing an unbound method" @@ -3714,10 +3805,12 @@ class PyCFunctionNode(AtomicExprNode): # # pymethdef_cname string PyMethodDef structure + type = py_object_type + is_temp = 1 + def analyse_types(self, env): - self.type = py_object_type - self.is_temp = 1 - + pass + gil_message = "Constructing Python function" def generate_result_code(self, code): @@ -3771,6 +3864,9 @@ class UnopNode(ExprNode): return func(operand) except Exception, e: self.compile_time_value_error(e) + + def infer_type(self, env): + return self.operand.infer_type(env) def analyse_types(self, env): self.operand.analyse_types(env) @@ -3819,7 +3915,11 @@ class NotNode(ExprNode): # 'not' operator # # operand ExprNode + + type = PyrexTypes.c_bint_type + subexprs = ['operand'] + def calculate_constant_result(self): self.constant_result = not self.operand.constant_result @@ -3830,12 +3930,12 @@ class NotNode(ExprNode): except Exception, e: self.compile_time_value_error(e) - subexprs = ['operand'] + def infer_type(self, env): + return PyrexTypes.c_bint_type def analyse_types(self, env): self.operand.analyse_types(env) self.operand = self.operand.coerce_to_boolean(env) - self.type = PyrexTypes.c_bint_type def calculate_result_code(self): return "(!%s)" % self.operand.result() @@ -3903,6 +4003,9 @@ class AmpersandNode(ExprNode): # operand ExprNode subexprs = ['operand'] + + def infer_type(self, env): + return PyrexTypes.c_ptr_type(self.operand.infer_type(env)) def analyse_types(self, env): self.operand.analyse_types(env) @@ -3961,6 +4064,15 @@ class TypecastNode(ExprNode): subexprs = ['operand'] base_type = declarator = type = None + def type_dependencies(self): + return () + + def infer_types(self, env): + if self.type is None: + base_type = self.base_type.analyse(env) + _, self.type = self.declarator.analyse(base_type, env) + return self.type + def analyse_types(self, env): if self.type is None: base_type = self.base_type.analyse(env) @@ -4182,6 +4294,10 @@ class BinopNode(ExprNode): return func(operand1, operand2) except Exception, e: self.compile_time_value_error(e) + + def infer_type(self, env): + return self.result_type(self.operand1.infer_type(env), + self.operand1.infer_type(env)) def analyse_types(self, env): self.operand1.analyse_types(env) @@ -4196,13 +4312,21 @@ class BinopNode(ExprNode): self.analyse_c_operation(env) def is_py_operation(self): - return (self.operand1.type.is_pyobject - or self.operand2.type.is_pyobject) + return self.is_py_operation_types(self.operand1.type, self.operand2.type) + + def is_py_operation_types(self, type1, type2): + return type1.is_pyobject or type2.is_pyobject + + def result_type(self, type1, type2): + if self.is_py_operation_types(type1, type2): + return py_object_type + else: + return self.compute_c_result_type(type1, type2) def nogil_check(self, env): if self.is_py_operation(): self.gil_error() - + def coerce_operands_to_pyobjects(self, env): self.operand1 = self.operand1.coerce_to_pyobject(env) self.operand2 = self.operand2.coerce_to_pyobject(env) @@ -4321,12 +4445,11 @@ class IntBinopNode(NumBinopNode): class AddNode(NumBinopNode): # '+' operator. - def is_py_operation(self): - if self.operand1.type.is_string \ - and self.operand2.type.is_string: - return 1 + def is_py_operation_types(self, type1, type2): + if type1.is_string and type2.is_string: + return 1 else: - return NumBinopNode.is_py_operation(self) + return NumBinopNode.is_py_operation_types(self, type1, type2) def compute_c_result_type(self, type1, type2): #print "AddNode.compute_c_result_type:", type1, self.operator, type2 ### @@ -4355,14 +4478,12 @@ class SubNode(NumBinopNode): class MulNode(NumBinopNode): # '*' operator. - def is_py_operation(self): - type1 = self.operand1.type - type2 = self.operand2.type + def is_py_operation_types(self, type1, type2): if (type1.is_string and type2.is_int) \ or (type2.is_string and type1.is_int): return 1 else: - return NumBinopNode.is_py_operation(self) + return NumBinopNode.is_py_operation_types(self, type1, type2) class DivNode(NumBinopNode): @@ -4499,10 +4620,10 @@ class DivNode(NumBinopNode): class ModNode(DivNode): # '%' operator. - def is_py_operation(self): - return (self.operand1.type.is_string - or self.operand2.type.is_string - or NumBinopNode.is_py_operation(self)) + def is_py_operation_types(self, type1, type2): + return (type1.is_string + or type2.is_string + or NumBinopNode.is_py_operation_types(self, type1, type2)) def zero_division_message(self): if self.type.is_int: @@ -4578,6 +4699,13 @@ class BoolBinopNode(ExprNode): # operand2 ExprNode subexprs = ['operand1', 'operand2'] + + def infer_type(self, env): + if (self.operand1.infer_type(env).is_pyobject or + self.operand2.infer_type(env).is_pyobject): + return py_object_type + else: + return PyrexTypes.c_bint_type def calculate_constant_result(self): if self.operator == 'and': @@ -4692,6 +4820,13 @@ class CondExprNode(ExprNode): false_val = None subexprs = ['test', 'true_val', 'false_val'] + + def type_dependencies(self): + return self.true_val.type_dependencies() + self.false_val.type_dependencies() + + def infer_types(self, env): + return self.compute_result_type(self.true_val.infer_types(env), + self.false_val.infer_types(env)) def calculate_constant_result(self): if self.test.constant_result: @@ -4776,6 +4911,10 @@ richcmp_constants = { class CmpNode(object): # Mixin class containing code common to PrimaryCmpNodes # and CascadedCmpNodes. + + def infer_types(self, env): + # TODO: Actually implement this (after merging with -unstable). + return py_object_type def calculate_cascaded_constant_result(self, operand1_result): func = compile_time_binary_operators[self.operator] @@ -5294,6 +5433,8 @@ class NoneCheckNode(CoercionNode): class CoerceToPyTypeNode(CoercionNode): # This node is used to convert a C data type # to a Python object. + + type = py_object_type def __init__(self, arg, env): CoercionNode.__init__(self, arg) @@ -5366,9 +5507,10 @@ class CoerceToBooleanNode(CoercionNode): # This node is used when a result needs to be used # in a boolean context. + type = PyrexTypes.c_bint_type + def __init__(self, arg, env): CoercionNode.__init__(self, arg) - self.type = PyrexTypes.c_bint_type if arg.type.is_pyobject: self.is_temp = 1 @@ -5472,6 +5614,12 @@ class CloneNode(CoercionNode): def result(self): return self.arg.result() + + def type_dependencies(self): + return self.arg.type_dependencies() + + def infer_type(self, env): + return self.arg.infer_type(env) def analyse_types(self, env): self.type = self.arg.type -- 2.26.2