From: Dag Sverre Seljebotn <dagss@student.matnat.uio.no>
Date: Tue, 8 Jul 2008 22:01:57 +0000 (+0200)
Subject: Buffer assignment appears to be working
X-Git-Tag: 0.9.8.1~49^2~106
X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=3da50c6dc783d8656298db777b9a455d1bf7174e;p=cython.git

Buffer assignment appears to be working
---

diff --git a/Cython/Compiler/Buffer.py b/Cython/Compiler/Buffer.py
index 7f672064..b8978bf9 100644
--- a/Cython/Compiler/Buffer.py
+++ b/Cython/Compiler/Buffer.py
@@ -5,18 +5,97 @@ from Cython.Compiler.ExprNodes import *
 from Cython.Compiler.TreeFragment import TreeFragment
 from Cython.Utils import EncodedString
 from Cython.Compiler.Errors import CompileError
+import PyrexTypes
 from sets import Set as set
 
+class PureCFuncNode(Node):
+    def __init__(self, pos, cname, type, c_code):
+        self.pos = pos
+        self.cname = cname
+        self.type = type
+        self.c_code = c_code
+
+    def analyse_types(self, env):
+        self.entry = env.declare_cfunction(
+            "<pure c function:%s>" % self.cname,
+            self.type, self.pos, cname=self.cname,
+            defining=True)
+
+    def generate_function_definitions(self, env, code, transforms):
+        # TODO: Fix constness, don't hack it
+        assert self.type.optional_arg_count == 0
+        arg_decls = [arg.declaration_code() for arg in self.type.args]
+        sig = self.type.return_type.declaration_code(
+            self.type.function_header_code(self.cname, ", ".join(arg_decls)))
+        code.putln("")
+        code.putln("%s {" % sig)
+        code.put(self.c_code)
+        code.putln("}")
+
+    def generate_execution_code(self, code):
+        pass
+
 class BufferTransform(CythonTransform):
     """
     Run after type analysis. Takes care of the buffer functionality.
     """
     scope = None
+    tschecker_functype = PyrexTypes.CFuncType(
+        PyrexTypes.c_char_ptr_type,
+        [PyrexTypes.CFuncTypeArg(EncodedString("ts"), PyrexTypes.c_char_ptr_type,
+                      (0, 0, None), cname="ts")],
+        exception_value = "NULL"
+    )  
 
     def __call__(self, node):
         cymod = self.context.modules[u'__cython__']
-        self.buffer_type = cymod.entries[u'Py_buffer'].type
-        return super(BufferTransform, self).__call__(node)
+        self.bufstruct_type = cymod.entries[u'Py_buffer'].type
+        self.tscheckers = {}
+        self.module_scope = node.scope
+        self.module_pos = node.pos
+        result = super(BufferTransform, self).__call__(node)
+        result.body.stats += [node for node in self.tscheckers.values()]
+        return result
+
+    def tschecker_simple(self, dtype):
+        char = dtype.typestring
+        return """
+  if (*ts != '%s') {
+    PyErr_Format(PyExc_TypeError, "Buffer datatype mismatch");  
+    return NULL;
+  } else return ts + 1;
+""" % char
+
+    def tschecker(self, dtype):
+        # Creates a type string checker function for the given type.
+        # Each checker is created as a function entry in the module scope
+        # and a PureCNode and put in the self.ts_checkers dict.
+        # Also the entry is returned.
+        #
+        # TODO: __eq__ and __hash__ for types
+        funcnode = self.tscheckers.get(dtype, None)
+        if funcnode is None:
+            assert dtype.is_int or dtype.is_float or dtype.is_struct_or_union
+            # Use prefixes to seperate user defined types from builtins
+            # (consider "typedef float unsigned_int")
+            builtin = not (dtype.is_struct_or_union or dtype.is_typedef)
+            if not builtin:
+                prefix = "user"
+            else:
+                prefix = "builtin"
+            cname = "check_typestring_%s_%s" % (prefix,
+                       dtype.declaration_code("").replace(" ", "_"))
+
+            if dtype.typestring is not None and len(dtype.typestring) == 1:
+                code = self.tschecker_simple(dtype)
+            else:
+                assert False
+
+            funcnode = PureCFuncNode(self.module_pos, cname,
+                                     self.tschecker_functype, code)
+            funcnode.analyse_types(self.module_scope)
+            self.tscheckers[dtype] = funcnode
+        return funcnode.entry
 
     def handle_scope(self, node, scope):
         # For all buffers, insert extra variables in the scope.
