Main.py has transformation pipelines for pxds; allowing buffer funcs in pxds
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Thu, 31 Jul 2008 11:53:22 +0000 (13:53 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Thu, 31 Jul 2008 11:53:22 +0000 (13:53 +0200)
Cython/Compiler/Errors.py
Cython/Compiler/Main.py
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/Parsing.py
Cython/Includes/numpy.pxd
tests/errors/e_pxdimpl.pyx [new file with mode: 0644]
tests/errors/e_pxdimpl_imported.pxd [new file with mode: 0644]
tests/run/bufaccess.pyx

index 76a699aff25c8297526eec95cde20222372f8991..57ce5334c8ff80e8eea18535c283b93b7670d20a 100644 (file)
@@ -104,7 +104,7 @@ def report_error(err):
 def error(position, message):
     #print "Errors.error:", repr(position), repr(message) ###
     err = CompileError(position, message)
-#    if position is not None: raise Exception(err) # debug
+   # if position is not None: raise Exception(err) # debug
     report_error(err)
     return err
 
index db3f25210cbbdf11f61a567592e54ebcff51586a..2f03afe25490c5f7feecc960e1fad847981ffd74 100644 (file)
@@ -27,6 +27,11 @@ module_name_pattern = re.compile(r"[A-Za-z_][A-Za-z0-9_]*(\.[A-Za-z_][A-Za-z0-9_
 
 verbose = 0
 
+def dumptree(t):
+    # For quick debugging in pipelines
+    print t.dump()
+    return t
+
 class Context:
     #  This class encapsulates the context needed for compiling
     #  one or more Cython implementation files along with their
@@ -42,18 +47,78 @@ class Context:
         #self.modules = {"__builtin__" : BuiltinScope()}
         import Builtin
         self.modules = {"__builtin__" : Builtin.builtin_scope}
-        self.pxds = {}
-        self.pyxs = {}
         self.include_directories = include_directories
         self.future_directives = set()
 
-        import os.path
-
         standard_include_path = os.path.abspath(
             os.path.join(os.path.dirname(__file__), '..', 'Includes'))
         self.include_directories = include_directories + [standard_include_path]
 
+    def create_pipeline(self, pxd):
+        from Visitor import PrintTree
+        from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, PxdPostParse
+        from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
+        from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
+        from Optimize import FlattenInListTransform, SwitchTransform, OptimizeRefcounting
+        from Buffer import IntroduceBufferAuxiliaryVars
+        from ModuleNode import check_c_classes
+
+        if pxd:
+            _check_c_classes = None
+            _specific_post_parse = PxdPostParse(self)
+        else:
+            _check_c_classes = check_c_classes
+            _specific_post_parse = None
+        return [
+            NormalizeTree(self),
+            PostParse(self),
+            _specific_post_parse,
+            FlattenInListTransform(),
+            WithTransform(self),
+            DecoratorTransform(self),
+            AnalyseDeclarationsTransform(self),
+            IntroduceBufferAuxiliaryVars(self),
+            _check_c_classes,
+            AnalyseExpressionsTransform(self),
+            SwitchTransform(),
+            OptimizeRefcounting(self),
+            #        CreateClosureClasses(context),
+            ]
+
+    def create_pyx_pipeline(self, options, result):
+        return [create_parse(self)] + self.create_pipeline(pxd=False) + [
+            create_generate_code(self, options, result)
+            ]
+
+    def create_pxd_pipeline(self, scope, module_name):
+        def parse_pxd(source_desc):
+            tree = self.parse(source_desc, scope, pxd=True,
+                              full_module_name=module_name)
+            tree.scope = scope
+            tree.is_pxd = True
+            return tree
+        return [parse_pxd] + self.create_pipeline(pxd=True)
+
+    def process_pxd(self, source_desc, scope, module_name):
+        pipeline = self.create_pxd_pipeline(scope, module_name)
+        return self.run_pipeline(pipeline, source_desc)
         
+    def nonfatal_error(self, exc):
+        return Errors.report_error(exc)
+
+    def run_pipeline(self, pipeline, source):
+        errors_occurred = False
+        data = source
+        try:
+            for phase in pipeline:
+                if phase is not None:
+                    data = phase(data)
+        except CompileError, err:
+            errors_occurred = True
+            Errors.report_error(err)
+        return (errors_occurred, data)
+
     def find_module(self, module_name, 
             relative_to = None, pos = None, need_pxd = 1):
         # Finds and returns the module scope corresponding to
@@ -106,9 +171,7 @@ class Context:
                     if debug_find_module:
                         print("Context.find_module: Parsing %s" % pxd_pathname)
                     source_desc = FileSourceDescriptor(pxd_pathname)
-                    pxd_tree = self.parse(source_desc, scope, pxd = 1,
-                                          full_module_name = module_name)
-                    pxd_tree.analyse_declarations(scope)
+                    self.process_pxd(source_desc, scope, module_name)
                 except CompileError:
                     pass
         return scope
@@ -330,20 +393,6 @@ class Context:
                     verbose_flag = options.show_version,
                     cplus = options.cplus)
 
-    def nonfatal_error(self, exc):
-        return Errors.report_error(exc)
-
-    def run_pipeline(self, pipeline, source):
-        errors_occurred = False
-        data = source
-        try:
-            for phase in pipeline:
-                data = phase(data)
-        except CompileError, err:
-            errors_occurred = True
-            Errors.report_error(err)
-        return (errors_occurred, data)
-
 def create_parse(context):
     def parse(compsrc):
         source_desc = compsrc.source_desc
@@ -353,6 +402,7 @@ def create_parse(context):
         tree = context.parse(source_desc, scope, pxd = 0, full_module_name = full_module_name)
         tree.compilation_source = compsrc
         tree.scope = scope
+        tree.is_pxd = False
         return tree
     return parse
 
@@ -364,34 +414,6 @@ def create_generate_code(context, options, result):
         return result
     return generate_code
 
-def create_default_pipeline(context, options, result):
-    from Visitor import PrintTree
-    from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse
-    from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
-    from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
-    from Optimize import FlattenInListTransform, SwitchTransform, OptimizeRefcounting
-    from Buffer import IntroduceBufferAuxiliaryVars
-    from ModuleNode import check_c_classes
-    def printit(x): print x.dump()
-    return [
-        create_parse(context),
-#        printit,
-        NormalizeTree(context),
-        PostParse(context),
-        FlattenInListTransform(),
-        WithTransform(context),
-        DecoratorTransform(context),
-        AnalyseDeclarationsTransform(context),
-        IntroduceBufferAuxiliaryVars(context),
-        check_c_classes,
-        AnalyseExpressionsTransform(context),
-#        BufferTransform(context),
-        SwitchTransform(),
-        OptimizeRefcounting(context),
-#        CreateClosureClasses(context),
-        create_generate_code(context, options, result)
-    ]
-
 def create_default_resultobj(compilation_source, options):
     result = CompilationResult()
     result.main_source_file = compilation_source.source_desc.filename
@@ -428,7 +450,7 @@ def run_pipeline(source, options, full_module_name = None):
     result = create_default_resultobj(source, options)
     
     # Get pipeline
-    pipeline = create_default_pipeline(context, options, result)
+    pipeline = context.create_pyx_pipeline(options, result)
 
     context.setup_errors(options)
     errors_occurred, enddata = context.run_pipeline(pipeline, source)
index 1fc7778e88dd7b0433752f411c53c9f405b9d00b..b9a3b425afa5e11e69ac53827d06ebee4c78ccb6 100644 (file)
@@ -213,6 +213,42 @@ class PostParse(CythonTransform):
         node.keyword_args = None
         return node
 
+class PxdPostParse(CythonTransform):
+    """
+    Basic interpretation/validity checking that should only be
+    done on pxd trees.
+    """
+    ERR_FUNCDEF_NOT_ALLOWED = 'function definition not allowed here'
+
+    def __call__(self, node):
+        self.scope_type = 'pxd'
+        return super(PxdPostParse, self).__call__(node)
+
+    def visit_CClassDefNode(self, node):
+        old = self.scope_type
+        self.scope_type = 'cclass'
+        self.visitchildren(node)
+        self.scope_type = old
+        return node
+
+    def visit_FuncDefNode(self, node):
+        # FuncDefNode always come with an implementation (without
+        # an imp they are CVarDefNodes..)
+        ok = False
+
+        if (isinstance(node, DefNode) and self.scope_type == 'cclass'
+            and node.name in ('__getbuffer__', '__releasebuffer__')):
+            ok = True
+
+
+        if not ok:
+            self.context.nonfatal_error(PostParseError(node.pos,
+                self.ERR_FUNCDEF_NOT_ALLOWED))
+            return None
+        else:
+            return node
+
+
 class WithTransform(CythonTransform):
 
     # EXCINFO is manually set to a variable that contains
index e2b370944474ccbfc5929f472aff314eb08c2652..1b3febee332a9e07552232b5e455d2f5332ed7d3 100644 (file)
@@ -1398,7 +1398,7 @@ def p_statement(s, ctx, first_statement = 0):
             if ctx.api:
                 error(s.pos, "'api' not allowed with this statement")
             elif s.sy == 'def':
-                if ctx.level not in ('module', 'class', 'c_class', 'property'):
+                if ctx.level not in ('module', 'class', 'c_class', 'c_class_pxd', 'property'):
                     s.error('def statement not allowed here')
                 s.level = ctx.level
                 return p_def_statement(s)
index 6c95c3dd886540791405c8faea4cdc647b806bf3..5a9db9cda1fd9412b0a1c209be213ef3e8fa80ec 100644 (file)
@@ -2,16 +2,25 @@ cdef extern from "Python.h":
     ctypedef int Py_intptr_t
     
 cdef extern from "numpy/arrayobject.h":
+    ctypedef void PyArrayObject
+    
     ctypedef class numpy.ndarray [object PyArrayObject]:
-        cdef char *data
-        cdef int nd
-        cdef Py_intptr_t *dimensions
-        cdef Py_intptr_t *strides
-        cdef object base
-        # descr not implemented yet here...
-        cdef int flags
-        cdef int itemsize
-        cdef object weakreflist
+        cdef:
+            char *data
+            int nd
+            Py_intptr_t *dimensions
+            Py_intptr_t *strides
+            object base
+            # descr not implemented yet here...
+            int flags
+            int itemsize
+            object weakreflist
+
+        def __getbuffer__(self, Py_buffer* info, int flags):
+       
+            pass
+
+
 
     ctypedef unsigned int npy_uint8
     ctypedef unsigned int npy_uint16
@@ -27,4 +36,5 @@ cdef extern from "numpy/arrayobject.h":
     ctypedef float        npy_float96
     ctypedef float        npy_float128
 
+
 ctypedef npy_int64 Tint64
diff --git a/tests/errors/e_pxdimpl.pyx b/tests/errors/e_pxdimpl.pyx
new file mode 100644 (file)
index 0000000..b8452b6
--- /dev/null
@@ -0,0 +1,7 @@
+cimport e_pxdimpl_imported
+
+_ERRORS = """
+6:4: function definition not allowed here
+18:4: function definition not allowed here
+23:8: function definition not allowed here
+"""
diff --git a/tests/errors/e_pxdimpl_imported.pxd b/tests/errors/e_pxdimpl_imported.pxd
new file mode 100644 (file)
index 0000000..e2700cf
--- /dev/null
@@ -0,0 +1,25 @@
+
+cdef class A:
+    cdef int test(self)
+
+    # Should give error:
+    def somefunc(self):
+        pass
+
+    # While this should *not* be an error...:
+    def __getbuffer__(self, Py_buffer* info, int flags):
+        pass
+    # This neither:
+    def __releasebuffer__(self, Py_buffer* info):
+        pass
+
+    # Terminate with an error to be sure the compiler is
+    # not terminating prior to previous errors
+    def terminate(self):
+        pass
+
+cdef extern from "foo.h":
+    cdef class pxdimpl.B [object MyB]:
+        def otherfunc(self):
+            pass
+
index 8cd7217d67e056309e2a03c99542b5a4efdbdb43..fab8b85599cc378569b69be8680b46cad8ac150d 100644 (file)
@@ -11,7 +11,7 @@
 
 cimport stdlib
 cimport python_buffer
-# Add all test_X function docstrings as unit tests
+cimport stdio
 
 __test__ = {}
 setup_string = """
@@ -571,8 +571,6 @@ available_flags = (
     ('WRITABLE', python_buffer.PyBUF_WRITABLE)
 )
 
-cimport stdio
-
 cdef class MockBuffer:
     cdef object format
     cdef void* buffer