cdivision - raise zero division errors
authorRobert Bradshaw <robertwb@math.washington.edu>
Tue, 14 Apr 2009 22:12:47 +0000 (15:12 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Tue, 14 Apr 2009 22:12:47 +0000 (15:12 -0700)
Cython/Compiler/ExprNodes.py
tests/run/cdivision_CEP_516.pyx

index cf0135cad250b6c841ec28627fab24e67a501655..9d274b94bab4c2fd71acbe7520a83da95be5df4d 100644 (file)
@@ -4237,12 +4237,22 @@ class DivNode(NumBinopNode):
     
     cdivision = None
     cdivision_warnings = False
+    zerodivision_check = None
     
     def analyse_types(self, env):
         NumBinopNode.analyse_types(self, env)
-        if not self.type.is_pyobject and env.directives['cdivision_warnings']:
-            self.operand1 = self.operand1.coerce_to_simple(env)
-            self.operand2 = self.operand2.coerce_to_simple(env)
+        if not self.type.is_pyobject:
+            self.zerodivision_check = self.cdivision is None and not env.directives['cdivision']
+            if self.zerodivision_check or env.directives['cdivision_warnings']:
+                # Need to check ahead of time to warn or raise zero division error
+                self.operand1 = self.operand1.coerce_to_simple(env)
+                self.operand2 = self.operand2.coerce_to_simple(env)
+    
+    def zero_division_message(self):
+        if self.type.is_int:
+            return "integer division or modulo by zero"
+        else:
+            return "float division"
 
     def generate_evaluation_code(self, code):
         if not self.type.is_pyobject:
@@ -4253,18 +4263,24 @@ class DivNode(NumBinopNode):
             if not self.cdivision:
                 code.globalstate.use_utility_code(div_int_utility_code.specialize(self.type))
         NumBinopNode.generate_evaluation_code(self, code)
-        if not self.type.is_pyobject and code.globalstate.directives['cdivision_warnings']:
-            self.generate_div_warning_code(code)
+        self.generate_div_warning_code(code)
     
     def generate_div_warning_code(self, code):
-        code.globalstate.use_utility_code(cdivision_warning_utility_code)
-        code.putln("if ((%s < 0) ^ (%s < 0)) {" % (
-                        self.operand1.result(),
-                        self.operand2.result()))
-        code.putln(code.set_error_info(self.pos));
-        code.put("if (__Pyx_cdivision_warning()) ")
-        code.put_goto(code.error_label)
-        code.putln("}")
+        if not self.type.is_pyobject:
+            if self.zerodivision_check:
+                code.putln("if (unlikely(%s == 0)) {" % self.operand2.result())
+                code.putln('PyErr_Format(PyExc_ZeroDivisionError, "%s");' % self.zero_division_message())
+                code.putln(code.error_goto(self.pos))
+                code.putln("}")
+            if code.globalstate.directives['cdivision_warnings']:
+                code.globalstate.use_utility_code(cdivision_warning_utility_code)
+                code.putln("if ((%s < 0) ^ (%s < 0)) {" % (
+                                self.operand1.result(),
+                                self.operand2.result()))
+                code.putln(code.set_error_info(self.pos));
+                code.put("if (__Pyx_cdivision_warning()) ")
+                code.put_goto(code.error_label)
+                code.putln("}")
     
     def calculate_result_code(self):
         if self.type.is_float and self.operator == '//':
@@ -4290,6 +4306,12 @@ class ModNode(DivNode):
             or self.operand2.type.is_string
             or NumBinopNode.is_py_operation(self))
 
+    def zero_division_message(self):
+        if self.type.is_int:
+            return "integer division or modulo by zero"
+        else:
+            return "float divmod()"
+    
     def generate_evaluation_code(self, code):
         if not self.type.is_pyobject:
             if self.cdivision is None:
@@ -4301,8 +4323,7 @@ class ModNode(DivNode):
                     code.globalstate.use_utility_code(
                         mod_float_utility_code.specialize(self.type, math_h_modifier=self.type.math_h_modifier))
         NumBinopNode.generate_evaluation_code(self, code)
-        if not self.type.is_pyobject and code.globalstate.directives['cdivision_warnings']:
-            self.generate_div_warning_code(code)
+        self.generate_div_warning_code(code)
     
     def calculate_result_code(self):
         if self.cdivision:
index 51f9eba16f9ff7f6629d3f2769325db00a9a467a..82a55139af9e80470fe608719c7b49c862ede533 100644 (file)
@@ -42,12 +42,28 @@ division with oppositely signed operands, C and Python semantics differ
 >>> div_int_c_warn(-17, 10)
 division with oppositely signed operands, C and Python semantics differ
 -1
->>> complex_expression(-150, 20, 20, -7)
-verbose_call(-150)
-division with oppositely signed operands, C and Python semantics differ
+>>> complex_expression(-150, 20, 19, -7)
 verbose_call(20)
 division with oppositely signed operands, C and Python semantics differ
+verbose_call(19)
+division with oppositely signed operands, C and Python semantics differ
 -2
+
+>>> mod_div_zero_int(25, 10, 2)
+verbose_call(5)
+2
+>>> mod_div_zero_int(25, 10, 0)
+verbose_call(5)
+'integer division or modulo by zero'
+>>> mod_div_zero_int(25, 0, 0)
+'integer division or modulo by zero'
+
+>>> mod_div_zero_float(25, 10, 2)
+2.5
+>>> mod_div_zero_float(25, 10, 0)
+'float division'
+>>> mod_div_zero_float(25, 0, 0)
+'float divmod()'
 """
 
 cimport cython
@@ -109,8 +125,25 @@ def div_int_c_warn(int a, int b):
 @cython.cdivision(False)
 @cython.cdivision_warnings(True)
 def complex_expression(int a, int b, int c, int d):
-    return (verbose_call(a) // b) % (verbose_call(c) // d)
+    return (a // verbose_call(b)) % (verbose_call(c) // d)
 
 cdef int verbose_call(int x):
     print "verbose_call(%s)" % x
     return x
+
+
+# These may segfault with cdivision
+
+@cython.cdivision(False)
+def mod_div_zero_int(int a, int b, int c):
+    try:
+        return verbose_call(a % b) / c
+    except ZeroDivisionError, ex:
+        return ex.message
+
+@cython.cdivision(False)
+def mod_div_zero_float(float a, float b, float c):
+    try:
+        return (a % b) / c
+    except ZeroDivisionError, ex:
+        return ex.message