@@ -27,11 +106,15 @@ class BufferTransform(CythonTransform):
                    if entry.type.buffer_options is not None]
                    
         for name, entry in bufvars:
-            # Variable has buffer opts, declare auxiliary vars
+            
             bufopts = entry.type.buffer_options
 
+            # Get or make a type string checker
+            tschecker = self.tschecker(bufopts.dtype)
+
+            # Declare auxiliary vars
             bufinfo = scope.declare_var(temp_name_handle(u"%s_bufinfo" % name),
-                                        self.buffer_type, node.pos)
+                                        self.bufstruct_type, node.pos)
 
             temp_var =  scope.declare_var(temp_name_handle(u"%s_tmp" % name),
                                         entry.type, node.pos)
@@ -49,7 +132,7 @@ class BufferTransform(CythonTransform):
                 var = scope.declare_var(varname, PyrexTypes.c_uint_type, node.pos, is_cdef=True)
                 shapevars.append(var)
             entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, 
-                                                shapevars)
+                                                shapevars, tschecker)
             entry.buffer_aux.temp_var = temp_var
         self.scope = scope
 
@@ -64,13 +147,16 @@ class BufferTransform(CythonTransform):
         self.visitchildren(node)
         return node
 
+    # Notes: The cast to <char*> gets around Cython not supporting const types
     acquire_buffer_fragment = TreeFragment(u"""
         TMP = LHS
         if TMP is not None:
             __cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO)
         TMP = RHS
-        __cython__.PyObject_GetBuffer(<__cython__.PyObject*>TMP, &BUFINFO, 0)
-        ASSIGN_AUX
+        if TMP is not None:
+            __cython__.PyObject_GetBuffer(<__cython__.PyObject*>TMP, &BUFINFO, 0)
+            TSCHECKER(<char*>BUFINFO.format)
+            ASSIGN_AUX
         LHS = TMP
     """)
 
