From 8af9a568886fe0c8b1d7e656fb1af8b3773f31b7 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 14 May 2009 00:58:13 -0700 Subject: [PATCH] Complex numeber comparison, etc. --- Cython/Compiler/ExprNodes.py | 56 ++++++++++++++++++++++++++++------- Cython/Compiler/Nodes.py | 6 ++-- Cython/Compiler/PyrexTypes.py | 40 ++++++++++++++++++------- 3 files changed, 80 insertions(+), 22 deletions(-) diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 70fe414f..7cd13af7 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -4213,8 +4213,11 @@ class NumBinopNode(BinopNode): self.operator, self.operand2.result()) else: + func = self.type.binary_op(self.operator) + if func is None: + error(self.pos, "binary operator %s not supported for %s" % (self.operator, self.type)) return "%s(%s, %s)" % ( - self.type.binop(self.operator), + func, self.operand1.result(), self.operand2.result()) @@ -4318,7 +4321,7 @@ class DivNode(NumBinopNode): return "float division" def generate_evaluation_code(self, code): - if not self.type.is_pyobject: + if not self.type.is_pyobject and not self.type.is_complex: if self.cdivision is None: self.cdivision = (code.globalstate.directives['cdivision'] or not self.type.signed @@ -4331,7 +4334,11 @@ class DivNode(NumBinopNode): def generate_div_warning_code(self, code): if not self.type.is_pyobject: if self.zerodivision_check: - code.putln("if (unlikely(%s == 0)) {" % self.operand2.result()) + if not self.infix: + zero_test = "%s(%s)" % (self.type.unary_op('zero'), self.operand2.result()) + else: + zero_test = "%s == 0" % self.operand2.result() + code.putln("if (unlikely(%s)) {" % zero_test) code.putln('PyErr_Format(PyExc_ZeroDivisionError, "%s");' % self.zero_division_message()) code.putln(code.error_goto(self.pos)) code.putln("}") @@ -4344,7 +4351,7 @@ class DivNode(NumBinopNode): code.putln('PyErr_Format(PyExc_OverflowError, "value too large to perform division");') code.putln(code.error_goto(self.pos)) code.putln("}") - if code.globalstate.directives['cdivision_warnings']: + if code.globalstate.directives['cdivision_warnings'] and self.operand != '/': code.globalstate.use_utility_code(cdivision_warning_utility_code) code.putln("if ((%s < 0) ^ (%s < 0)) {" % ( self.operand1.result(), @@ -4355,7 +4362,9 @@ class DivNode(NumBinopNode): code.putln("}") def calculate_result_code(self): - if self.type.is_float and self.operator == '//': + if self.type.is_complex: + return NumBinopNode.calculate_result_code(self) + elif self.type.is_float and self.operator == '//': return "floor(%s / %s)" % ( self.operand1.result(), self.operand2.result()) @@ -4705,7 +4714,13 @@ class CmpNode(object): or (self.cascade and self.cascade.is_python_result())) def check_types(self, env, operand1, op, operand2): - if not self.types_okay(operand1, op, operand2): + if operand1.type.is_complex or operand2.type.is_complex: + if op not in ('==', '!='): + error(self.pos, "complex types unordered") + common_type = PyrexTypes.widest_numeric_type(operand1.type, operand2.type) + self.operand1 = operand1.coerce_to(common_type, env) + self.operand2 = operand2.coerce_to(common_type, env) + elif not self.types_okay(operand1, op, operand2): error(self.pos, "Invalid types for '%s' (%s, %s)" % (self.operator, operand1.type, operand2.type)) @@ -4754,6 +4769,16 @@ class CmpNode(object): richcmp_constants[op], code.error_goto_if_null(result_code, self.pos))) code.put_gotref(result_code) + elif operand1.type.is_complex and not code.globalstate.directives['c99_complex']: + if op == "!=": negation = "!" + else: negation = "" + code.putln("%s = %s(%s%s(%s, %s));" % ( + result_code, + coerce_result, + negation, + operand1.type.unary_op('eq'), + operand1.result(), + operand2.result())) else: type1 = operand1.type type2 = operand2.type @@ -4881,10 +4906,21 @@ class PrimaryCmpNode(NewTempExprNode, CmpNode): self.not_const() def calculate_result_code(self): - return "(%s %s %s)" % ( - self.operand1.result(), - self.c_operator(self.operator), - self.operand2.result()) + if self.operand1.type.is_complex: + if self.operator == "!=": + negation = "!" + else: + negation = "" + return "(%s%s(%s, %s))" % ( + negation, + self.operand1.type.binary_op('=='), + self.operand1.result(), + self.operand2.result()) + else: + return "(%s %s %s)" % ( + self.operand1.result(), + self.c_operator(self.operator), + self.operand2.result()) def generate_evaluation_code(self, code): self.operand1.generate_evaluation_code(code) diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index 084caf12..2ff86dba 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -1921,6 +1921,10 @@ class DefNode(FuncDefNode): has_star_or_kw_args = self.star_arg is not None \ or self.starstar_arg is not None or has_kwonly_args + for arg in self.args: + if not arg.type.is_pyobject and arg.type.from_py_function is None: + arg.type.create_from_py_utility_code(env) + if not self.signature_has_generic_args(): if has_star_or_kw_args: error(self.pos, "This method cannot have * or keyword arguments") @@ -1951,8 +1955,6 @@ class DefNode(FuncDefNode): error(arg.pos, "Non-default argument following default argument") elif not arg.is_self_arg: positional_args.append(arg) - if arg.type.from_py_function is None: - arg.type.create_from_py_utility_code(env) self.generate_tuple_and_keyword_parsing_code( positional_args, kw_only_args, end_label, code) diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py index 3cf7395c..54b81ad1 100644 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@ -753,18 +753,33 @@ class CComplexType(CNumericType): self.from_py_function = "__pyx_PyObject_As_" + self.specalization_name() return True - def binop(self, op): + def lookup_op(self, nargs, op): try: - return self.binops[op] + return self.binops[nargs, op] except KeyError: - if op in "+-*/": - from ExprNodes import compile_time_binary_operators - op_name = compile_time_binary_operators[op].__name__ - self.binops[op] = func_name = "%s_%s" % (self.specalization_name(), op_name) - return func_name - else: - error("Binary '%s' not supported in for %s" % (op, self)) - return "" + pass + try: + op_name = complex_ops[nargs, op] + self.binops[nargs, op] = func_name = "%s_%s" % (self.specalization_name(), op_name) + return func_name + except KeyError: + return None + + def unary_op(self, op): + return self.lookup_op(1, op) + + def binary_op(self, op): + return self.lookup_op(2, op) + +complex_ops = { + (1, '-'): 'neg', + (1, 'zero'): 'is_zero', + (2, '+'): 'add', + (2, '-') : 'sub', + (2, '*'): 'mul', + (2, '/'): 'div', + (2, '=='): 'eq', +} complex_generic_utility_code = UtilityCode( proto=""" @@ -804,6 +819,7 @@ proto=""" #define %(type_name)s_from_parts(x, y) ((x) + (y)*(%(type)s)_Complex_I) #define %(type_name)s_is_zero(a) ((a) == 0) + #define %(type_name)s_eq(a, b) ((a) == (b)) #define %(type_name)s_add(a, b) ((a)+(b)) #define %(type_name)s_sub(a, b) ((a)-(b)) #define %(type_name)s_mul(a, b) ((a)*(b)) @@ -819,6 +835,10 @@ proto=""" return (a.real == 0) & (a.imag == 0); } + static INLINE int %(type_name)s_eq(%(type)s a, %(type)s b) { + return (a.real == b.real) & (a.imag == b.imag); + } + static INLINE %(type)s %(type_name)s_add(%(type)s a, %(type)s b) { %(type)s z; z.real = a.real + b.real; -- 2.26.2