Save dtypes to YAML (now we only drop ndarrays).
[hooke.git] / hooke / util / yaml.py
1 # Copyright
2
3 """Add representers to YAML to support Hooke.
4
5 Without introspection, YAML cannot decide how to save some
6 objects.  By refusing to save these objects, we obviously loose
7 that information, so make sure the things you drop are either
8 stored somewhere else or not important.
9
10 >>> import yaml
11 >>> a = numpy.array([1,2,3])
12 >>> print yaml.dump(a)
13 null
14 ...
15 <BLANKLINE>
16
17 The default behavior is to crash.
18
19 >>> yaml.Dumper.yaml_representers.pop(numpy.ndarray)  # doctest: +ELLIPSIS
20 <function none_representer at 0x...>
21 >>> print yaml.dump(a)
22 !!python/object/apply:numpy.core.multiarray._reconstruct
23 args:
24 - !!python/name:numpy.ndarray ''
25 - !!python/tuple [0]
26 - b
27 state: !!python/tuple
28 - 1
29 - !!python/tuple [3]
30 - null
31 - false
32 - "\\x01\\0\\0\\0\\x02\\0\\0\\0\\x03\\0\\0\\0"
33 <BLANKLINE>
34
35 Hmm, at one point that crashed like this::
36
37     Traceback (most recent call last):
38       ...
39         if data in [None, ()]:
40     TypeError: data type not understood
41
42 Must be because of the other representers I've loaded since.
43
44 Restore the representer for future tests.
45
46 >>> yaml.add_representer(numpy.ndarray, none_representer)
47 """
48
49 from __future__ import absolute_import
50 import copy_reg
51 import sys
52 import types
53
54 import numpy
55 import yaml
56 import yaml.constructor
57 from yaml.constructor import ConstructorError
58 import yaml.representer
59
60 from ..curve import Data, Curve
61 from ..playlist import FilePlaylist
62
63
64 DATA_INFO_TAG = u'!hooke.curve.DataInfo'
65
66
67 if False: # YAML dump debugging code
68     """To help isolate data types etc. that give YAML problems.
69
70     This is usually caused by external C modules (e.g. numpy) that
71     define new types (e.g. numpy.ndarray) which YAML cannot inspect.
72     """
73     def ignore_aliases(data):
74         print data, repr(data), type(data), repr(type(data))
75         sys.stdout.flush()
76         if data in [None, ()]:
77             return True
78         if isinstance(data, (str, unicode, bool, int, float)):
79             return True
80     yaml.representer.SafeRepresenter.ignore_aliases = staticmethod(
81         ignore_aliases)
82 else:
83     # Avoid error with
84     #   numpy.dtype(numpy.int32) in [None, ()]
85     # See
86     #   http://projects.scipy.org/numpy/ticket/1001
87     def ignore_aliases(data):
88         try:
89             if data in [None, ()]:
90                 return True
91             if isinstance(data, (str, unicode, bool, int, float)):
92                 return True
93         except TypeError, e:
94             pass
95     yaml.representer.SafeRepresenter.ignore_aliases = staticmethod(
96         ignore_aliases)
97
98
99 def none_representer(dumper, data):
100     return dumper.represent_none(None)
101 yaml.add_representer(numpy.ndarray, none_representer)
102
103 def bool_representer(dumper, data):
104     return dumper.represent_bool(data)
105 yaml.add_representer(numpy.bool_, bool_representer)
106
107 def int_representer(dumper, data):
108     return dumper.represent_int(data)
109 yaml.add_representer(numpy.int32, int_representer)
110 yaml.add_representer(numpy.dtype(numpy.int32), int_representer)
111
112 def long_representer(dumper, data):
113     return dumper.represent_long(data)
114 yaml.add_representer(numpy.int64, int_representer)
115
116 def float_representer(dumper, data):
117     return dumper.represent_float(data)
118 yaml.add_representer(numpy.float32, float_representer)
119 yaml.add_representer(numpy.float64, float_representer)
120
121 def data_representer(dumper, data):
122     info = dict(data.info)
123     for key in info.keys():
124         if key.startswith('raw '):
125             del(info[key])
126     return dumper.represent_mapping(DATA_INFO_TAG, info)
127 yaml.add_representer(Data, data_representer)
128
129 def data_constructor(loader, node):
130     info = loader.construct_mapping(node)
131     return Data(shape=(0,0), dtype=numpy.float32, info=info)
132 yaml.add_constructor(DATA_INFO_TAG, data_constructor)
133
134 def object_representer(dumper, data):
135     cls = type(data)
136     if cls in copy_reg.dispatch_table:
137         reduce = copy_reg.dispatch_table[cls](data)
138     elif hasattr(data, '__reduce_ex__'):
139         reduce = data.__reduce_ex__(2)
140     elif hasattr(data, '__reduce__'):
141         reduce = data.__reduce__()
142     else:
143         raise RepresenterError("cannot represent object: %r" % data)
144     reduce = (list(reduce)+[None]*5)[:5]
145     function, args, state, listitems, dictitems = reduce
146     args = list(args)
147     if state is None:
148         state = {}
149     if isinstance(state, dict) and '_default_attrs' in state:
150         for key in state['_default_attrs']:
151             if key in state and state[key] == state['_default_attrs'][key]:
152                 del(state[key])
153         del(state['_default_attrs'])
154     if listitems is not None:
155         listitems = list(listitems)
156     if dictitems is not None:
157         dictitems = dict(dictitems)
158     if function.__name__ == '__newobj__':
159         function = args[0]
160         args = args[1:]
161         tag = u'tag:yaml.org,2002:python/object/new:'
162         newobj = True
163     else:
164         tag = u'tag:yaml.org,2002:python/object/apply:'
165         newobj = False
166     function_name = u'%s.%s' % (function.__module__, function.__name__)
167     if not args and not listitems and not dictitems \
168             and isinstance(state, dict) and newobj:
169         return dumper.represent_mapping(
170                 u'tag:yaml.org,2002:python/object:'+function_name, state)
171     if not listitems and not dictitems  \
172             and isinstance(state, dict) and not state:
173         return dumper.represent_sequence(tag+function_name, args)
174     value = {}
175     if args:
176         value['args'] = args
177     if state or not isinstance(state, dict):
178         value['state'] = state
179     if listitems:
180         value['listitems'] = listitems
181     if dictitems:
182         value['dictitems'] = dictitems
183     return dumper.represent_mapping(tag+function_name, value)
184 yaml.add_representer(FilePlaylist, object_representer)
185 yaml.add_representer(Curve, object_representer)
186
187
188 # Monkey patch PyYAML bug 159.
189 #   Yaml failed to restore loops in objects when __setstate__ is defined
190 #   http://pyyaml.org/ticket/159
191 # With viktor.x.voroshylo@jpmchase.com's patch
192 def construct_object(self, node, deep=False):
193     if deep:
194         old_deep = self.deep_construct
195         self.deep_construct = True
196     if node in self.constructed_objects:
197         return self.constructed_objects[node]
198     if node in self.recursive_objects:
199         obj = self.recursive_objects[node]
200         if obj is None :
201             raise ConstructorError(None, None,
202                  "found unconstructable recursive node", node.start_mark)
203         return obj
204     self.recursive_objects[node] = None
205     constructor = None
206     tag_suffix = None
207     if node.tag in self.yaml_constructors:
208         constructor = self.yaml_constructors[node.tag]
209     else:
210         for tag_prefix in self.yaml_multi_constructors:
211             if node.tag.startswith(tag_prefix):
212                 tag_suffix = node.tag[len(tag_prefix):]
213                 constructor = self.yaml_multi_constructors[tag_prefix]
214                 break
215         else:
216             if None in self.yaml_multi_constructors:
217                 tag_suffix = node.tag
218                 constructor = self.yaml_multi_constructors[None]
219             elif None in self.yaml_constructors:
220                 constructor = self.yaml_constructors[None]
221             elif isinstance(node, ScalarNode):
222                 constructor = self.__class__.construct_scalar
223             elif isinstance(node, SequenceNode):
224                 constructor = self.__class__.construct_sequence
225             elif isinstance(node, MappingNode):
226                 constructor = self.__class__.construct_mapping
227     if tag_suffix is None:
228         data = constructor(self, node)
229     else:
230         data = constructor(self, tag_suffix, node)
231     if isinstance(data, types.GeneratorType):
232         generator = data
233         data = generator.next()
234         if self.deep_construct:
235             self.recursive_objects[node] = data
236             for dummy in generator:
237                 pass
238         else:
239             self.state_generators.append(generator)
240     self.constructed_objects[node] = data
241     del self.recursive_objects[node]
242     if deep:
243         self.deep_construct = old_deep
244     return data
245 yaml.constructor.BaseConstructor.construct_object = construct_object