Ran update-copyright.py.
[h5config.git] / h5config / storage / hdf5.py
1 # Copyright (C) 2011-2012 W. Trevor King <wking@tremily.us>
2 #
3 # This file is part of h5config.
4 #
5 # h5config is free software: you can redistribute it and/or modify it under the
6 # terms of the GNU General Public License as published by the Free Software
7 # Foundation, either version 3 of the License, or (at your option) any later
8 # version.
9 #
10 # h5config is distributed in the hope that it will be useful, but WITHOUT ANY
11 # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
12 # A PARTICULAR PURPOSE.  See the GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License along with
15 # h5config.  If not, see <http://www.gnu.org/licenses/>.
16
17 """HDF5 backend implementation
18 """
19
20 import os.path as _os_path
21 import sys as _sys
22
23 import h5py as _h5py
24 import numpy as _numpy
25
26 from .. import LOG as _LOG
27 from .. import config as _config
28 from . import FileStorage as _FileStorage
29 from . import is_string as _is_string
30
31
32 def pprint_HDF5(*args, **kwargs):
33     print(pformat_HDF5(*args, **kwargs))
34
35 def pformat_HDF5(filename, group='/'):
36     try:
37         with _h5py.File(filename, 'r') as f:
38             cwg = f[group]
39             ret = '\n'.join(_pformat_hdf5(cwg))
40     except IOError as e:
41         if 'unable to open' in e.message:
42             if _os_path.getsize(filename) == 0:
43                 return 'EMPTY'
44             return None
45         raise
46     return ret
47
48 def _pformat_hdf5(cwg, depth=0):
49     lines = []
50     lines.append('  '*depth + cwg.name)
51     depth += 1
52     for key,value in cwg.items():
53         if isinstance(value, _h5py.Group):
54             lines.extend(_pformat_hdf5(value, depth))
55         elif isinstance(value, _h5py.Dataset):
56             lines.append('  '*depth + str(value))
57             lines.append('  '*(depth+1) + str(value[...]))
58         else:
59             lines.append('  '*depth + str(value))
60     return lines
61
62 def h5_create_group(cwg, path, force=False):
63     "Create the group where the settings are stored (if necessary)."
64     if path == '/':
65         return cwg
66     gpath = ['']
67     for group in path.strip('/').split('/'):
68         gpath.append(group)
69         if group not in cwg.keys():
70             _LOG.debug('creating group {} in {}'.format(
71                     '/'.join(gpath), cwg.file))
72             cwg.create_group(group)
73         _cwg = cwg[group]
74         if isinstance(_cwg, _h5py.Dataset):
75             if force:
76                 _LOG.info('overwrite {} in {} ({}) with a group'.format(
77                         '/'.join(gpath), _cwg.file, _cwg))
78                 del cwg[group]
79                 _cwg = cwg.create_group(group)
80             else:
81                 raise ValueError(_cwg)
82         cwg = _cwg
83     return cwg
84
85
86 class HDF5_Storage (_FileStorage):
87     """Back a `Config` class with an HDF5 file.
88
89     The `.save` and `.load` methods have an optional `group` argument
90     that allows you to save and load settings from an externally
91     opened HDF5 file.  This can make it easier to stash several
92     related `Config` classes in a single file.  For example
93
94     >>> import os
95     >>> import tempfile
96     >>> from ..test import TestConfig
97     >>> fd,filename = tempfile.mkstemp(
98     ...     suffix='.'+HDF5_Storage.extension, prefix='pypiezo-')
99     >>> os.close(fd)
100
101     >>> f = _h5py.File(filename, 'a')
102     >>> c = TestConfig(storage=HDF5_Storage(
103     ...     filename='untouched_file.h5', group='/untouched/group'))
104     >>> c['alive'] = True
105     >>> group = f.create_group('base')
106     >>> c.save(group=group)
107     >>> pprint_HDF5(filename)  # doctest: +REPORT_UDIFF, +ELLIPSIS
108     /
109       /base
110         <HDF5 dataset "age": shape (), type "<f8">
111           1.3
112         <HDF5 dataset "alive": shape (), type "|b1">
113           True
114         <HDF5 dataset "bids": shape (3,), type "<f8">
115           [ 5.4  3.2  1. ]
116         <HDF5 dataset "children": shape (), type "|S1">
117     <BLANKLINE>
118         <HDF5 dataset "claws": shape (2,), type "<i8">
119           [1 2]
120         <HDF5 dataset "daisies": shape (), type "<i...">
121           13
122         <HDF5 dataset "name": shape (), type "|S1">
123     <BLANKLINE>
124         <HDF5 dataset "species": shape (), type "|S14">
125           Norwegian Blue
126         <HDF5 dataset "spouse": shape (), type "|S1">
127     <BLANKLINE>
128         <HDF5 dataset "words": shape (2,), type "|S7">
129           ['cracker' 'wants']
130     >>> c.clear()
131     >>> c['alive']
132     False
133     >>> c.load(group=group)
134     >>> c['alive']
135     True
136
137     >>> f.close()
138     >>> os.remove(filename)
139     """
140     extension = 'h5'
141
142     def __init__(self, group='/', **kwargs):
143         super(HDF5_Storage, self).__init__(**kwargs)
144         if isinstance(group, _h5py.Group):
145             self._file_checked = True
146         else:
147             assert group.startswith('/'), group
148             if not group.endswith('/'):
149                 group += '/'
150             self._file_checked = False
151         self.group = group
152
153     def _check_file(self):
154         if self._file_checked:
155             return
156         self._setup_file()
157         self._file_checked = True
158
159     def _setup_file(self):
160         self._create_basedir(filename=self._filename)
161         with _h5py.File(self._filename, 'a') as f:
162             cwg = f  # current working group
163             h5_create_group(cwg, self.group)
164
165     def _load(self, config, group=None):
166         f = None
167         try:
168             if group is None:
169                 if isinstance(self.group, _h5py.Group):
170                     group = self.group
171                 else:
172                     self._check_file()
173                     f = _h5py.File(self._filename, 'r')
174                     group = f[self.group]
175             for s in config.settings:
176                 if s.name not in group.keys():
177                     continue
178                 if isinstance(s, _config.ConfigListSetting):
179                     try:
180                         cwg = h5_create_group(group, s.name)
181                     except ValueError:
182                         pass
183                     else:
184                         value = []
185                         for i in sorted(int(x) for x in cwg.keys()):
186                             instance = s.config_class()
187                             try:
188                                 _cwg = h5_create_group(cwg, str(i))
189                             except ValueError:
190                                 pass
191                             else:
192                                 self._load(config=instance, group=_cwg)
193                                 value.append(instance)
194                         config[s.name] = value
195                 elif isinstance(s, _config.ConfigSetting):
196                     try:
197                         cwg = h5_create_group(group, s.name)
198                     except ValueError:
199                         pass
200                     else:
201                         if not config[s.name]:
202                             config[s.name] = s.config_class()
203                         self._load(config=config[s.name], group=cwg)
204                 else:
205                     try:
206                         v = group[s.name][...]
207                     except Exception as e:
208                         _LOG.error('Could not access {}/{}: {}'.format(
209                                 group.name, s.name, e))
210                         raise 
211                     if isinstance(v, _numpy.ndarray):
212                         if isinstance(s, _config.BooleanSetting):
213                             v = bool(v)  # array(True, dtype=bool) -> True
214                         elif v.dtype.type == _numpy.string_:
215                             if isinstance(s, _config.ListSetting):
216                                 try:
217                                     v = list(v)
218                                 except TypeError:
219                                     v = []
220                                 if _sys.version_info >= (3,):
221                                     for i,v_ in enumerate(v):
222                                         if isinstance(v_, bytes):
223                                             v[i] = str(v_, 'utf-8')
224                             else:  # array('abc', dtype='|S3') -> 'abc'
225                                 if _sys.version_info >= (3,):
226                                     v = str(v, 'utf-8')
227                                 else:
228                                     v = str(v)
229                         elif isinstance(s, _config.IntegerSetting):
230                             v = int(v)  # array(3, dtpe='int32') -> 3
231                         elif isinstance(s, _config.FloatSetting):
232                             v = float(v)  # array(1.2, dtype='float64') -> 1.2
233                         elif isinstance(s, _config.NumericSetting):
234                             raise NotImplementedError(type(s))
235                         elif isinstance(s, _config.ListSetting):
236                             v = list(v)  # convert from numpy array
237                     if _is_string(v):
238                         # convert back from None, etc.
239                         v = s.convert_from_text(v)
240                     config[s.name] = v
241         finally:
242             if f:
243                 f.close()
244
245     def _save(self, config, group=None):
246         f = None
247         try:
248             if group is None:
249                 if isinstance(self.group, _h5py.Group):
250                     group = self.group
251                 else:
252                     self._check_file()
253                     f = _h5py.File(self._filename, 'a')
254                     group = f[self.group]
255             for s in config.settings:
256                 value = None
257                 if isinstance(s, (_config.BooleanSetting,
258                                   _config.NumericSetting,
259                                   _config.ListSetting)):
260                     value = config[s.name]
261                     if value in [None, []]:
262                         value = s.convert_to_text(value)
263                 elif isinstance(s, _config.ConfigListSetting):
264                     configs = config[s.name]
265                     if configs:
266                         cwg = h5_create_group(group, s.name, force=True)
267                         for i,cfg in enumerate(configs):
268                             _cwg = h5_create_group(cwg, str(i), force=True)
269                             self._save(config=cfg, group=_cwg)
270                         continue
271                 elif isinstance(s, _config.ConfigSetting):
272                     cfg = config[s.name]
273                     if cfg:
274                         cwg = h5_create_group(group, s.name, force=True)
275                         self._save(config=cfg, group=cwg)
276                         continue
277                 if value is None:  # not set yet, or invalid
278                     value = s.convert_to_text(config[s.name])
279                 if _sys.version_info >= (3,):  # convert strings to bytes/
280                     if isinstance(value, str):
281                         value = value.encode('utf-8')
282                     elif isinstance(value, list):
283                         value = list(value)  # shallow copy
284                         for i,v in enumerate(value):
285                             if isinstance(v, str):
286                                 value[i] = v.encode('utf-8')
287                 try:
288                     del group[s.name]
289                 except KeyError:
290                     pass
291                 try:
292                     group[s.name] = value
293                 except TypeError:
294                     raise ValueError((value, type(value)))
295         finally:
296             if f:
297                 f.close()