implemented 'print >> stream'
authorStefan Behnel <scoder@users.berlios.de>
Tue, 9 Mar 2010 08:29:47 +0000 (09:29 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Tue, 9 Mar 2010 08:29:47 +0000 (09:29 +0100)
Cython/Compiler/Nodes.py
Cython/Compiler/Parsing.py
tests/run/print.pyx

index 89acf9179bcef87f3c877cfa1e98e74c2cb19af5..cbae103906fc7395d138c3f4fe884a56c95e9504 100644 (file)
@@ -3447,11 +3447,15 @@ class PrintStatNode(StatNode):
     #  print statement
     #
     #  arg_tuple         TupleNode
+    #  stream            ExprNode or None (stdout)
     #  append_newline    boolean
 
-    child_attrs = ["arg_tuple"]
+    child_attrs = ["arg_tuple", "stream"]
 
     def analyse_expressions(self, env):
+        if self.stream:
+            self.stream.analyse_expressions(env)
+            self.stream = self.stream.coerce_to_pyobject(env)
         self.arg_tuple.analyse_expressions(env)
         self.arg_tuple = self.arg_tuple.coerce_to_pyobject(env)
         env.use_utility_code(printing_utility_code)
@@ -3462,12 +3466,18 @@ class PrintStatNode(StatNode):
     gil_message = "Python print statement"
 
     def generate_execution_code(self, code):
+        if self.stream:
+            self.stream.generate_evaluation_code(code)
+            stream_result = self.stream.py_result()
+        else:
+            stream_result = '0'
         if len(self.arg_tuple.args) == 1 and self.append_newline:
             arg = self.arg_tuple.args[0]
             arg.generate_evaluation_code(code)
             
             code.putln(
-                "if (__Pyx_PrintOne(%s) < 0) %s" % (
+                "if (__Pyx_PrintOne(%s, %s) < 0) %s" % (
+                    stream_result,
                     arg.py_result(),
                     code.error_goto(self.pos)))
             arg.generate_disposal_code(code)
@@ -3475,14 +3485,21 @@ class PrintStatNode(StatNode):
         else:
             self.arg_tuple.generate_evaluation_code(code)
             code.putln(
-                "if (__Pyx_Print(%s, %d) < 0) %s" % (
+                "if (__Pyx_Print(%s, %s, %d) < 0) %s" % (
+                    stream_result,
                     self.arg_tuple.py_result(),
                     self.append_newline,
                     code.error_goto(self.pos)))
             self.arg_tuple.generate_disposal_code(code)
             self.arg_tuple.free_temps(code)
 
+        if self.stream:
+            self.stream.generate_disposal_code(code)
+            self.stream.free_temps(code)
+
     def annotate(self, code):
+        if self.stream:
+            self.stream.annotate(code)
         self.arg_tuple.annotate(code)
 
 
@@ -5028,7 +5045,7 @@ else:
 
 printing_utility_code = UtilityCode(
 proto = """
-static int __Pyx_Print(PyObject *, int); /*proto*/
+static int __Pyx_Print(PyObject*, PyObject *, int); /*proto*/
 #if PY_MAJOR_VERSION >= 3
 static PyObject* %s = 0;
 static PyObject* %s = 0;
@@ -5044,13 +5061,14 @@ static PyObject *__Pyx_GetStdout(void) {
     return f;
 }
 
-static int __Pyx_Print(PyObject *arg_tuple, int newline) {
-    PyObject *f;
+static int __Pyx_Print(PyObject* f, PyObject *arg_tuple, int newline) {
     PyObject* v;
     int i;
 
-    if (!(f = __Pyx_GetStdout()))
-        return -1;
+    if (!f) {
+        if (!(f = __Pyx_GetStdout()))
+            return -1;
+    }
     for (i=0; i < PyTuple_GET_SIZE(arg_tuple); i++) {
         if (PyFile_SoftSpace(f, 1)) {
             if (PyFile_WriteString(" ", f) < 0)
@@ -5078,7 +5096,7 @@ static int __Pyx_Print(PyObject *arg_tuple, int newline) {
 
 #else /* Python 3 has a print function */
 
-static int __Pyx_Print(PyObject *arg_tuple, int newline) {
+static int __Pyx_Print(PyObject* stream, PyObject *arg_tuple, int newline) {
     PyObject* kwargs = 0;
     PyObject* result = 0;
     PyObject* end_string;
@@ -5087,27 +5105,43 @@ static int __Pyx_Print(PyObject *arg_tuple, int newline) {
         if (!%(PRINT_FUNCTION)s)
             return -1;
     }
+    if (stream) {
+        kwargs = PyDict_New();
+        if (unlikely(!kwargs))
+            return -1;
+        if (unlikely(PyDict_SetItemString(kwargs, "file", stream) < 0))
+            goto bad;
+        }
+    }
     if (!newline) {
-        if (!%(PRINT_KWARGS)s) {
+        if (!kwargs)
+            kwargs = %(PRINT_KWARGS)s;
+        if (!kwargs) {
             %(PRINT_KWARGS)s = PyDict_New();
-            if (!%(PRINT_KWARGS)s)
-                return -1;
-            end_string = PyUnicode_FromStringAndSize(" ", 1);
-            if (!end_string)
+            if unlikely((!%(PRINT_KWARGS)s))
                 return -1;
-            if (PyDict_SetItemString(%(PRINT_KWARGS)s, "end", end_string) < 0) {
-                Py_DECREF(end_string);
-                return -1;
-            }
+            kwargs = %(PRINT_KWARGS)s;
+        }
+        end_string = PyUnicode_FromStringAndSize(" ", 1);
+        if (unlikely(!end_string))
+            goto bad;
+        if (PyDict_SetItemString(%(PRINT_KWARGS)s, "end", end_string) < 0) {
             Py_DECREF(end_string);
+            goto bad;
         }
-        kwargs = %(PRINT_KWARGS)s;
+        Py_DECREF(end_string);
     }
     result = PyObject_Call(%(PRINT_FUNCTION)s, arg_tuple, kwargs);
+    if (unlikely(kwargs) && (kwargs != %(PRINT_FUNCTION)s))
+        Py_DECREF(kwargs);
     if (!result)
         return -1;
     Py_DECREF(result);
     return 0;
+bad:
+    if (kwargs != %(PRINT_FUNCTION)s)
+        Py_XDECREF(kwargs);
+    return -1;
 }
 
 #endif
@@ -5119,15 +5153,16 @@ static int __Pyx_Print(PyObject *arg_tuple, int newline) {
 
 printing_one_utility_code = UtilityCode(
 proto = """
-static int __Pyx_PrintOne(PyObject *o); /*proto*/
+static int __Pyx_PrintOne(PyObject* stream, PyObject *o); /*proto*/
 """,
 impl = r"""
 #if PY_MAJOR_VERSION < 3
 
-static int __Pyx_PrintOne(PyObject *o) {
-    PyObject *f;
-    if (!(f = __Pyx_GetStdout()))
-        return -1;
+static int __Pyx_PrintOne(PyObject* f, PyObject *o) {
+    if (!f) {
+        if (!(f = __Pyx_GetStdout()))
+            return -1;
+    }
     if (PyFile_SoftSpace(f, 0)) {
         if (PyFile_WriteString(" ", f) < 0)
             return -1;
@@ -5139,19 +5174,19 @@ static int __Pyx_PrintOne(PyObject *o) {
     return 0;
     /* the line below is just to avoid compiler
      * compiler warnings about unused functions */
-    return __Pyx_Print(NULL, 0);
+    return __Pyx_Print(f, NULL, 0);
 }
 
 #else /* Python 3 has a print function */
 
-static int __Pyx_PrintOne(PyObject *o) {
+static int __Pyx_PrintOne(PyObject* stream, PyObject *o) {
     int res;
     PyObject* arg_tuple = PyTuple_New(1);
     if (unlikely(!arg_tuple))
         return -1;
     Py_INCREF(o);
     PyTuple_SET_ITEM(arg_tuple, 0, o);
-    res = __Pyx_Print(arg_tuple, 1);
+    res = __Pyx_Print(stream, arg_tuple, 1);
     Py_DECREF(arg_tuple);
     return res;
 }
index c97d700c8eb4ccf459910729dc4789b3c86063c3..ef26ca0b6a96774cbbf0d365825747c69e89eb58 100644 (file)
@@ -929,11 +929,17 @@ def p_expression_or_assignment(s):
 def p_print_statement(s):
     # s.sy == 'print'
     pos = s.position()
+    ends_with_comma = 0
     s.next()
     if s.sy == '>>':
-        s.error("'print >>' not yet implemented")
+        s.next()
+        stream = p_simple_expr(s)
+        if s.sy == ',':
+            s.next()
+            ends_with_comma = s.sy in ('NEWLINE', 'EOF')
+    else:
+        stream = None
     args = []
-    ends_with_comma = 0
     if s.sy not in ('NEWLINE', 'EOF'):
         args.append(p_simple_expr(s))
         while s.sy == ',':
@@ -944,7 +950,8 @@ def p_print_statement(s):
             args.append(p_simple_expr(s))
     arg_tuple = ExprNodes.TupleNode(pos, args = args)
     return Nodes.PrintStatNode(pos,
-        arg_tuple = arg_tuple, append_newline = not ends_with_comma)
+        arg_tuple = arg_tuple, stream = stream,
+        append_newline = not ends_with_comma)
 
 def p_exec_statement(s):
     # s.sy == 'exec'
index 82ef95cfd301e57107007e462140626fb4e1a26f..f1b570a1bae15e6cb1f55bdf6d9fec938761b4b2 100644 (file)
@@ -14,3 +14,29 @@ def f(a, b):
     print a, b
     print a, b,
     print 42, u"spam"
+
+
+try:
+    from StringIO import StringIO
+except ImportError:
+    from io import StringIO
+
+def s(stream, a, b):
+    """
+    >>> stream = StringIO()
+    >>> s(stream, 1, 'test')
+    >>> print(stream.getvalue())
+    <BLANKLINE>
+    1
+    1 test
+    1 test
+    1 test 42 spam
+    <BLANKLINE>
+    """
+    print >> stream
+    print >> stream, a
+    print >> stream, a,
+    print >> stream, b
+    print >> stream, a, b
+    print >> stream, a, b,
+    print >> stream, 42, u"spam"