79e968200741299b030f7af040cb53c0dea47580
[h5config.git] / h5config / storage / hdf5.py
1 # Copyright (C) 2011 W. Trevor King <wking@drexel.edu>
2 #
3 # This file is part of h5config.
4 #
5 # h5config is free software; you can redistribute it and/or modify it
6 # under the terms of the GNU General Public License as published by the
7 # Free Software Foundation, either version 3 of the License, or (at your
8 # option) any later version.
9 #
10 # h5config is distributed in the hope that it will be useful, but
11 # WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13 # General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with h5config.  If not, see <http://www.gnu.org/licenses/>.
17
18 """HDF5 backend implementation
19 """
20
21 import os.path as _os_path
22 import types as _types
23
24 import h5py as _h5py
25 import numpy as _numpy
26
27 from .. import LOG as _LOG
28 from .. import config as _config
29 from . import FileStorage as _FileStorage
30
31
32 def pprint_HDF5(*args, **kwargs):
33     print(pformat_HDF5(*args, **kwargs))
34
35 def pformat_HDF5(filename, group='/'):
36     try:
37         with _h5py.File(filename, 'r') as f:
38             cwg = f[group]
39             ret = '\n'.join(_pformat_hdf5(cwg))
40     except IOError as e:
41         if 'unable to open' in e.message:
42             if _os_path.getsize(filename) == 0:
43                 return 'EMPTY'
44             return None
45         raise
46     return ret
47
48 def _pformat_hdf5(cwg, depth=0):
49     lines = []
50     lines.append('  '*depth + cwg.name)
51     depth += 1
52     for key,value in cwg.iteritems():
53         if isinstance(value, _h5py.Group):
54             lines.extend(_pformat_hdf5(value, depth))
55         elif isinstance(value, _h5py.Dataset):
56             lines.append('  '*depth + str(value))
57             lines.append('  '*(depth+1) + str(value[...]))
58         else:
59             lines.append('  '*depth + str(value))
60     return lines
61
62 def h5_create_group(cwg, path, force=False):
63     "Create the group where the settings are stored (if necessary)."
64     if path == '/':
65         return cwg
66     gpath = ['']
67     for group in path.strip('/').split('/'):
68         gpath.append(group)
69         if group not in cwg.keys():
70             _LOG.debug('creating group {} in {}'.format(
71                     '/'.join(gpath), cwg.file))
72             cwg.create_group(group)
73         _cwg = cwg[group]
74         if isinstance(_cwg, _h5py.Dataset):
75             if force:
76                 _LOG.info('overwrite {} in {} ({}) with a group'.format(
77                         '/'.join(gpath), _cwg.file, _cwg))
78                 del cwg[group]
79                 _cwg = cwg.create_group(group)
80             else:
81                 raise ValueError(_cwg)
82         cwg = _cwg
83     return cwg
84
85
86 class HDF5_Storage (_FileStorage):
87     """Back a `Config` class with an HDF5 file.
88
89     The `.save` and `.load` methods have an optional `group` argument
90     that allows you to save and load settings from an externally
91     opened HDF5 file.  This can make it easier to stash several
92     related `Config` classes in a single file.  For example
93
94     >>> import os
95     >>> import tempfile
96     >>> from ..test import TestConfig
97     >>> fd,filename = tempfile.mkstemp(
98     ...     suffix='.'+HDF5_Storage.extension, prefix='pypiezo-')
99     >>> os.close(fd)
100
101     >>> f = _h5py.File(filename, 'a')
102     >>> c = TestConfig(storage=HDF5_Storage(
103     ...     filename='untouched_file.h5', group='/untouched/group'))
104     >>> c['alive'] = True
105     >>> group = f.create_group('base')
106     >>> c.save(group=group)
107     >>> pprint_HDF5(filename)  # doctest: +REPORT_UDIFF, +ELLIPSIS
108     /
109       /base
110         <HDF5 dataset "age": shape (), type "<f8">
111           1.3
112         <HDF5 dataset "alive": shape (), type "|b1">
113           True
114         <HDF5 dataset "bids": shape (3,), type "<f8">
115           [ 5.4  3.2  1. ]
116         <HDF5 dataset "children": shape (), type "|S1">
117     <BLANKLINE>
118         <HDF5 dataset "claws": shape (2,), type "<i8">
119           [1 2]
120         <HDF5 dataset "daisies": shape (), type "<i...">
121           13
122         <HDF5 dataset "name": shape (), type "|S1">
123     <BLANKLINE>
124         <HDF5 dataset "species": shape (), type "|S14">
125           Norwegian Blue
126         <HDF5 dataset "spouse": shape (), type "|S1">
127     <BLANKLINE>
128         <HDF5 dataset "words": shape (2,), type "|S7">
129           ['cracker' 'wants']
130     >>> c.clear()
131     >>> c['alive']
132     False
133     >>> c.load(group=group)
134     >>> c['alive']
135     True
136
137     >>> f.close()
138     >>> os.remove(filename)
139     """
140     extension = 'h5'
141
142     def __init__(self, group='/', **kwargs):
143         super(HDF5_Storage, self).__init__(**kwargs)
144         if isinstance(group, _h5py.Group):
145             self._file_checked = True
146         else:
147             assert group.startswith('/'), group
148             if not group.endswith('/'):
149                 group += '/'
150             self._file_checked = False
151         self.group = group
152
153     def _check_file(self):
154         if self._file_checked:
155             return
156         self._setup_file()
157         self._file_checked = True
158
159     def _setup_file(self):
160         self._create_basedir(filename=self._filename)
161         with _h5py.File(self._filename, 'a') as f:
162             cwg = f  # current working group
163             h5_create_group(cwg, self.group)
164
165     def _load(self, config, group=None):
166         f = None
167         try:
168             if group is None:
169                 if isinstance(self.group, _h5py.Group):
170                     group = self.group
171                 else:
172                     self._check_file()
173                     f = _h5py.File(self._filename, 'r')
174                     group = f[self.group]
175             for s in config.settings:
176                 if s.name not in group.keys():
177                     continue
178                 if isinstance(s, _config.ConfigListSetting):
179                     try:
180                         cwg = h5_create_group(group, s.name)
181                     except ValueError:
182                         pass
183                     else:
184                         value = []
185                         for i in sorted(int(x) for x in cwg.keys()):
186                             instance = s.config_class()
187                             try:
188                                 _cwg = h5_create_group(cwg, str(i))
189                             except ValueError:
190                                 pass
191                             else:
192                                 self._load(config=instance, group=_cwg)
193                                 value.append(instance)
194                         config[s.name] = value
195                 elif isinstance(s, _config.ConfigSetting):
196                     try:
197                         cwg = h5_create_group(group, s.name)
198                     except ValueError:
199                         pass
200                     else:
201                         if not config[s.name]:
202                             config[s.name] = s.config_class()
203                         self._load(config=config[s.name], group=cwg)
204                 else:
205                     try:
206                         v = group[s.name][...]
207                     except Exception as e:
208                         _LOG.error('Could not access {}/{}: {}'.format(
209                                 group.name, s.name, e))
210                         raise 
211                     if isinstance(v, _numpy.ndarray):
212                         if isinstance(s, _config.BooleanSetting):
213                             v = bool(v)  # array(True, dtype=bool) -> True
214                         elif v.dtype.type == _numpy.string_:
215                             if isinstance(s, _config.ListSetting):
216                                 try:
217                                     v = list(v)
218                                 except TypeError:
219                                     v = []
220                             else:
221                                 v = str(v) # array('abc', dtype='|S3') -> 'abc'
222                         elif isinstance(s, _config.IntegerSetting):
223                             v = int(v)  # array(3, dtpe='int32') -> 3
224                         elif isinstance(s, _config.FloatSetting):
225                             v = float(v)  # array(1.2, dtype='float64') -> 1.2
226                         elif isinstance(s, _config.NumericSetting):
227                             raise NotImplementedError(type(s))
228                         elif isinstance(s, _config.ListSetting):
229                             v = list(v)  # convert from numpy array
230                     if isinstance(v, _types.StringTypes):
231                         # convert back from None, etc.
232                         v = s.convert_from_text(v)
233                     config[s.name] = v
234         finally:
235             if f:
236                 f.close()
237
238     def _save(self, config, group=None):
239         f = None
240         try:
241             if group is None:
242                 if isinstance(self.group, _h5py.Group):
243                     group = self.group
244                 else:
245                     self._check_file()
246                     f = _h5py.File(self._filename, 'a')
247                     group = f[self.group]
248             for s in config.settings:
249                 value = None
250                 if isinstance(s, (_config.BooleanSetting,
251                                   _config.NumericSetting,
252                                   _config.ListSetting)):
253                     value = config[s.name]
254                     if value in [None, []]:
255                         value = s.convert_to_text(value)
256                 elif isinstance(s, _config.ConfigListSetting):
257                     configs = config[s.name]
258                     if configs:
259                         cwg = h5_create_group(group, s.name, force=True)
260                         for i,cfg in enumerate(configs):
261                             _cwg = h5_create_group(cwg, str(i), force=True)
262                             self._save(config=cfg, group=_cwg)
263                         continue
264                 elif isinstance(s, _config.ConfigSetting):
265                     cfg = config[s.name]
266                     if cfg:
267                         cwg = h5_create_group(group, s.name, force=True)
268                         self._save(config=cfg, group=cwg)
269                         continue
270                 if value is None:  # not set yet, or invalid
271                     value = s.convert_to_text(config[s.name])
272                 try:
273                     del group[s.name]
274                 except KeyError:
275                     pass
276                 group[s.name] = value
277         finally:
278             if f:
279                 f.close()