Transitioned bugdir.py to new storage format.
[be.git] / libbe / storage / base.py
1 # Copyright
2
3 """
4 Abstract bug repository data storage to easily support multiple backends.
5 """
6
7 import copy
8 import os
9 import pickle
10 import types
11
12 from libbe.error import NotSupported
13 from libbe.util.tree import Tree
14 from libbe.util import InvalidObject
15 from libbe import TESTING
16
17 if TESTING == True:
18     import doctest
19     import os.path
20     import sys
21     import unittest
22
23     from libbe.util.utility import Dir
24
25 class ConnectionError (Exception):
26     pass
27
28 class InvalidID (KeyError):
29     pass
30
31 class InvalidRevision (KeyError):
32     pass
33
34 class NotWriteable (NotSupported):
35     def __init__(self, msg):
36         NotSupported.__init__(self, 'write', msg)
37
38 class NotReadable (NotSupported):
39     def __init__(self, msg):
40         NotSupported.__init__(self, 'read', msg)
41
42 class EmptyCommit(Exception):
43     def __init__(self):
44         Exception.__init__(self, 'No changes to commit')
45
46 class Entry (Tree):
47     def __init__(self, id, value=None, parent=None, children=None):
48         if children == None:
49             Tree.__init__(self)
50         else:
51             Tree.__init__(self, children)
52         self.id = id
53         self.value = value
54         self.parent = parent
55         if self.parent != None:
56             parent.append(self)
57
58     def __str__(self):
59         return '<Entry %s: %s>' % (self.id, self.value)
60
61     def __repr__(self):
62         return str(self)
63
64     def __cmp__(self, other, local=False):
65         if other == None:
66             return cmp(1, None)
67         if cmp(self.id, other.id) != 0:
68             return cmp(self.id, other.id)
69         if cmp(self.value, other.value) != 0:
70             return cmp(self.value, other.value)
71         if local == False:
72             if self.parent == None:
73                 if cmp(self.parent, other.parent) != 0:
74                     return cmp(self.parent, other.parent)
75             elif self.parent.__cmp__(other.parent, local=True) != 0:
76                 return self.parent.__cmp__(other.parent, local=True)
77             for sc,oc in zip(self, other):
78                 if sc.__cmp__(oc, local=True) != 0:
79                     return sc.__cmp__(oc, local=True)
80         return 0
81
82     def _objects_to_ids(self):
83         if self.parent != None:
84             self.parent = self.parent.id
85         for i,c in enumerate(self):
86             self[i] = c.id
87         return self
88
89     def _ids_to_objects(self, dict):
90         if self.parent != None:
91             self.parent = dict[self.parent]
92         for i,c in enumerate(self):
93             self[i] = dict[c]
94         return self
95
96 class Storage (object):
97     """
98     This class declares all the methods required by a Storage
99     interface.  This implementation just keeps the data in a
100     dictionary and uses pickle for persistent storage.
101     """
102     name = 'Storage'
103
104     def __init__(self, repo, encoding='utf-8', options=None):
105         self.repo = repo
106         self.encoding = encoding
107         self.options = options
108         self.readable = True  # soft limit (user choice)
109         self._readable = True # hard limit (backend choice)
110         self.writeable = True  # soft limit (user choice)
111         self._writeable = True # hard limit (backend choice)
112         self.versioned = False
113         self.can_init = True
114
115     def __str__(self):
116         return '<%s %s>' % (self.__class__.__name__, id(self))
117
118     def __repr__(self):
119         return str(self)
120
121     def version(self):
122         """Return a version string for this backend."""
123         return '0'
124
125     def is_readable(self):
126         return self.readable and self._readable
127
128     def is_writeable(self):
129         return self.writeable and self._writeable
130
131     def init(self):
132         """Create a new storage repository."""
133         if self.can_init == False:
134             raise NotSupported('init',
135                                'Cannot initialize this repository format.')
136         if self.is_writeable() == False:
137             raise NotWriteable('Cannot initialize unwriteable storage.')
138         return self._init()
139
140     def _init(self):
141         f = open(self.repo, 'wb')
142         root = Entry(id='__ROOT__')
143         d = {root.id:root}
144         pickle.dump(dict((k,v._objects_to_ids()) for k,v in d.items()), f, -1)
145         f.close()
146
147     def destroy(self):
148         """Remove the storage repository."""
149         if self.is_writeable() == False:
150             raise NotWriteable('Cannot destroy unwriteable storage.')
151         return self._destroy()
152
153     def _destroy(self):
154         os.remove(self.repo)
155
156     def connect(self):
157         """Open a connection to the repository."""
158         if self.is_readable() == False:
159             raise NotReadable('Cannot connect to unreadable storage.')
160         self._connect()
161
162     def _connect(self):
163         try:
164             f = open(self.repo, 'rb')
165         except IOError:
166             raise ConnectionError(self)
167         d = pickle.load(f)
168         self._data = dict((k,v._ids_to_objects(d)) for k,v in d.items())
169         f.close()
170
171     def disconnect(self):
172         """Close the connection to the repository."""
173         if self.is_writeable() == False:
174             return
175         f = open(self.repo, 'wb')
176         pickle.dump(dict((k,v._objects_to_ids())
177                          for k,v in self._data.items()), f, -1)
178         f.close()
179         self._data = None
180
181     def add(self, *args, **kwargs):
182         """Add an entry"""
183         if self.is_writeable() == False:
184             raise NotWriteable('Cannot add entry to unwriteable storage.')
185         try:  # Maybe we've already added that id?
186             self.get(id)
187             pass # yup, no need to add another
188         except InvalidID:
189             self._add(*args, **kwargs)
190
191     def _add(self, id, parent=None):
192         if parent == None:
193             parent = '__ROOT__'
194         p = self._data[parent]
195         self._data[id] = Entry(id, parent=p)
196
197     def remove(self, *args, **kwargs):
198         """Remove an entry."""
199         if self.is_writeable() == False:
200             raise NotSupported('write',
201                                'Cannot remove entry from unwriteable storage.')
202         self._remove(*args, **kwargs)
203
204     def _remove(self, id):
205         e = self._data.pop(id)
206         e.parent.remove(e)
207
208     def recursive_remove(self, *args, **kwargs):
209         """Remove an entry and all its decendents."""
210         if self.is_writeable() == False:
211             raise NotSupported('write',
212                                'Cannot remove entries from unwriteable storage.')
213         self._recursive_remove(*args, **kwargs)
214
215     def _recursive_remove(self, id):
216         for entry in self._data[id].traverse():
217             self._remove(entry.id)
218
219     def children(self, *args, **kwargs):
220         """Return a list of specified entry's children's ids."""
221         if self.is_readable() == False:
222             raise NotReadable('Cannot list children with unreadable storage.')
223         return self._children(*args, **kwargs)
224
225     def _children(self, id=None, revision=None):
226         if id == None:
227             id = '__ROOT__'
228         return [c.id for c in self._data[id] if not c.id.startswith('__')]
229
230     def get(self, *args, **kwargs):
231         """
232         Get contents of and entry as they were in a given revision.
233         revision==None specifies the current revision.
234
235         If there is no id, return default, unless default is not
236         given, in which case raise InvalidID.
237         """
238         if self.is_readable() == False:
239             raise NotReadable('Cannot get entry with unreadable storage.')
240         if 'decode' in kwargs:
241             decode = kwargs.pop('decode')
242         else:
243             decode = False
244         value = self._get(*args, **kwargs)
245         if decode == True:
246             return unicode(value, self.encoding)
247         return value
248
249     def _get(self, id, default=InvalidObject, revision=None):
250         if id in self._data:
251             return self._data[id].value
252         elif default == InvalidObject:
253             raise InvalidID(id)
254         return default
255
256     def set(self, id, value, *args, **kwargs):
257         """
258         Set the entry contents.
259         """
260         if self.is_writeable() == False:
261             raise NotWriteable('Cannot set entry in unwriteable storage.')
262         if type(value) == types.UnicodeType:
263             value = value.encode(self.encoding)
264         self._set(id, value, *args, **kwargs)
265
266     def _set(self, id, value):
267         if id not in self._data:
268             raise InvalidID(id)
269         self._data[id].value = value
270
271 class VersionedStorage (Storage):
272     """
273     This class declares all the methods required by a Storage
274     interface that supports versioning.  This implementation just
275     keeps the data in a list and uses pickle for persistent
276     storage.
277     """
278     name = 'VersionedStorage'
279
280     def __init__(self, *args, **kwargs):
281         Storage.__init__(self, *args, **kwargs)
282         self.versioned = True
283
284     def _init(self):
285         f = open(self.repo, 'wb')
286         root = Entry(id='__ROOT__')
287         summary = Entry(id='__COMMIT__SUMMARY__', value='Initial commit')
288         body = Entry(id='__COMMIT__BODY__')
289         initial_commit = {root.id:root, summary.id:summary, body.id:body}
290         d = dict((k,v._objects_to_ids()) for k,v in initial_commit.items())
291         pickle.dump([d, copy.deepcopy(d)], f, -1) # [inital tree, working tree]
292         f.close()
293
294     def _connect(self):
295         try:
296             f = open(self.repo, 'rb')
297         except IOError:
298             raise ConnectionError(self)
299         d = pickle.load(f)
300         self._data = [dict((k,v._ids_to_objects(t)) for k,v in t.items())
301                       for t in d]
302         f.close()
303
304     def disconnect(self):
305         """Close the connection to the repository."""
306         if self.is_writeable() == False:
307             return
308         f = open(self.repo, 'wb')
309         pickle.dump([dict((k,v._objects_to_ids())
310                           for k,v in t.items()) for t in self._data], f, -1)
311         f.close()
312         self._data = None
313
314     def _add(self, id, parent=None):
315         if parent == None:
316             parent = '__ROOT__'
317         p = self._data[-1][parent]
318         self._data[-1][id] = Entry(id, parent=p)
319
320     def _remove(self, id):
321         e = self._data[-1].pop(id)
322         e.parent.remove(e)
323
324     def _recursive_remove(self, id):
325         for entry in self._data[-1][id].traverse():
326             self._remove(entry.id)
327
328     def _children(self, id=None, revision=None):
329         if id == None:
330             id = '__ROOT__'
331         if revision == None:
332             revision = -1
333         return [c.id for c in self._data[revision][id]
334                 if not c.id.startswith('__')]
335
336     def _get(self, id, default=InvalidObject, revision=None):
337         if revision == None:
338             revision = -1
339         if id in self._data[revision]:
340             return self._data[revision][id].value
341         elif default == InvalidObject:
342             raise InvalidID(id)
343         return default
344
345     def _set(self, id, value):
346         if id not in self._data[-1]:
347             raise InvalidID(id)
348         self._data[-1][id].value = value
349
350     def commit(self, *args, **kwargs):
351         """
352         Commit the current repository, with a commit message string
353         summary and body.  Return the name of the new revision.
354
355         If allow_empty == False (the default), raise EmptyCommit if
356         there are no changes to commit.
357         """
358         if self.is_writeable() == False:
359             raise NotWriteable('Cannot commit to unwriteable storage.')
360         return self._commit(*args, **kwargs)
361
362     def _commit(self, summary, body=None, allow_empty=False):
363         if self._data[-1] == self._data[-2] and allow_empty == False:
364             raise EmptyCommit
365         self._data[-1]["__COMMIT__SUMMARY__"].value = summary
366         self._data[-1]["__COMMIT__BODY__"].value = body
367         rev = len(self._data)-1
368         self._data.append(copy.deepcopy(self._data[-1]))
369         return rev
370
371     def revision_id(self, index=None):
372         """
373         Return the name of the <index>th revision.  The choice of
374         which branch to follow when crossing branches/merges is not
375         defined.  Revision indices start at 1; ID 0 is the blank
376         repository.
377
378         Return None if index==None.
379
380         If the specified revision does not exist, raise InvalidRevision.
381         """
382         if index == None:
383             return None
384         try:
385             if int(index) != index:
386                 raise InvalidRevision(index)
387         except ValueError:
388             raise InvalidRevision(index)
389         L = len(self._data) - 1  # -1 b/c of initial commit
390         if index >= -L and index <= L:
391             return index % L
392         raise InvalidRevision(i)
393
394 if TESTING == True:
395     class StorageTestCase (unittest.TestCase):
396         """Test cases for base Storage class."""
397
398         Class = Storage
399
400         def __init__(self, *args, **kwargs):
401             super(StorageTestCase, self).__init__(*args, **kwargs)
402             self.dirname = None
403
404         def setUp(self):
405             """Set up test fixtures for Storage test case."""
406             super(StorageTestCase, self).setUp()
407             self.dir = Dir()
408             self.dirname = self.dir.path
409             self.s = self.Class(repo=os.path.join(self.dirname, 'repo.pkl'))
410             self.assert_failed_connect()
411             self.s.init()
412             self.s.connect()
413
414         def tearDown(self):
415             super(StorageTestCase, self).tearDown()
416             self.s.disconnect()
417             self.s.destroy()
418             self.assert_failed_connect()
419
420         def assert_failed_connect(self):
421             try:
422                 self.s.connect()
423                 self.fail(
424                     "Connected to %(name)s repository before initialising"
425                     % vars(self.Class))
426             except ConnectionError:
427                 pass
428
429     class Storage_init_TestCase (StorageTestCase):
430         """Test cases for Storage.init method."""
431
432         def test_connect_should_succeed_after_init(self):
433             """Should connect after initialization."""
434             self.s.connect()
435
436     class Storage_add_remove_TestCase (StorageTestCase):
437         """Test cases for Storage.add, .remove, and .recursive_remove methods."""
438
439         def test_initially_empty(self):
440             """New repository should be empty."""
441             self.failUnless(len(self.s.children()) == 0, self.s.children())
442
443         def test_add_rooted(self):
444             """
445             Adding entries with the same ID should not increase the number of children.
446             """
447             for i in range(10):
448                 self.s.add('some id')
449                 s = sorted(self.s.children())
450                 self.failUnless(s == ['some id'], s)
451
452         def test_add_rooted(self):
453             """
454             Adding entries should increase the number of children (rooted).
455             """
456             ids = []
457             for i in range(10):
458                 ids.append(str(i))
459                 self.s.add(ids[-1])
460                 s = sorted(self.s.children())
461                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
462
463         def test_add_nonrooted(self):
464             """
465             Adding entries should increase the number of children (nonrooted).
466             """
467             self.s.add('parent')
468             ids = []
469             for i in range(10):
470                 ids.append(str(i))
471                 self.s.add(ids[-1], 'parent')
472                 s = sorted(self.s.children('parent'))
473                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
474                 s = self.s.children()
475                 self.failUnless(s == ['parent'], s)
476                 
477         def test_remove_rooted(self):
478             """
479             Removing entries should decrease the number of children (rooted).
480             """
481             ids = []
482             for i in range(10):
483                 ids.append(str(i))
484                 self.s.add(ids[-1])
485             for i in range(10):
486                 self.s.remove(ids.pop())
487                 s = sorted(self.s.children())
488                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
489
490         def test_remove_nonrooted(self):
491             """
492             Removing entries should decrease the number of children (nonrooted).
493             """
494             self.s.add('parent')
495             ids = []
496             for i in range(10):
497                 ids.append(str(i))
498                 self.s.add(ids[-1], 'parent')
499             for i in range(10):
500                 self.s.remove(ids.pop())
501                 s = sorted(self.s.children('parent'))
502                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
503                 s = self.s.children()
504                 self.failUnless(s == ['parent'], s)
505
506         def test_recursive_remove(self):
507             """
508             Recursive remove should empty the tree.
509             """
510             self.s.add('parent')
511             ids = []
512             for i in range(10):
513                 ids.append(str(i))
514                 self.s.add(ids[-1], 'parent')
515                 for j in range(10): # add some grandkids
516                     self.s.add(str(20*i+j), ids[-i])
517             self.s.recursive_remove('parent')
518             s = sorted(self.s.children())
519             self.failUnless(s == [], s)
520
521     class Storage_get_set_TestCase (StorageTestCase):
522         """Test cases for Storage.get and .set methods."""
523
524         id = 'unlikely id'
525         val = 'unlikely value'
526
527         def test_get_default(self):
528             """
529             Get should return specified default if id not in Storage.
530             """
531             ret = self.s.get(self.id, default=self.val)
532             self.failUnless(ret == self.val,
533                     "%s.get() returned %s not %s"
534                     % (vars(self.Class)['name'], ret, self.val))
535
536         def test_get_default_exception(self):
537             """
538             Get should raise exception if id not in Storage and no default.
539             """
540             try:
541                 ret = self.s.get(self.id)
542                 self.fail(
543                     "%s.get() returned %s instead of raising InvalidID"
544                     % (vars(self.Class)['name'], ret))
545             except InvalidID:
546                 pass
547
548         def test_get_initial_value(self):
549             """
550             Data value should be None before any value has been set.
551             """
552             self.s.add(self.id)
553             ret = self.s.get(self.id)
554             self.failUnless(ret == None,
555                     "%s.get() returned %s not None"
556                     % (vars(self.Class)['name'], ret))
557
558         def test_set_exception(self):
559             """
560             Set should raise exception if id not in Storage.
561             """
562             try:
563                 self.s.set(self.id, self.val)
564                 self.fail(
565                     "%(name)s.set() did not raise InvalidID"
566                     % vars(self.Class))
567             except InvalidID:
568                 pass
569
570         def test_set(self):
571             """
572             Set should define the value returned by get.
573             """
574             self.s.add(self.id)
575             self.s.set(self.id, self.val)
576             ret = self.s.get(self.id)
577             self.failUnless(ret == self.val,
578                     "%s.get() returned %s not %s"
579                     % (vars(self.Class)['name'], ret, self.val))
580
581         def test_unicode_set(self):
582             """
583             Set should define the value returned by get.
584             """
585             val = u'Fran\xe7ois'
586             self.s.add(self.id)
587             self.s.set(self.id, val)
588             ret = self.s.get(self.id, decode=True)
589             self.failUnless(type(ret) == types.UnicodeType,
590                     "%s.get() returned %s not UnicodeType"
591                     % (vars(self.Class)['name'], type(ret)))
592             self.failUnless(ret == val,
593                     "%s.get() returned %s not %s"
594                     % (vars(self.Class)['name'], ret, self.val))
595             ret = self.s.get(self.id)
596             self.failUnless(type(ret) == types.StringType,
597                     "%s.get() returned %s not StringType"
598                     % (vars(self.Class)['name'], type(ret)))
599             s = unicode(ret, self.s.encoding)
600             self.failUnless(s == val,
601                     "%s.get() returned %s not %s"
602                     % (vars(self.Class)['name'], s, self.val))
603             
604
605     class Storage_persistence_TestCase (StorageTestCase):
606         """Test cases for Storage.disconnect and .connect methods."""
607
608         id = 'unlikely id'
609         val = 'unlikely value'
610
611         def test_get_set_persistence(self):
612             """
613             Set should define the value returned by get after reconnect.
614             """
615             self.s.add(self.id)
616             self.s.set(self.id, self.val)
617             self.s.disconnect()
618             self.s.connect()
619             ret = self.s.get(self.id)
620             self.failUnless(ret == self.val,
621                     "%s.get() returned %s not %s"
622                     % (vars(self.Class)['name'], ret, self.val))
623
624         def test_add_nonrooted_persistence(self):
625             """
626             Adding entries should increase the number of children after reconnect.
627             """
628             self.s.add('parent')
629             ids = []
630             for i in range(10):
631                 ids.append(str(i))
632                 self.s.add(ids[-1], 'parent')
633             self.s.disconnect()
634             self.s.connect()
635             s = sorted(self.s.children('parent'))
636             self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
637             s = self.s.children()
638             self.failUnless(s == ['parent'], s)
639
640     class VersionedStorageTestCase (StorageTestCase):
641         """Test cases for base VersionedStorage class."""
642
643         Class = VersionedStorage
644
645     class VersionedStorage_commit_TestCase (VersionedStorageTestCase):
646         """Test cases for VersionedStorage methods."""
647
648         id = 'I' #unlikely id'
649         val = 'X'
650         commit_msg = 'C' #ommitting something interesting'
651         commit_body = 'B' #ome\nlonger\ndescription\n'
652
653         def test_revision_id_exception(self):
654             """
655             Invalid revision id should raise InvalidRevision.
656             """
657             try:
658                 rev = self.s.revision_id('highly unlikely revision id')
659                 self.fail(
660                     "%s.revision_id() didn't raise InvalidRevision, returned %s."
661                     % (vars(self.Class)['name'], rev))
662             except InvalidRevision:
663                 pass
664
665         def test_empty_commit_raises_exception(self):
666             """
667             Empty commit should raise exception.
668             """
669             try:
670                 self.s.commit(self.commit_msg, self.commit_body)
671                 self.fail(
672                     "Empty %(name)s.commit() didn't raise EmptyCommit."
673                     % vars(self.Class))
674             except EmptyCommit:
675                 pass
676
677         def test_empty_commit_allowed(self):
678             """
679             Empty commit should _not_ raise exception if allow_empty=True.
680             """
681             self.s.commit(self.commit_msg, self.commit_body,
682                           allow_empty=True)
683
684         def test_commit_revision_ids(self):
685             """
686             Commit / revision_id should agree on revision ids.
687             """
688             revs = []
689             for s in range(10):
690                 revs.append(self.s.commit(self.commit_msg,
691                                           self.commit_body,
692                                           allow_empty=True))
693             for i in range(10):
694                 rev = self.s.revision_id(i+1) 
695                 self.failUnless(rev == revs[i],
696                                 "%s.revision_id(%d) returned %s not %s"
697                                 % (vars(self.Class)['name'], i+1, rev, revs[i]))
698             for i in range(-1, -9, -1):
699                 rev = self.s.revision_id(i)
700                 self.failUnless(rev == revs[i],
701                                 "%s.revision_id(%d) returned %s not %s"
702                                 % (vars(self.Class)['name'], i, rev, revs[i]))
703
704         def test_get_previous_version(self):
705             """
706             Get should be able to return the previous version.
707             """
708             def val(i):
709                 return '%s:%d' % (self.val, i+1)
710             self.s.add(self.id)
711             revs = []
712             for i in range(10):
713                 self.s.set(self.id, val(i))
714                 revs.append(self.s.commit('%s: %d' % (self.commit_msg, i),
715                                           self.commit_body))
716             for i in range(10):
717                 ret = self.s.get(self.id, revision=revs[i])
718                 self.failUnless(ret == val(i),
719                                 "%s.get() returned %s not %s for revision %d"
720                                 % (vars(self.Class)['name'], ret, val(i), revs[i]))
721         
722     def make_storage_testcase_subclasses(storage_class, namespace):
723         """Make StorageTestCase subclasses for storage_class in namespace."""
724         storage_testcase_classes = [
725             c for c in (
726                 ob for ob in globals().values() if isinstance(ob, type))
727             if issubclass(c, StorageTestCase) \
728                 and not issubclass(c, VersionedStorageTestCase)]
729
730         for base_class in storage_testcase_classes:
731             testcase_class_name = storage_class.__name__ + base_class.__name__
732             testcase_class_bases = (base_class,)
733             testcase_class_dict = dict(base_class.__dict__)
734             testcase_class_dict['Class'] = storage_class
735             testcase_class = type(
736                 testcase_class_name, testcase_class_bases, testcase_class_dict)
737             setattr(namespace, testcase_class_name, testcase_class)
738
739     def make_versioned_storage_testcase_subclasses(storage_class, namespace):
740         """Make VersionedStorageTestCase subclasses for storage_class in namespace."""
741         storage_testcase_classes = [
742             c for c in (
743                 ob for ob in globals().values() if isinstance(ob, type))
744             if issubclass(c, StorageTestCase)]
745
746         for base_class in storage_testcase_classes:
747             testcase_class_name = storage_class.__name__ + base_class.__name__
748             testcase_class_bases = (base_class,)
749             testcase_class_dict = dict(base_class.__dict__)
750             testcase_class_dict['Class'] = storage_class
751             testcase_class = type(
752                 testcase_class_name, testcase_class_bases, testcase_class_dict)
753             setattr(namespace, testcase_class_name, testcase_class)
754
755     make_storage_testcase_subclasses(VersionedStorage, sys.modules[__name__])
756
757     unitsuite =unittest.TestLoader().loadTestsFromModule(sys.modules[__name__])
758     suite = unittest.TestSuite([unitsuite, doctest.DocTestSuite()])