Split struct and util modules out of binarywave.
[igor.git] / igor / struct.py
1 # Copyright
2
3 "Structure and Field classes for declaring structures "
4
5 from __future__ import absolute_import
6 import struct as _struct
7
8 import numpy as _numpy
9
10
11 _buffer = buffer  # save builtin buffer for clobbered situations
12
13
14 class Field (object):
15     """Represent a Structure field.
16
17     See Also
18     --------
19     Structure
20     """
21     def __init__(self, format, name, default=None, help=None, count=1):
22         self.format = format # See the struct documentation
23         self.name = name
24         self.default = None
25         self.help = help
26         self.count = count
27         self.total_count = _numpy.prod(count)
28
29
30 class Structure (_struct.Struct):
31     """Represent a C structure.
32
33     A convenient wrapper around struct.Struct that uses Fields and
34     adds dict-handling methods for transparent name assignment.
35
36     See Also
37     --------
38     Field
39
40     Examples
41     --------
42
43     Represent the C structure::
44
45         struct thing {
46           short version;
47           long size[3];
48         }
49
50     As
51
52     >>> import array
53     >>> from pprint import pprint
54     >>> thing = Structure(name='thing',
55     ...     fields=[Field('h', 'version'), Field('l', 'size', count=3)])
56     >>> thing.set_byte_order('>')
57     >>> b = array.array('b', range(2+4*3))
58     >>> d = thing.unpack_dict_from(buffer=b)
59     >>> pprint(d)
60     {'size': array([ 33752069, 101124105, 168496141]), 'version': 1}
61     >>> [hex(x) for x in d['size']]
62     ['0x2030405L', '0x6070809L', '0xa0b0c0dL']
63
64     You can even get fancy with multi-dimensional arrays.
65
66     >>> thing = Structure(name='thing',
67     ...     fields=[Field('h', 'version'), Field('l', 'size', count=(3,2))])
68     >>> thing.set_byte_order('>')
69     >>> b = array.array('b', range(2+4*3*2))
70     >>> d = thing.unpack_dict_from(buffer=b)
71     >>> d['size'].shape
72     (3, 2)
73     >>> pprint(d)
74     {'size': array([[ 33752069, 101124105],
75            [168496141, 235868177],
76            [303240213, 370612249]]),
77      'version': 1}
78     """
79     def __init__(self, name, fields, byte_order='='):
80         # '=' for native byte order, standard size and alignment
81         # See http://docs.python.org/library/struct for details
82         self.name = name
83         self.fields = fields
84         self.set_byte_order(byte_order)
85
86     def __str__(self):
87         return self.name
88
89     def set_byte_order(self, byte_order):
90         """Allow changing the format byte_order on the fly.
91         """
92         if (hasattr(self, 'format') and self.format != None
93             and self.format.startswith(byte_order)):
94             return  # no need to change anything
95         format = []
96         for field in self.fields:
97             format.extend([field.format]*field.total_count)
98         super(Structure, self).__init__(
99             format=byte_order+''.join(format).replace('P', 'L'))
100
101     def _flatten_args(self, args):
102         # handle Field.count > 0
103         flat_args = []
104         for a,f in zip(args, self.fields):
105             if f.total_count > 1:
106                 flat_args.extend(a)
107             else:
108                 flat_args.append(a)
109         return flat_args
110
111     def _unflatten_args(self, args):
112         # handle Field.count > 0
113         unflat_args = []
114         i = 0
115         for f in self.fields:
116             if f.total_count > 1:
117                 data = _numpy.array(args[i:i+f.total_count])
118                 data = data.reshape(f.count)
119                 unflat_args.append(data)
120             else:
121                 unflat_args.append(args[i])
122             i += f.total_count
123         return unflat_args
124         
125     def pack(self, *args):
126         return super(Structure, self)(*self._flatten_args(args))
127
128     def pack_into(self, buffer, offset, *args):
129         return super(Structure, self).pack_into(
130             buffer, offset, *self._flatten_args(args))
131
132     def _clean_dict(self, dict):
133         for f in self.fields:
134             if f.name not in dict:
135                 if f.default != None:
136                     dict[f.name] = f.default
137                 else:
138                     raise ValueError('{} field not set for {}'.format(
139                             f.name, self.__class__.__name__))
140         return dict
141
142     def pack_dict(self, dict):
143         dict = self._clean_dict(dict)
144         return self.pack(*[dict[f.name] for f in self.fields])
145
146     def pack_dict_into(self, buffer, offset, dict={}):
147         dict = self._clean_dict(dict)
148         return self.pack_into(buffer, offset,
149                               *[dict[f.name] for f in self.fields])
150
151     def unpack(self, string):
152         return self._unflatten_args(
153             super(Structure, self).unpack(string))
154
155     def unpack_from(self, buffer, offset=0):
156         try:
157             args = super(Structure, self).unpack_from(buffer, offset)
158         except _struct.error as e:
159             if not self.name in ('WaveHeader2', 'WaveHeader5'):
160                 raise
161             # HACK!  For WaveHeader5, when npnts is 0, wData is
162             # optional.  If we couldn't unpack the structure, fill in
163             # wData with zeros and try again, asserting that npnts is
164             # zero.
165             if len(buffer) - offset < self.size:
166                 # missing wData?  Pad with zeros
167                 buffer += _buffer('\x00'*(self.size + offset - len(buffer)))
168             args = super(Structure, self).unpack_from(buffer, offset)
169             unpacked = self._unflatten_args(args)
170             data = dict(zip([f.name for f in self.fields],
171                             unpacked))
172             assert data['npnts'] == 0, data['npnts']
173         return self._unflatten_args(args)
174
175     def unpack_dict(self, string):
176         return dict(zip([f.name for f in self.fields],
177                         self.unpack(string)))
178
179     def unpack_dict_from(self, buffer, offset=0):
180         return dict(zip([f.name for f in self.fields],
181                         self.unpack_from(buffer, offset)))