Added libbe.storage.base and test suite.
[be.git] / libbe / storage / base.py
1 # Copyright
2
3 """
4 Abstract bug repository data storage to easily support multiple backends.
5 """
6
7 import copy
8 import os
9 import pickle
10
11 from libbe.error import NotSupported
12 from libbe.util.tree import Tree
13 from libbe.util import InvalidObject
14 from libbe import TESTING
15
16 if TESTING == True:
17     import doctest
18     import os.path
19     import sys
20     import unittest
21
22     from libbe.util.utility import Dir
23
24 class ConnectionError (Exception):
25     pass
26
27 class InvalidID (KeyError):
28     pass
29
30 class InvalidRevision (KeyError):
31     pass
32
33 class EmptyCommit(Exception):
34     def __init__(self):
35         Exception.__init__(self, 'No changes to commit')
36
37 class Entry (Tree):
38     def __init__(self, id, value=None, parent=None, children=None):
39         if children == None:
40             Tree.__init__(self)
41         else:
42             Tree.__init__(self, children)
43         self.id = id
44         self.value = value
45         self.parent = parent
46         if self.parent != None:
47             parent.append(self)
48
49     def __str__(self):
50         return '<Entry %s: %s>' % (self.id, self.value)
51
52     def __repr__(self):
53         return str(self)
54
55     def __cmp__(self, other, local=False):
56         if other == None:
57             return cmp(1, None)
58         if cmp(self.id, other.id) != 0:
59             return cmp(self.id, other.id)
60         if cmp(self.value, other.value) != 0:
61             return cmp(self.value, other.value)
62         if local == False:
63             if self.parent == None:
64                 if cmp(self.parent, other.parent) != 0:
65                     return cmp(self.parent, other.parent)
66             elif self.parent.__cmp__(other.parent, local=True) != 0:
67                 return self.parent.__cmp__(other.parent, local=True)
68             for sc,oc in zip(self, other):
69                 if sc.__cmp__(oc, local=True) != 0:
70                     return sc.__cmp__(oc, local=True)
71         return 0
72
73     def _objects_to_ids(self):
74         if self.parent != None:
75             self.parent = self.parent.id
76         for i,c in enumerate(self):
77             self[i] = c.id
78         return self
79
80     def _ids_to_objects(self, dict):
81         if self.parent != None:
82             self.parent = dict[self.parent]
83         for i,c in enumerate(self):
84             self[i] = dict[c]
85         return self
86
87 class Storage (object):
88     """
89     This class declares all the methods required by a Storage
90     interface.  This implementation just keeps the data in a
91     dictionary and uses pickle for persistent storage.
92     """
93     name = 'Storage'
94
95     def __init__(self, repo, options=None):
96         self.repo = repo
97         self.options = options
98         self.read_only = False
99         self.versioned = False
100         self.can_init = True
101
102     def __str__(self):
103         return '<%s %s>' % (self.__class__.__name__, id(self))
104
105     def __repr__(self):
106         return str(self)
107
108     def version(self):
109         """Return a version string for this backend."""
110         return '0'
111
112     def init(self):
113         """Create a new storage repository."""
114         if self.can_init == False:
115             raise NotSupported('init',
116                                'Cannot initialize this repository format.')
117         if self.read_only == True:
118             raise NotSupported('init', 'Cannot initialize read only storage.')
119         return self._init()
120
121     def _init(self):
122         f = open(self.repo, 'wb')
123         root = Entry(id='__ROOT__')
124         d = {root.id:root}
125         pickle.dump(dict((k,v._objects_to_ids()) for k,v in d.items()), f, -1)
126         f.close()
127
128     def destroy(self):
129         """Remove the storage repository."""
130         if self.read_only == True:
131             raise NotSupported('destroy', 'Cannot destroy read only storage.')
132         return self._destroy()
133
134     def _destroy(self):
135         os.remove(self.repo)
136
137     def connect(self):
138         """Open a connection to the repository."""
139         try:
140             f = open(self.repo, 'rb')
141         except IOError:
142             raise ConnectionError(self)
143         d = pickle.load(f)
144         self._data = dict((k,v._ids_to_objects(d)) for k,v in d.items())
145         f.close()
146
147     def disconnect(self):
148         """Close the connection to the repository."""
149         if self.read_only == True:
150             return
151         f = open(self.repo, 'wb')
152         pickle.dump(dict((k,v._objects_to_ids())
153                          for k,v in self._data.items()), f, -1)
154         f.close()
155         self._data = None
156
157     def add(self, *args, **kwargs):
158         """Add an entry"""
159         if self.read_only == True:
160             raise NotSupported('add', 'Cannot add entry to read only storage.')
161         self._add(*args, **kwargs)
162
163     def _add(self, id, parent=None):
164         if parent == None:
165             parent = '__ROOT__'
166         p = self._data[parent]
167         self._data[id] = Entry(id, parent=p)
168
169     def remove(self, *args, **kwargs):
170         """Remove an entry."""
171         if self.read_only == True:
172             raise NotSupported('remove',
173                                'Cannot remove entry from read only storage.')
174         self._remove(*args, **kwargs)
175
176     def _remove(self, id):
177         e = self._data.pop(id)
178         e.parent.remove(e)
179
180     def recursive_remove(self, *args, **kwargs):
181         """Remove an entry and all its decendents."""
182         if self.read_only == True:
183             raise NotSupported('recursive_remove',
184                                'Cannot remove entries from read only storage.')
185         self._recursive_remove(*args, **kwargs)
186
187     def _recursive_remove(self, id):
188         for entry in self._data[id].traverse():
189             self._remove(entry.id)
190
191     def children(self, id=None, revision=None):
192         """Return a list of specified entry's children's ids."""
193         if id == None:
194             id = '__ROOT__'
195         return [c.id for c in self._data[id] if not c.id.startswith('__')]
196
197     def get(self, id, default=InvalidObject, revision=None):
198         """
199         Get contents of and entry as they were in a given revision.
200         revision==None specifies the current revision.
201
202         If there is no id, return default, unless default is not
203         given, in which case raise InvalidID.
204         """
205         if id in self._data:
206             return self._data[id].value
207         elif default == InvalidObject:
208             raise InvalidID(id)
209         return default
210
211     def set(self, *args, **kwargs):
212         """
213         Set the entry contents.
214         """
215         if self.read_only == True:
216             raise NotSupported('set', 'Cannot set entry in read only storage.')
217         self._set(*args, **kwargs)
218
219     def _set(self, id, value):
220         if id not in self._data:
221             raise InvalidID(id)
222         self._data[id].value = value
223
224 class VersionedStorage (Storage):
225     """
226     This class declares all the methods required by a Storage
227     interface that supports versioning.  This implementation just
228     keeps the data in a list and uses pickle for persistent
229     storage.
230     """
231     name = 'VersionedStorage'
232
233     def __init__(self, *args, **kwargs):
234         Storage.__init__(self, *args, **kwargs)
235         self.versioned = True
236
237     def _init(self):
238         f = open(self.repo, 'wb')
239         root = Entry(id='__ROOT__')
240         summary = Entry(id='__COMMIT__SUMMARY__', value='Initial commit')
241         body = Entry(id='__COMMIT__BODY__')
242         initial_commit = {root.id:root, summary.id:summary, body.id:body}
243         d = dict((k,v._objects_to_ids()) for k,v in initial_commit.items())
244         pickle.dump([d, copy.deepcopy(d)], f, -1) # [inital tree, working tree]
245         f.close()
246
247     def connect(self):
248         """Open a connection to the repository."""
249         try:
250             f = open(self.repo, 'rb')
251         except IOError:
252             raise ConnectionError(self)
253         d = pickle.load(f)
254         self._data = [dict((k,v._ids_to_objects(t)) for k,v in t.items())
255                       for t in d]
256         f.close()
257
258     def disconnect(self):
259         """Close the connection to the repository."""
260         if self.read_only == True:
261             return
262         f = open(self.repo, 'wb')
263         pickle.dump([dict((k,v._objects_to_ids())
264                           for k,v in t.items()) for t in self._data], f, -1)
265         f.close()
266         self._data = None
267
268     def _add(self, id, parent=None):
269         if parent == None:
270             parent = '__ROOT__'
271         p = self._data[-1][parent]
272         self._data[-1][id] = Entry(id, parent=p)
273
274     def _remove(self, id):
275         e = self._data[-1].pop(id)
276         e.parent.remove(e)
277
278     def _recursive_remove(self, id):
279         for entry in self._data[-1][id].traverse():
280             self._remove(entry.id)
281
282     def children(self, id=None, revision=None):
283         """Return a list of specified entry's children's ids."""
284         if id == None:
285             id = '__ROOT__'
286         if revision == None:
287             revision = -1
288         return [c.id for c in self._data[revision][id]
289                 if not c.id.startswith('__')]
290
291     def get(self, id, default=InvalidObject, revision=None):
292         """
293         Get contents of and entry as they were in a given revision.
294         revision==None specifies the current revision.
295
296         If there is no id, return default, unless default is not
297         given, in which case raise InvalidID.
298         """
299         if revision == None:
300             revision = -1
301         if id in self._data[revision]:
302             return self._data[revision][id].value
303         elif default == InvalidObject:
304             raise InvalidID(id)
305         return default
306
307     def _set(self, id, value):
308         if id not in self._data[-1]:
309             raise InvalidID(id)
310         self._data[-1][id].value = value
311
312     def commit(self, *args, **kwargs):
313         """
314         Commit the current repository, with a commit message string
315         summary and body.  Return the name of the new revision.
316
317         If allow_empty == False (the default), raise EmptyCommit if
318         there are no changes to commit.
319         """
320         if self.read_only == True:
321             raise NotSupported('commit', 'Cannot commit to read only storage.')
322         return self._commit(*args, **kwargs)
323
324     def _commit(self, summary, body=None, allow_empty=False):
325         if self._data[-1] == self._data[-2] and allow_empty == False:
326             raise EmptyCommit
327         self._data[-1]["__COMMIT__SUMMARY__"].value = summary
328         self._data[-1]["__COMMIT__BODY__"].value = body
329         rev = len(self._data)-1
330         self._data.append(copy.deepcopy(self._data[-1]))
331         return rev
332
333     def revision_id(self, index=None):
334         """
335         Return the name of the <index>th revision.  The choice of
336         which branch to follow when crossing branches/merges is not
337         defined.  Revision indices start at 1; ID 0 is the blank
338         repository.
339
340         Return None if index==None.
341
342         If the specified revision does not exist, raise InvalidRevision.
343         """
344         if index == None:
345             return None
346         try:
347             if int(index) != index:
348                 raise InvalidRevision(index)
349         except ValueError:
350             raise InvalidRevision(index)
351         L = len(self._data) - 1  # -1 b/c of initial commit
352         if index >= -L and index <= L:
353             return index % L
354         raise InvalidRevision(i)
355
356 if TESTING == True:
357     class StorageTestCase (unittest.TestCase):
358         """Test cases for base Storage class."""
359
360         Class = Storage
361
362         def __init__(self, *args, **kwargs):
363             super(StorageTestCase, self).__init__(*args, **kwargs)
364             self.dirname = None
365
366         def setUp(self):
367             """Set up test fixtures for Storage test case."""
368             super(StorageTestCase, self).setUp()
369             self.dir = Dir()
370             self.dirname = self.dir.path
371             self.s = self.Class(repo=os.path.join(self.dirname, 'repo.pkl'))
372             self.assert_failed_connect()
373             self.s.init()
374             self.s.connect()
375
376         def tearDown(self):
377             super(StorageTestCase, self).tearDown()
378             self.s.disconnect()
379             self.s.destroy()
380             self.assert_failed_connect()
381
382         def assert_failed_connect(self):
383             try:
384                 self.s.connect()
385                 self.fail(
386                     "Connected to %(name)s repository before initialising"
387                     % vars(self.Class))
388             except ConnectionError:
389                 pass
390
391     class Storage_init_TestCase (StorageTestCase):
392         """Test cases for Storage.init method."""
393
394         def test_connect_should_succeed_after_init(self):
395             """Should connect after initialization."""
396             self.s.connect()
397
398     class Storage_add_remove_TestCase (StorageTestCase):
399         """Test cases for Storage.add, .remove, and .recursive_remove methods."""
400
401         def test_initially_empty(self):
402             """New repository should be empty."""
403             self.failUnless(len(self.s.children()) == 0, self.s.children())
404
405         def test_add_rooted(self):
406             """
407             Adding entries should increase the number of children (rooted).
408             """
409             ids = []
410             for i in range(10):
411                 ids.append(str(i))
412                 self.s.add(ids[-1])
413                 s = sorted(self.s.children())
414                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
415
416         def test_add_nonrooted(self):
417             """
418             Adding entries should increase the number of children (nonrooted).
419             """
420             self.s.add('parent')
421             ids = []
422             for i in range(10):
423                 ids.append(str(i))
424                 self.s.add(ids[-1], 'parent')
425                 s = sorted(self.s.children('parent'))
426                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
427                 s = self.s.children()
428                 self.failUnless(s == ['parent'], s)
429                 
430         def test_remove_rooted(self):
431             """
432             Removing entries should decrease the number of children (rooted).
433             """
434             ids = []
435             for i in range(10):
436                 ids.append(str(i))
437                 self.s.add(ids[-1])
438             for i in range(10):
439                 self.s.remove(ids.pop())
440                 s = sorted(self.s.children())
441                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
442
443         def test_remove_nonrooted(self):
444             """
445             Removing entries should decrease the number of children (nonrooted).
446             """
447             self.s.add('parent')
448             ids = []
449             for i in range(10):
450                 ids.append(str(i))
451                 self.s.add(ids[-1], 'parent')
452             for i in range(10):
453                 self.s.remove(ids.pop())
454                 s = sorted(self.s.children('parent'))
455                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
456                 s = self.s.children()
457                 self.failUnless(s == ['parent'], s)
458
459         def test_recursive_remove(self):
460             """
461             Recursive remove should empty the tree.
462             """
463             self.s.add('parent')
464             ids = []
465             for i in range(10):
466                 ids.append(str(i))
467                 self.s.add(ids[-1], 'parent')
468                 for j in range(10): # add some grandkids
469                     self.s.add(str(20*i+j), ids[-i])
470             self.s.recursive_remove('parent')
471             s = sorted(self.s.children())
472             self.failUnless(s == [], s)
473
474     class Storage_get_set_TestCase (StorageTestCase):
475         """Test cases for Storage.get and .set methods."""
476
477         id = 'unlikely id'
478         val = 'unlikely value'
479
480         def test_get_default(self):
481             """
482             Get should return specified default if id not in Storage.
483             """
484             ret = self.s.get(self.id, default=self.val)
485             self.failUnless(ret == self.val,
486                     "%s.get() returned %s not %s"
487                     % (vars(self.Class)['name'], ret, self.val))
488
489         def test_get_default_exception(self):
490             """
491             Get should raise exception if id not in Storage and no default.
492             """
493             try:
494                 ret = self.s.get(self.id)
495                 self.fail(
496                     "%s.get() returned %s instead of raising InvalidID"
497                     % (vars(self.Class)['name'], ret))
498             except InvalidID:
499                 pass
500
501         def test_get_initial_value(self):
502             """
503             Data value should be None before any value has been set.
504             """
505             self.s.add(self.id)
506             ret = self.s.get(self.id)
507             self.failUnless(ret == None,
508                     "%s.get() returned %s not None"
509                     % (vars(self.Class)['name'], ret))
510
511         def test_set_exception(self):
512             """
513             Set should raise exception if id not in Storage.
514             """
515             try:
516                 self.s.set(self.id, self.val)
517                 self.fail(
518                     "%(name)s.set() did not raise InvalidID"
519                     % vars(self.Class))
520             except InvalidID:
521                 pass
522
523         def test_set(self):
524             """
525             Set should define the value returned by get.
526             """
527             self.s.add(self.id)
528             self.s.set(self.id, self.val)
529             ret = self.s.get(self.id)
530             self.failUnless(ret == self.val,
531                     "%s.get() returned %s not %s"
532                     % (vars(self.Class)['name'], ret, self.val))
533
534     class Storage_persistence_TestCase (StorageTestCase):
535         """Test cases for Storage.disconnect and .connect methods."""
536
537         id = 'unlikely id'
538         val = 'unlikely value'
539
540         def test_get_set_persistence(self):
541             """
542             Set should define the value returned by get after reconnect.
543             """
544             self.s.add(self.id)
545             self.s.set(self.id, self.val)
546             self.s.disconnect()
547             self.s.connect()
548             ret = self.s.get(self.id)
549             self.failUnless(ret == self.val,
550                     "%s.get() returned %s not %s"
551                     % (vars(self.Class)['name'], ret, self.val))
552
553         def test_add_nonrooted_persistence(self):
554             """
555             Adding entries should increase the number of children after reconnect.
556             """
557             self.s.add('parent')
558             ids = []
559             for i in range(10):
560                 ids.append(str(i))
561                 self.s.add(ids[-1], 'parent')
562             self.s.disconnect()
563             self.s.connect()
564             s = sorted(self.s.children('parent'))
565             self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
566             s = self.s.children()
567             self.failUnless(s == ['parent'], s)
568
569     class VersionedStorageTestCase (StorageTestCase):
570         """Test cases for base VersionedStorage class."""
571
572         Class = VersionedStorage
573
574     class VersionedStorage_commit_TestCase (VersionedStorageTestCase):
575         """Test cases for VersionedStorage methods."""
576
577         id = 'I' #unlikely id'
578         val = 'X'
579         commit_msg = 'C' #ommitting something interesting'
580         commit_body = 'B' #ome\nlonger\ndescription\n'
581
582         def test_revision_id_exception(self):
583             """
584             Invalid revision id should raise InvalidRevision.
585             """
586             try:
587                 rev = self.s.revision_id('highly unlikely revision id')
588                 self.fail(
589                     "%s.revision_id() didn't raise InvalidRevision, returned %s."
590                     % (vars(self.Class)['name'], rev))
591             except InvalidRevision:
592                 pass
593
594         def test_empty_commit_raises_exception(self):
595             """
596             Empty commit should raise exception.
597             """
598             try:
599                 self.s.commit(self.commit_msg, self.commit_body)
600                 self.fail(
601                     "Empty %(name)s.commit() didn't raise EmptyCommit."
602                     % vars(self.Class))
603             except EmptyCommit:
604                 pass
605
606         def test_empty_commit_allowed(self):
607             """
608             Empty commit should _not_ raise exception if allow_empty=True.
609             """
610             self.s.commit(self.commit_msg, self.commit_body,
611                           allow_empty=True)
612
613         def test_commit_revision_ids(self):
614             """
615             Commit / revision_id should agree on revision ids.
616             """
617             revs = []
618             for s in range(10):
619                 revs.append(self.s.commit(self.commit_msg,
620                                           self.commit_body,
621                                           allow_empty=True))
622             for i in range(10):
623                 rev = self.s.revision_id(i+1) 
624                 self.failUnless(rev == revs[i],
625                                 "%s.revision_id(%d) returned %s not %s"
626                                 % (vars(self.Class)['name'], i+1, rev, revs[i]))
627             for i in range(-1, -9, -1):
628                 rev = self.s.revision_id(i)
629                 self.failUnless(rev == revs[i],
630                                 "%s.revision_id(%d) returned %s not %s"
631                                 % (vars(self.Class)['name'], i, rev, revs[i]))
632
633         def test_get_previous_version(self):
634             """
635             Get should be able to return the previous version.
636             """
637             def val(i):
638                 return '%s:%d' % (self.val, i+1)
639             self.s.add(self.id)
640             revs = []
641             for i in range(10):
642                 self.s.set(self.id, val(i))
643                 revs.append(self.s.commit('%s: %d' % (self.commit_msg, i),
644                                           self.commit_body))
645             for i in range(10):
646                 ret = self.s.get(self.id, revision=revs[i])
647                 self.failUnless(ret == val(i),
648                                 "%s.get() returned %s not %s for revision %d"
649                                 % (vars(self.Class)['name'], ret, val(i), revs[i]))
650         
651     def make_storage_testcase_subclasses(storage_class, namespace):
652         """Make StorageTestCase subclasses for storage_class in namespace."""
653         storage_testcase_classes = [
654             c for c in (
655                 ob for ob in globals().values() if isinstance(ob, type))
656             if issubclass(c, StorageTestCase) \
657                 and not issubclass(c, VersionedStorageTestCase)]
658
659         for base_class in storage_testcase_classes:
660             testcase_class_name = storage_class.__name__ + base_class.__name__
661             testcase_class_bases = (base_class,)
662             testcase_class_dict = dict(base_class.__dict__)
663             testcase_class_dict['Class'] = storage_class
664             testcase_class = type(
665                 testcase_class_name, testcase_class_bases, testcase_class_dict)
666             setattr(namespace, testcase_class_name, testcase_class)
667
668     def make_versioned_storage_testcase_subclasses(storage_class, namespace):
669         """Make VersionedStorageTestCase subclasses for storage_class in namespace."""
670         storage_testcase_classes = [
671             c for c in (
672                 ob for ob in globals().values() if isinstance(ob, type))
673             if issubclass(c, StorageTestCase)]
674
675         for base_class in storage_testcase_classes:
676             testcase_class_name = storage_class.__name__ + base_class.__name__
677             testcase_class_bases = (base_class,)
678             testcase_class_dict = dict(base_class.__dict__)
679             testcase_class_dict['Class'] = storage_class
680             testcase_class = type(
681                 testcase_class_name, testcase_class_bases, testcase_class_dict)
682             setattr(namespace, testcase_class_name, testcase_class)
683
684     make_storage_testcase_subclasses(VersionedStorage, sys.modules[__name__])
685
686     unitsuite =unittest.TestLoader().loadTestsFromModule(sys.modules[__name__])
687     suite = unittest.TestSuite([unitsuite, doctest.DocTestSuite()])