@@ -82,22 +168,6 @@ class BufferTransform(CythonTransform):
         TARGET = BUFINFO.shape[IDX]
     """)
 
-    def visit_SingleAssignmentNode(self, node):
-        # On assignments, two buffer-related things can happen:
-        # a) A buffer variable is assigned to (reacquisition)
-        # b) Buffer access assignment: arr[...] = ...
-        # Since we don't allow nested buffers, these don't overlap.
-        
-        self.visitchildren(node)
-        # Only acquire buffers on vars (not attributes) for now.
-        if isinstance(node.lhs, NameNode) and node.lhs.entry.buffer_aux:
-            # Is buffer variable
-            return self.reacquire_buffer(node)
-        elif (isinstance(node.lhs, IndexNode) and
-              isinstance(node.lhs.base, NameNode) and
-              node.lhs.base.entry.buffer_aux is not None):
-            return self.assign_into_buffer(node)
-        
     def reacquire_buffer(self, node):
         bufaux = node.lhs.entry.buffer_aux
         auxass = []
@@ -106,7 +176,7 @@ class BufferTransform(CythonTransform):
             ass = self.fetch_strides.substitute({
                 u"TARGET": NameNode(node.pos, name=entry.name),
                 u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
-                u"IDX": IntNode(node.pos, value=EncodedString(idx))
+                u"IDX": IntNode(node.pos, value=EncodedString(idx)),
             })
             auxass.append(ass)
 
@@ -125,7 +195,8 @@ class BufferTransform(CythonTransform):
             u"LHS" : node.lhs,
             u"RHS": node.rhs,
             u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass),
-            u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name)
+            u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name),
+            u"TSCHECKER": NameNode(node.pos, name=bufaux.tschecker.name)
         }, pos=node.pos)
         # Note: The below should probably be refactored into something
         # like fragment.substitute(..., context=self.context), with
@@ -165,6 +236,24 @@ class BufferTransform(CythonTransform):
 
         return tmp.stats[0].expr
 
+
+    def visit_SingleAssignmentNode(self, node):
+        # On assignments, two buffer-related things can happen:
+        # a) A buffer variable is assigned to (reacquisition)
+        # b) Buffer access assignment: arr[...] = ...
+        # Since we don't allow nested buffers, these don't overlap.
+        self.visitchildren(node)
+        # Only acquire buffers on vars (not attributes) for now.
+        if isinstance(node.lhs, NameNode) and node.lhs.entry.buffer_aux:
+            # Is buffer variable
+            return self.reacquire_buffer(node)
+        elif (isinstance(node.lhs, IndexNode) and
+              isinstance(node.lhs.base, NameNode) and
+              node.lhs.base.entry.buffer_aux is not None):
+            return self.assign_into_buffer(node)
+        else:
+            return node
+        
     buffer_access = TreeFragment(u"""
         (<unsigned char*>(BUF.buf + OFFSET))[0]
     """)
@@ -179,10 +268,3 @@ class BufferTransform(CythonTransform):
         else:
             return node
 
-    def visit_CallNode(self, node):
-###        print node.dump()
-        return node
-    
-#    def visit_FuncDefNode(self, node):
-#        print node.dump()
-    
diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py
index d342cc1a..2e96b41b 100644
--- a/Cython/Compiler/ExprNodes.py
+++ b/Cython/Compiler/ExprNodes.py
@@ -1251,8 +1251,19 @@ class IndexNode(ExprNode):
     #
     #  base     ExprNode
     #  index    ExprNode
+    #  indices  [ExprNode]
+    #  is_buffer_access boolean Whether this is a buffer access.
+    #
+    #  indices is used on buffer access, index on non-buffer access.
+    #  The former contains a clean list of index parameters, the
+    #  latter whatever Python object is needed for index access.
     
-    subexprs = ['base', 'index']
+    subexprs = ['base', 'index', 'indices']
+    indices = None
+
+    def __init__(self, pos, index, *args, **kw):
+        ExprNode.__init__(self, pos, index=index, *args, **kw)
+        self._index = index
     
     def compile_time_value(self, denv):
         base = self.base.compile_time_value(denv)
@@ -1273,7 +1284,7 @@ class IndexNode(ExprNode):
     
     def analyse_target_types(self, env):
         self.analyse_base_and_index_types(env, setting = 1)
-    
+
     def analyse_base_and_index_types(self, env, getting = 0, setting = 0):
         self.is_buffer_access = False
 
@@ -1282,19 +1293,20 @@ class IndexNode(ExprNode):
         if self.base.type.buffer_options is not None:
             if isinstance(self.index, TupleNode):
                 indices = self.index.args
-#                is_int_indices = 0 == sum([1 for i in self.index.args if not i.type.is_int])
             else:
-#                is_int_indices = self.index.type.is_int
                 indices = [self.index]
             all_ints = True
-            for index in indices:
-                index.analyse_types(env)
-                if not index.type.is_int:
+            for x in indices:
+                x.analyse_types(env)
+                if not x.type.is_int:
                     all_ints = False
             if all_ints:
+#                self.indices = [
+#                    x.coerce_to(PyrexTypes.c_py_ssize_t_type, env).coerce_to_simple(env)
+#                    for x in  indices]
                 self.indices = indices
                 self.index = None
-                self.type = self.base.type.buffer_options.dtype
+                self.type = self.base.type.buffer_options.dtype 
                 self.is_temp = 1
                 self.is_buffer_access = True
 
@@ -3935,6 +3947,10 @@ class CoerceToTempNode(CoercionNode):
 
     gil_message = "Creating temporary Python reference"
 
+    def analyse_types(self, env):
+        # The arg is always already analysed
+        pass
+
     def generate_result_code(self, code):
         #self.arg.generate_evaluation_code(code) # Already done
         # by generic generate_subexpr_evaluation_code!
diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py
index d2077607..3a2768f8 100644
--- a/Cython/Compiler/Nodes.py
+++ b/Cython/Compiler/Nodes.py
@@ -184,10 +184,10 @@ class Node(object):
         
         attrs = [(key, value) for key, value in self.__dict__.iteritems() if key not in filter_out]
         if len(attrs) == 0:
-            return "<%s>" % self.__class__.__name__
+            return "<%s (%d)>" % (self.__class__.__name__, id(self))
         else:
             indent = "  " * level
-            res = "<%s\n" % (self.__class__.__name__)
+            res = "<%s (%d)\n" % (self.__class__.__name__, id(self))
             for key, value in attrs:
                 res += "%s  %s: %s\n" % (indent, key, dump_child(value, level + 1))
             res += "%s>" % indent
diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py
index dcbe81e9..ed59d62f 100644
--- a/Cython/Compiler/ParseTreeTransforms.py
+++ b/Cython/Compiler/ParseTreeTransforms.py
@@ -1,4 +1,3 @@
-
 from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform
 from Cython.Compiler.ModuleNode import ModuleNode
 from Cython.Compiler.Nodes import *
@@ -51,12 +50,6 @@ class NormalizeTree(CythonTransform):
         else:
             return node
 
-    def visit_PassStatNode(self, node):
-        if not self.is_in_statlist:
-            return StatListNode(pos=node.pos, stats=[])
-        else:
-            return []
-
     def visit_StatListNode(self, node):
         self.is_in_statlist = True
         self.visitchildren(node)
@@ -72,6 +65,18 @@ class NormalizeTree(CythonTransform):
     def visit_CStructOrUnionDefNode(self, node):
         return self.visit_StatNode(node, True)
 
+    # Eliminate PassStatNode
+    def visit_PassStatNode(self, node):
+        if not self.is_in_statlist:
+            return StatListNode(pos=node.pos, stats=[])
+        else:
+            return []
+
+    # Eliminate CascadedAssignmentNode
+    def visit_CascadedAssignmentNode(self, node):
+        tmpname = temp_name_handle()
+        
+
 
 class PostParseError(CompileError): pass
 
diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py
index 59ae0e3c..0e08ace5 100644
--- a/Cython/Compiler/PyrexTypes.py
+++ b/Cython/Compiler/PyrexTypes.py
@@ -61,6 +61,7 @@ class PyrexType(BaseType):
     #  default_value         string      Initial value
     #  parsetuple_format     string      Format char for PyArg_ParseTuple
     #  pymemberdef_typecode  string      Type code for PyMemberDef struct
+    #  typestring            string      String char defining the type (see Python struct module)
     #
     #  declaration_code(entity_code, 
     #      for_display = 0, dll_linkage = None, pyrex = 0)
@@ -416,9 +417,10 @@ class CNumericType(CType):
     
     sign_words = ("unsigned ", "", "signed ")
     
-    def __init__(self, rank, signed = 1, pymemberdef_typecode = None):
+    def __init__(self, rank, signed = 1, pymemberdef_typecode = None, typestring = None):
         self.rank = rank
         self.signed = signed
+        self.typestring = typestring
         ptf = self.parsetuple_formats[signed][rank]
         if ptf == '?':
             ptf = None
@@ -451,8 +453,9 @@ class CIntType(CNumericType):
     from_py_function = "__pyx_PyInt_AsLong"
     exception_value = -1
 
-    def __init__(self, rank, signed, pymemberdef_typecode = None, is_returncode = 0):
-        CNumericType.__init__(self, rank, signed, pymemberdef_typecode)
+    def __init__(self, rank, signed, pymemberdef_typecode = None, is_returncode = 0,
+                 typestring=None):
+        CNumericType.__init__(self, rank, signed, pymemberdef_typecode, typestring=typestring)
         self.is_returncode = is_returncode
         if self.from_py_function == '__pyx_PyInt_AsLong':
             self.from_py_function = self.get_type_conversion()
@@ -543,8 +546,8 @@ class CFloatType(CNumericType):
     to_py_function = "PyFloat_FromDouble"
     from_py_function = "__pyx_PyFloat_AsDouble"
     
-    def __init__(self, rank, pymemberdef_typecode = None):
-        CNumericType.__init__(self, rank, 1, pymemberdef_typecode)
+    def __init__(self, rank, pymemberdef_typecode = None, typestring=None):
+        CNumericType.__init__(self, rank, 1, pymemberdef_typecode, typestring = typestring)
     
     def assignable_from_resolved_type(self, src_type):
         return src_type.is_numeric or src_type is error_type
@@ -852,9 +855,12 @@ class CFuncTypeArg:
     #  type       PyrexType
     #  pos        source file position
     
-    def __init__(self, name, type, pos):
+    def __init__(self, name, type, pos, cname=None):
         self.name = name
-        self.cname = Naming.var_prefix + name
+        if cname is not None:
+            self.cname = cname
+        else:
+            self.cname = Naming.var_prefix + name
         self.type = type
         self.pos = pos
         self.not_none = False
@@ -1050,29 +1056,29 @@ c_void_type =         CVoidType()
 c_void_ptr_type =     CPtrType(c_void_type)
 c_void_ptr_ptr_type = CPtrType(c_void_ptr_type)
 
-c_uchar_type =       CIntType(0, 0, "T_UBYTE")
-c_ushort_type =      CIntType(1, 0, "T_USHORT")
-c_uint_type =        CUIntType(2, 0, "T_UINT")
-c_ulong_type =       CULongType(3, 0, "T_ULONG")
-c_ulonglong_type =   CULongLongType(4, 0, "T_ULONGLONG")
-
-c_char_type =        CIntType(0, 1, "T_CHAR")
-c_short_type =       CIntType(1, 1, "T_SHORT")
-c_int_type =         CIntType(2, 1, "T_INT")
-c_long_type =        CIntType(3, 1, "T_LONG")
-c_longlong_type =    CLongLongType(4, 1, "T_LONGLONG")
+c_uchar_type =       CIntType(0, 0, "T_UBYTE", typestring="B")
+c_ushort_type =      CIntType(1, 0, "T_USHORT", typestring="H")
+c_uint_type =        CUIntType(2, 0, "T_UINT", typestring="I")
+c_ulong_type =       CULongType(3, 0, "T_ULONG", typestring="L")
+c_ulonglong_type =   CULongLongType(4, 0, "T_ULONGLONG", typestring="Q")
+
+c_char_type =        CIntType(0, 1, "T_CHAR", typestring="b")
+c_short_type =       CIntType(1, 1, "T_SHORT", typestring="h")
+c_int_type =         CIntType(2, 1, "T_INT", typestring="i")
+c_long_type =        CIntType(3, 1, "T_LONG", typestring="l")
+c_longlong_type =    CLongLongType(4, 1, "T_LONGLONG", typestring="q")
 c_py_ssize_t_type =  CPySSizeTType(5, 1)
-c_bint_type =        CBIntType(2, 1, "T_INT")
+c_bint_type =        CBIntType(2, 1, "T_INT", typestring="i")
 
-c_schar_type =       CIntType(0, 2, "T_CHAR")
-c_sshort_type =      CIntType(1, 2, "T_SHORT")
-c_sint_type =        CIntType(2, 2, "T_INT")
-c_slong_type =       CIntType(3, 2, "T_LONG")
-c_slonglong_type =   CLongLongType(4, 2, "T_LONGLONG")
+c_schar_type =       CIntType(0, 2, "T_CHAR", typestring="b")
+c_sshort_type =      CIntType(1, 2, "T_SHORT", typestring="h")
+c_sint_type =        CIntType(2, 2, "T_INT", typestring="i")
+c_slong_type =       CIntType(3, 2, "T_LONG", typestring="l")
+c_slonglong_type =   CLongLongType(4, 2, "T_LONGLONG", typestring="q")
 
-c_float_type =       CFloatType(6, "T_FLOAT")
-c_double_type =      CFloatType(7, "T_DOUBLE")
-c_longdouble_type =  CFloatType(8)
+c_float_type =       CFloatType(6, "T_FLOAT", typestring="f")
+c_double_type =      CFloatType(7, "T_DOUBLE", typestring="d")
+c_longdouble_type =  CFloatType(8, typestring="g")
 
 c_null_ptr_type =     CNullPtrType(c_void_type)
 c_char_array_type =   CCharArrayType(None)
diff --git a/Cython/Compiler/Symtab.py b/Cython/Compiler/Symtab.py
index 24fc998a..05d38cc7 100644
--- a/Cython/Compiler/Symtab.py
+++ b/Cython/Compiler/Symtab.py
@@ -20,10 +20,12 @@ possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match
 nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match
 
 class BufferAux:
-    def __init__(self, buffer_info_var, stridevars, shapevars):
+    def __init__(self, buffer_info_var, stridevars, shapevars, tschecker):
         self.buffer_info_var = buffer_info_var
         self.stridevars = stridevars
         self.shapevars = shapevars
+        self.tschecker = tschecker
+        
     def __repr__(self):
         return "<BufferAux %r>" % self.__dict__
 
diff --git a/Cython/Compiler/Visitor.py b/Cython/Compiler/Visitor.py
index 31b6386b..52fe2c2b 100644
--- a/Cython/Compiler/Visitor.py
+++ b/Cython/Compiler/Visitor.py
@@ -181,10 +181,14 @@ def replace_node(ptr, value):
         getattr(parent, attrname)[listidx] = value
 
 tmpnamectr = 0
-def temp_name_handle(description):
+def temp_name_handle(description=None):
     global tmpnamectr
     tmpnamectr += 1
-    return EncodedString(Naming.temp_prefix + u"%d_%s" % (tmpnamectr, description))
+    if description is not None:
+        name = u"%d_%s" % (tmpnamectr, description)
+    else:
+        name = u"%d" % tmpnamectr
+    return EncodedString(Naming.temp_prefix + name)
 
 def get_temp_name_handle_desc(handle):
     if not handle.startswith(u"__cyt_"):
@@ -198,7 +202,7 @@ class PrintTree(TreeVisitor):
     Subclass and override repr_of to provide more information
     about nodes. """
     def __init__(self):
-        Transform.__init__(self)
+        TreeVisitor.__init__(self)
         self._indent = ""
 
     def indent(self):
@@ -208,6 +212,7 @@ class PrintTree(TreeVisitor):
 
     def __call__(self, tree, phase=None):
         print("Parse tree dump at phase '%s'" % phase)
+        self.visit(tree)
 
     # Don't do anything about process_list, the defaults gives
     # nice-looking name[idx] nodes which will visually appear
diff --git a/Includes/__cython__.pxd b/Includes/__cython__.pxd
index 39a83e54..c703e6f2 100644
--- a/Includes/__cython__.pxd
+++ b/Includes/__cython__.pxd
@@ -19,5 +19,10 @@ cdef extern from "Python.h":
     int PyObject_GetBuffer(PyObject* obj, Py_buffer* view, int flags) except -1
     void PyObject_ReleaseBuffer(PyObject* obj, Py_buffer* view)
 
+    void PyErr_Format(int, char*, ...)
+
+    enum:
+        PyExc_TypeError
+
 #                  int PyObject_GetBuffer(PyObject *obj, Py_buffer *view,
 #                       int flags)