More Python 3 fixes, mostly about string/byte/unicode handling.
[h5config.git] / h5config / storage / hdf5.py
1 # Copyright (C) 2011 W. Trevor King <wking@drexel.edu>
2 #
3 # This file is part of h5config.
4 #
5 # h5config is free software; you can redistribute it and/or modify it
6 # under the terms of the GNU General Public License as published by the
7 # Free Software Foundation, either version 3 of the License, or (at your
8 # option) any later version.
9 #
10 # h5config is distributed in the hope that it will be useful, but
11 # WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13 # General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with h5config.  If not, see <http://www.gnu.org/licenses/>.
17
18 """HDF5 backend implementation
19 """
20
21 import os.path as _os_path
22 import sys as _sys
23
24 import h5py as _h5py
25 import numpy as _numpy
26
27 from .. import LOG as _LOG
28 from .. import config as _config
29 from . import FileStorage as _FileStorage
30 from . import is_string as _is_string
31
32
33 def pprint_HDF5(*args, **kwargs):
34     print(pformat_HDF5(*args, **kwargs))
35
36 def pformat_HDF5(filename, group='/'):
37     try:
38         with _h5py.File(filename, 'r') as f:
39             cwg = f[group]
40             ret = '\n'.join(_pformat_hdf5(cwg))
41     except IOError as e:
42         if 'unable to open' in e.message:
43             if _os_path.getsize(filename) == 0:
44                 return 'EMPTY'
45             return None
46         raise
47     return ret
48
49 def _pformat_hdf5(cwg, depth=0):
50     lines = []
51     lines.append('  '*depth + cwg.name)
52     depth += 1
53     for key,value in cwg.items():
54         if isinstance(value, _h5py.Group):
55             lines.extend(_pformat_hdf5(value, depth))
56         elif isinstance(value, _h5py.Dataset):
57             lines.append('  '*depth + str(value))
58             lines.append('  '*(depth+1) + str(value[...]))
59         else:
60             lines.append('  '*depth + str(value))
61     return lines
62
63 def h5_create_group(cwg, path, force=False):
64     "Create the group where the settings are stored (if necessary)."
65     if path == '/':
66         return cwg
67     gpath = ['']
68     for group in path.strip('/').split('/'):
69         gpath.append(group)
70         if group not in cwg.keys():
71             _LOG.debug('creating group {} in {}'.format(
72                     '/'.join(gpath), cwg.file))
73             cwg.create_group(group)
74         _cwg = cwg[group]
75         if isinstance(_cwg, _h5py.Dataset):
76             if force:
77                 _LOG.info('overwrite {} in {} ({}) with a group'.format(
78                         '/'.join(gpath), _cwg.file, _cwg))
79                 del cwg[group]
80                 _cwg = cwg.create_group(group)
81             else:
82                 raise ValueError(_cwg)
83         cwg = _cwg
84     return cwg
85
86
87 class HDF5_Storage (_FileStorage):
88     """Back a `Config` class with an HDF5 file.
89
90     The `.save` and `.load` methods have an optional `group` argument
91     that allows you to save and load settings from an externally
92     opened HDF5 file.  This can make it easier to stash several
93     related `Config` classes in a single file.  For example
94
95     >>> import os
96     >>> import tempfile
97     >>> from ..test import TestConfig
98     >>> fd,filename = tempfile.mkstemp(
99     ...     suffix='.'+HDF5_Storage.extension, prefix='pypiezo-')
100     >>> os.close(fd)
101
102     >>> f = _h5py.File(filename, 'a')
103     >>> c = TestConfig(storage=HDF5_Storage(
104     ...     filename='untouched_file.h5', group='/untouched/group'))
105     >>> c['alive'] = True
106     >>> group = f.create_group('base')
107     >>> c.save(group=group)
108     >>> pprint_HDF5(filename)  # doctest: +REPORT_UDIFF, +ELLIPSIS
109     /
110       /base
111         <HDF5 dataset "age": shape (), type "<f8">
112           1.3
113         <HDF5 dataset "alive": shape (), type "|b1">
114           True
115         <HDF5 dataset "bids": shape (3,), type "<f8">
116           [ 5.4  3.2  1. ]
117         <HDF5 dataset "children": shape (), type "|S1">
118     <BLANKLINE>
119         <HDF5 dataset "claws": shape (2,), type "<i8">
120           [1 2]
121         <HDF5 dataset "daisies": shape (), type "<i...">
122           13
123         <HDF5 dataset "name": shape (), type "|S1">
124     <BLANKLINE>
125         <HDF5 dataset "species": shape (), type "|S14">
126           Norwegian Blue
127         <HDF5 dataset "spouse": shape (), type "|S1">
128     <BLANKLINE>
129         <HDF5 dataset "words": shape (2,), type "|S7">
130           ['cracker' 'wants']
131     >>> c.clear()
132     >>> c['alive']
133     False
134     >>> c.load(group=group)
135     >>> c['alive']
136     True
137
138     >>> f.close()
139     >>> os.remove(filename)
140     """
141     extension = 'h5'
142
143     def __init__(self, group='/', **kwargs):
144         super(HDF5_Storage, self).__init__(**kwargs)
145         if isinstance(group, _h5py.Group):
146             self._file_checked = True
147         else:
148             assert group.startswith('/'), group
149             if not group.endswith('/'):
150                 group += '/'
151             self._file_checked = False
152         self.group = group
153
154     def _check_file(self):
155         if self._file_checked:
156             return
157         self._setup_file()
158         self._file_checked = True
159
160     def _setup_file(self):
161         self._create_basedir(filename=self._filename)
162         with _h5py.File(self._filename, 'a') as f:
163             cwg = f  # current working group
164             h5_create_group(cwg, self.group)
165
166     def _load(self, config, group=None):
167         f = None
168         try:
169             if group is None:
170                 if isinstance(self.group, _h5py.Group):
171                     group = self.group
172                 else:
173                     self._check_file()
174                     f = _h5py.File(self._filename, 'r')
175                     group = f[self.group]
176             for s in config.settings:
177                 if s.name not in group.keys():
178                     continue
179                 if isinstance(s, _config.ConfigListSetting):
180                     try:
181                         cwg = h5_create_group(group, s.name)
182                     except ValueError:
183                         pass
184                     else:
185                         value = []
186                         for i in sorted(int(x) for x in cwg.keys()):
187                             instance = s.config_class()
188                             try:
189                                 _cwg = h5_create_group(cwg, str(i))
190                             except ValueError:
191                                 pass
192                             else:
193                                 self._load(config=instance, group=_cwg)
194                                 value.append(instance)
195                         config[s.name] = value
196                 elif isinstance(s, _config.ConfigSetting):
197                     try:
198                         cwg = h5_create_group(group, s.name)
199                     except ValueError:
200                         pass
201                     else:
202                         if not config[s.name]:
203                             config[s.name] = s.config_class()
204                         self._load(config=config[s.name], group=cwg)
205                 else:
206                     try:
207                         v = group[s.name][...]
208                     except Exception as e:
209                         _LOG.error('Could not access {}/{}: {}'.format(
210                                 group.name, s.name, e))
211                         raise 
212                     if isinstance(v, _numpy.ndarray):
213                         if isinstance(s, _config.BooleanSetting):
214                             v = bool(v)  # array(True, dtype=bool) -> True
215                         elif v.dtype.type == _numpy.string_:
216                             if isinstance(s, _config.ListSetting):
217                                 try:
218                                     v = list(v)
219                                 except TypeError:
220                                     v = []
221                                 if _sys.version_info >= (3,):
222                                     for i,v_ in enumerate(v):
223                                         if isinstance(v_, bytes):
224                                             v[i] = str(v_, 'utf-8')
225                             else:  # array('abc', dtype='|S3') -> 'abc'
226                                 if _sys.version_info >= (3,):
227                                     v = str(v, 'utf-8')
228                                 else:
229                                     v = str(v)
230                         elif isinstance(s, _config.IntegerSetting):
231                             v = int(v)  # array(3, dtpe='int32') -> 3
232                         elif isinstance(s, _config.FloatSetting):
233                             v = float(v)  # array(1.2, dtype='float64') -> 1.2
234                         elif isinstance(s, _config.NumericSetting):
235                             raise NotImplementedError(type(s))
236                         elif isinstance(s, _config.ListSetting):
237                             v = list(v)  # convert from numpy array
238                     if _is_string(v):
239                         # convert back from None, etc.
240                         v = s.convert_from_text(v)
241                     config[s.name] = v
242         finally:
243             if f:
244                 f.close()
245
246     def _save(self, config, group=None):
247         f = None
248         try:
249             if group is None:
250                 if isinstance(self.group, _h5py.Group):
251                     group = self.group
252                 else:
253                     self._check_file()
254                     f = _h5py.File(self._filename, 'a')
255                     group = f[self.group]
256             for s in config.settings:
257                 value = None
258                 if isinstance(s, (_config.BooleanSetting,
259                                   _config.NumericSetting,
260                                   _config.ListSetting)):
261                     value = config[s.name]
262                     if value in [None, []]:
263                         value = s.convert_to_text(value)
264                 elif isinstance(s, _config.ConfigListSetting):
265                     configs = config[s.name]
266                     if configs:
267                         cwg = h5_create_group(group, s.name, force=True)
268                         for i,cfg in enumerate(configs):
269                             _cwg = h5_create_group(cwg, str(i), force=True)
270                             self._save(config=cfg, group=_cwg)
271                         continue
272                 elif isinstance(s, _config.ConfigSetting):
273                     cfg = config[s.name]
274                     if cfg:
275                         cwg = h5_create_group(group, s.name, force=True)
276                         self._save(config=cfg, group=cwg)
277                         continue
278                 if value is None:  # not set yet, or invalid
279                     value = s.convert_to_text(config[s.name])
280                 if _sys.version_info >= (3,):  # convert strings to bytes/
281                     if isinstance(value, str):
282                         value = value.encode('utf-8')
283                     elif isinstance(value, list):
284                         value = list(value)  # shallow copy
285                         for i,v in enumerate(value):
286                             if isinstance(v, str):
287                                 value[i] = v.encode('utf-8')
288                 try:
289                     del group[s.name]
290                 except KeyError:
291                     pass
292                 try:
293                     group[s.name] = value
294                 except TypeError:
295                     raise ValueError((value, type(value)))
296         finally:
297             if f:
298                 f.close()