Adjust Hooke internals to use unsafe YAML (not just builtin Python types).
[hooke.git] / hooke / util / yaml.py
index f86e3b181d8dead69fbd8c355225820100d156b1..e7f99806a49c6382716f49764d8ac72e3d1092f1 100644 (file)
@@ -17,25 +17,47 @@ null
 The default behavior is to crash.
 
 >>> yaml.Dumper.yaml_representers.pop(numpy.ndarray)  # doctest: +ELLIPSIS
-<function ndarray_representer at 0x...>
+<function none_representer at 0x...>
 >>> print yaml.dump(a)
-Traceback (most recent call last):
-  ...
-    if data in [None, ()]:
-TypeError: data type not understood
+!!python/object/apply:numpy.core.multiarray._reconstruct
+args:
+- !!python/name:numpy.ndarray ''
+- !!python/tuple [0]
+- b
+state: !!python/tuple
+- 1
+- !!python/tuple [3]
+- null
+- false
+- "\\x01\\0\\0\\0\\x02\\0\\0\\0\\x03\\0\\0\\0"
+<BLANKLINE>
+
+Hmm, at one point that crashed like this::
+
+    Traceback (most recent call last):
+      ...
+        if data in [None, ()]:
+    TypeError: data type not understood
+
+Must be because of the other representers I've loaded since.
 
 Restore the representer for future tests.
 
->>> yaml.add_representer(numpy.ndarray, ndarray_representer)
+>>> yaml.add_representer(numpy.ndarray, none_representer)
 """
 
 from __future__ import absolute_import
+import copy_reg
 import sys
+import types
 
 import numpy
-import yaml #from yaml.representer import Representer
+import yaml
+import yaml.constructor
+import yaml.representer
 
-from ..curve import Data
+from ..curve import Data, Curve
+from ..playlist import FilePlaylist
 
 
 if False: # YAML dump debugging code
@@ -95,10 +117,121 @@ yaml.add_representer(numpy.float64, float_representer)
 
 def data_representer(dumper, data):
     info = dict(data.info)
-    print 'KEYS', info.keys()
     for key in info.keys():
-        if key.startswith('raw '):# or 'peak' in key: #or key not in ['surface deflection offset (m)', 'z piezo sensitivity (m/V)', 'z piezo scan (V/bit)', 'z piezo gain', 'deflection range (V)', 'z piezo range (V)', 'spring constant (N/m)', 'z piezo scan size (V)', 'deflection sensitivity (V/bit)', 'z piezo ramp size (V/bit)', 'surface deflection offset', 'z piezo offset (V)', 'name']:
+        if key.startswith('raw '):
             del(info[key])
-    print 'AAAS', info.keys()
     return dumper.represent_mapping(u'!hooke.curve.DataInfo', info)
 yaml.add_representer(Data, data_representer)
+
+def object_representer(dumper, data):
+    cls = type(data)
+    if cls in copy_reg.dispatch_table:
+        reduce = copy_reg.dispatch_table[cls](data)
+    elif hasattr(data, '__reduce_ex__'):
+        reduce = data.__reduce_ex__(2)
+    elif hasattr(data, '__reduce__'):
+        reduce = data.__reduce__()
+    else:
+        raise RepresenterError("cannot represent object: %r" % data)
+    reduce = (list(reduce)+[None]*5)[:5]
+    function, args, state, listitems, dictitems = reduce
+    args = list(args)
+    if state is None:
+        state = {}
+    if isinstance(state, dict) and '_default_attrs' in state:
+        for key in state['_default_attrs']:
+            if key in state and state[key] == state['_default_attrs'][key]:
+                del(state[key])
+        del(state['_default_attrs'])
+    if listitems is not None:
+        listitems = list(listitems)
+    if dictitems is not None:
+        dictitems = dict(dictitems)
+    if function.__name__ == '__newobj__':
+        function = args[0]
+        args = args[1:]
+        tag = u'tag:yaml.org,2002:python/object/new:'
+        newobj = True
+    else:
+        tag = u'tag:yaml.org,2002:python/object/apply:'
+        newobj = False
+    function_name = u'%s.%s' % (function.__module__, function.__name__)
+    if not args and not listitems and not dictitems \
+            and isinstance(state, dict) and newobj:
+        return dumper.represent_mapping(
+                u'tag:yaml.org,2002:python/object:'+function_name, state)
+    if not listitems and not dictitems  \
+            and isinstance(state, dict) and not state:
+        return dumper.represent_sequence(tag+function_name, args)
+    value = {}
+    if args:
+        value['args'] = args
+    if state or not isinstance(state, dict):
+        value['state'] = state
+    if listitems:
+        value['listitems'] = listitems
+    if dictitems:
+        value['dictitems'] = dictitems
+    return dumper.represent_mapping(tag+function_name, value)
+yaml.add_representer(FilePlaylist, object_representer)
+yaml.add_representer(Curve, object_representer)
+
+
+# Monkey patch PyYAML bug 159.
+#   Yaml failed to restore loops in objects when __setstate__ is defined
+#   http://pyyaml.org/ticket/159
+# With viktor.x.voroshylo@jpmchase.com's patch
+def construct_object(self, node, deep=False):
+    if deep:
+        old_deep = self.deep_construct
+        self.deep_construct = True
+    if node in self.constructed_objects:
+        return self.constructed_objects[node]
+    if node in self.recursive_objects:
+        obj = self.recursive_objects[node]
+        if obj is None :
+            raise ConstructorError(None, None,
+                 "found unconstructable recursive node", node.start_mark)
+        return obj
+    self.recursive_objects[node] = None
+    constructor = None
+    tag_suffix = None
+    if node.tag in self.yaml_constructors:
+        constructor = self.yaml_constructors[node.tag]
+    else:
+        for tag_prefix in self.yaml_multi_constructors:
+            if node.tag.startswith(tag_prefix):
+                tag_suffix = node.tag[len(tag_prefix):]
+                constructor = self.yaml_multi_constructors[tag_prefix]
+                break
+        else:
+            if None in self.yaml_multi_constructors:
+                tag_suffix = node.tag
+                constructor = self.yaml_multi_constructors[None]
+            elif None in self.yaml_constructors:
+                constructor = self.yaml_constructors[None]
+            elif isinstance(node, ScalarNode):
+                constructor = self.__class__.construct_scalar
+            elif isinstance(node, SequenceNode):
+                constructor = self.__class__.construct_sequence
+            elif isinstance(node, MappingNode):
+                constructor = self.__class__.construct_mapping
+    if tag_suffix is None:
+        data = constructor(self, node)
+    else:
+        data = constructor(self, tag_suffix, node)
+    if isinstance(data, types.GeneratorType):
+        generator = data
+        data = generator.next()
+        if self.deep_construct:
+            self.recursive_objects[node] = data
+            for dummy in generator:
+                pass
+        else:
+            self.state_generators.append(generator)
+    self.constructed_objects[node] = data
+    del self.recursive_objects[node]
+    if deep:
+        self.deep_construct = old_deep
+    return data
+yaml.constructor.BaseConstructor.construct_object = construct_object