Merged clarifications requested by Ben Finney
[be.git] / libbe / storage / base.py
1 # Copyright (C) 2009-2010 W. Trevor King <wking@drexel.edu>
2 #
3 # This program is free software; you can redistribute it and/or modify
4 # it under the terms of the GNU General Public License as published by
5 # the Free Software Foundation; either version 2 of the License, or
6 # (at your option) any later version.
7 #
8 # This program is distributed in the hope that it will be useful,
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 # GNU General Public License for more details.
12 #
13 # You should have received a copy of the GNU General Public License along
14 # with this program; if not, write to the Free Software Foundation, Inc.,
15 # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
16
17 """
18 Abstract bug repository data storage to easily support multiple backends.
19 """
20
21 import copy
22 import os
23 import pickle
24 import types
25
26 from libbe.error import NotSupported
27 import libbe.storage
28 from libbe.util.tree import Tree
29 from libbe.util import InvalidObject
30 import libbe.version
31 from libbe import TESTING
32
33 if TESTING == True:
34     import doctest
35     import os.path
36     import sys
37     import unittest
38
39     from libbe.util.utility import Dir
40
41 class ConnectionError (Exception):
42     pass
43
44 class InvalidStorageVersion(ConnectionError):
45     def __init__(self, active_version, expected_version=None):
46         if expected_version == None:
47             expected_version = libbe.storage.STORAGE_VERSION
48         msg = 'Storage in "%s" not the expected "%s"' \
49             % (active_version, expected_version)
50         Exception.__init__(self, msg)
51         self.active_version = active_version
52         self.expected_version = expected_version
53
54 class InvalidID (KeyError):
55     def __init__(self, id=None, revision=None, msg=None):
56         KeyError.__init__(self, id)
57         self.msg = msg
58         self.id = id
59         self.revision = revision
60     def __str__(self):
61         if self.msg == None:
62             return '%s in revision %s' % (self.id, self.revision)
63         return self.msg
64
65
66 class InvalidRevision (KeyError):
67     pass
68
69 class InvalidDirectory (Exception):
70     pass
71
72 class DirectoryNotEmpty (InvalidDirectory):
73     pass
74
75 class NotWriteable (NotSupported):
76     def __init__(self, msg):
77         NotSupported.__init__(self, 'write', msg)
78
79 class NotReadable (NotSupported):
80     def __init__(self, msg):
81         NotSupported.__init__(self, 'read', msg)
82
83 class EmptyCommit(Exception):
84     def __init__(self):
85         Exception.__init__(self, 'No changes to commit')
86
87 class _EMPTY (object):
88     """Entry has been added but has no user-set value."""
89     pass
90
91 class Entry (Tree):
92     def __init__(self, id, value=_EMPTY, parent=None, directory=False,
93                  children=None):
94         if children == None:
95             Tree.__init__(self)
96         else:
97             Tree.__init__(self, children)
98         self.id = id
99         self.value = value
100         self.parent = parent
101         if self.parent != None:
102             if self.parent.directory == False:
103                 raise InvalidDirectory(
104                     'Non-directory %s cannot have children' % self.parent)
105             parent.append(self)
106         self.directory = directory
107
108     def __str__(self):
109         return '<Entry %s: %s>' % (self.id, self.value)
110
111     def __repr__(self):
112         return str(self)
113
114     def __cmp__(self, other, local=False):
115         if other == None:
116             return cmp(1, None)
117         if cmp(self.id, other.id) != 0:
118             return cmp(self.id, other.id)
119         if cmp(self.value, other.value) != 0:
120             return cmp(self.value, other.value)
121         if local == False:
122             if self.parent == None:
123                 if cmp(self.parent, other.parent) != 0:
124                     return cmp(self.parent, other.parent)
125             elif self.parent.__cmp__(other.parent, local=True) != 0:
126                 return self.parent.__cmp__(other.parent, local=True)
127             for sc,oc in zip(self, other):
128                 if sc.__cmp__(oc, local=True) != 0:
129                     return sc.__cmp__(oc, local=True)
130         return 0
131
132     def _objects_to_ids(self):
133         if self.parent != None:
134             self.parent = self.parent.id
135         for i,c in enumerate(self):
136             self[i] = c.id
137         return self
138
139     def _ids_to_objects(self, dict):
140         if self.parent != None:
141             self.parent = dict[self.parent]
142         for i,c in enumerate(self):
143             self[i] = dict[c]
144         return self
145
146 class Storage (object):
147     """
148     This class declares all the methods required by a Storage
149     interface.  This implementation just keeps the data in a
150     dictionary and uses pickle for persistent storage.
151     """
152     name = 'Storage'
153
154     def __init__(self, repo='/', encoding='utf-8', options=None):
155         self.repo = repo
156         self.encoding = encoding
157         self.options = options
158         self.readable = True  # soft limit (user choice)
159         self._readable = True # hard limit (backend choice)
160         self.writeable = True  # soft limit (user choice)
161         self._writeable = True # hard limit (backend choice)
162         self.versioned = False
163         self.can_init = True
164         self.connected = False
165
166     def __str__(self):
167         return '<%s %s %s>' % (self.__class__.__name__, id(self), self.repo)
168
169     def __repr__(self):
170         return str(self)
171
172     def version(self):
173         """Return a version string for this backend."""
174         return libbe.version.version()
175
176     def storage_version(self, revision=None):
177         """Return the storage format for this backend."""
178         return libbe.storage.STORAGE_VERSION
179
180     def is_readable(self):
181         return self.readable and self._readable
182
183     def is_writeable(self):
184         return self.writeable and self._writeable
185
186     def init(self):
187         """Create a new storage repository."""
188         if self.can_init == False:
189             raise NotSupported('init',
190                                'Cannot initialize this repository format.')
191         if self.is_writeable() == False:
192             raise NotWriteable('Cannot initialize unwriteable storage.')
193         return self._init()
194
195     def _init(self):
196         f = open(os.path.join(self.repo, 'repo.pkl'), 'wb')
197         root = Entry(id='__ROOT__', directory=True)
198         d = {root.id:root}
199         pickle.dump(dict((k,v._objects_to_ids()) for k,v in d.items()), f, -1)
200         f.close()
201
202     def destroy(self):
203         """Remove the storage repository."""
204         if self.is_writeable() == False:
205             raise NotWriteable('Cannot destroy unwriteable storage.')
206         return self._destroy()
207
208     def _destroy(self):
209         os.remove(os.path.join(self.repo, 'repo.pkl'))
210
211     def connect(self):
212         """Open a connection to the repository."""
213         if self.is_readable() == False:
214             raise NotReadable('Cannot connect to unreadable storage.')
215         self._connect()
216         self.connected = True
217
218     def _connect(self):
219         try:
220             f = open(os.path.join(self.repo, 'repo.pkl'), 'rb')
221         except IOError:
222             raise ConnectionError(self)
223         d = pickle.load(f)
224         self._data = dict((k,v._ids_to_objects(d)) for k,v in d.items())
225         f.close()
226
227     def disconnect(self):
228         """Close the connection to the repository."""
229         if self.is_writeable() == False:
230             return
231         if self.connected == False:
232             return
233         self._disconnect()
234         self.connected = False
235
236     def _disconnect(self):
237         f = open(os.path.join(self.repo, 'repo.pkl'), 'wb')
238         pickle.dump(dict((k,v._objects_to_ids())
239                          for k,v in self._data.items()), f, -1)
240         f.close()
241         self._data = None
242
243     def add(self, id, *args, **kwargs):
244         """Add an entry"""
245         if self.is_writeable() == False:
246             raise NotWriteable('Cannot add entry to unwriteable storage.')
247         if not self.exists(id):
248             self._add(id, *args, **kwargs)
249
250     def _add(self, id, parent=None, directory=False):
251         if parent == None:
252             parent = '__ROOT__'
253         p = self._data[parent]
254         self._data[id] = Entry(id, parent=p, directory=directory)
255
256     def exists(self, *args, **kwargs):
257         """Check an entry's existence"""
258         if self.is_readable() == False:
259             raise NotReadable('Cannot check entry existence in unreadable storage.')
260         return self._exists(*args, **kwargs)
261
262     def _exists(self, id, revision=None):
263         return id in self._data
264
265     def remove(self, *args, **kwargs):
266         """Remove an entry."""
267         if self.is_writeable() == False:
268             raise NotSupported('write',
269                                'Cannot remove entry from unwriteable storage.')
270         self._remove(*args, **kwargs)
271
272     def _remove(self, id):
273         if self._data[id].directory == True \
274                 and len(self.children(id)) > 0:
275             raise DirectoryNotEmpty(id)
276         e = self._data.pop(id)
277         e.parent.remove(e)
278
279     def recursive_remove(self, *args, **kwargs):
280         """Remove an entry and all its decendents."""
281         if self.is_writeable() == False:
282             raise NotSupported('write',
283                                'Cannot remove entries from unwriteable storage.')
284         self._recursive_remove(*args, **kwargs)
285
286     def _recursive_remove(self, id):
287         for entry in reversed(list(self._data[id].traverse())):
288             self._remove(entry.id)
289
290     def ancestors(self, *args, **kwargs):
291         """Return a list of the specified entry's ancestors' ids."""
292         if self.is_readable() == False:
293             raise NotReadable('Cannot list parents with unreadable storage.')
294         return self._ancestors(*args, **kwargs)
295
296     def _ancestors(self, id=None, revision=None):
297         if id == None:
298             return []
299         ancestors = []
300         stack = [id]
301         while len(stack) > 0:
302             id = stack.pop(0)
303             parent = self._data[id].parent
304             if parent != None and not parent.id.startswith('__'):
305                 ancestor = parent.id
306                 ancestors.append(ancestor)
307                 stack.append(ancestor)
308         return ancestors
309
310     def children(self, *args, **kwargs):
311         """Return a list of specified entry's children's ids."""
312         if self.is_readable() == False:
313             raise NotReadable('Cannot list children with unreadable storage.')
314         return self._children(*args, **kwargs)
315
316     def _children(self, id=None, revision=None):
317         if id == None:
318             id = '__ROOT__'
319         return [c.id for c in self._data[id] if not c.id.startswith('__')]
320
321     def get(self, *args, **kwargs):
322         """
323         Get contents of and entry as they were in a given revision.
324         revision==None specifies the current revision.
325
326         If there is no id, return default, unless default is not
327         given, in which case raise InvalidID.
328         """
329         if self.is_readable() == False:
330             raise NotReadable('Cannot get entry with unreadable storage.')
331         if 'decode' in kwargs:
332             decode = kwargs.pop('decode')
333         else:
334             decode = False
335         value = self._get(*args, **kwargs)
336         if value != None:
337             if decode == True and type(value) != types.UnicodeType:
338                 return unicode(value, self.encoding)
339             elif decode == False and type(value) != types.StringType:
340                 return value.encode(self.encoding)
341         return value
342
343     def _get(self, id, default=InvalidObject, revision=None):
344         if id in self._data and self._data[id].value != _EMPTY:
345             return self._data[id].value
346         elif default == InvalidObject:
347             raise InvalidID(id)
348         return default
349
350     def set(self, id, value, *args, **kwargs):
351         """
352         Set the entry contents.
353         """
354         if self.is_writeable() == False:
355             raise NotWriteable('Cannot set entry in unwriteable storage.')
356         if type(value) == types.UnicodeType:
357             value = value.encode(self.encoding)
358         self._set(id, value, *args, **kwargs)
359
360     def _set(self, id, value):
361         if id not in self._data:
362             raise InvalidID(id)
363         if self._data[id].directory == True:
364             raise InvalidDirectory(
365                 'Directory %s cannot have data' % self.parent)
366         self._data[id].value = value
367
368 class VersionedStorage (Storage):
369     """
370     This class declares all the methods required by a Storage
371     interface that supports versioning.  This implementation just
372     keeps the data in a list and uses pickle for persistent
373     storage.
374     """
375     name = 'VersionedStorage'
376
377     def __init__(self, *args, **kwargs):
378         Storage.__init__(self, *args, **kwargs)
379         self.versioned = True
380
381     def _init(self):
382         f = open(os.path.join(self.repo, 'repo.pkl'), 'wb')
383         root = Entry(id='__ROOT__', directory=True)
384         summary = Entry(id='__COMMIT__SUMMARY__', value='Initial commit')
385         body = Entry(id='__COMMIT__BODY__')
386         initial_commit = {root.id:root, summary.id:summary, body.id:body}
387         d = dict((k,v._objects_to_ids()) for k,v in initial_commit.items())
388         pickle.dump([d, copy.deepcopy(d)], f, -1) # [inital tree, working tree]
389         f.close()
390
391     def _connect(self):
392         try:
393             f = open(os.path.join(self.repo, 'repo.pkl'), 'rb')
394         except IOError:
395             raise ConnectionError(self)
396         d = pickle.load(f)
397         self._data = [dict((k,v._ids_to_objects(t)) for k,v in t.items())
398                       for t in d]
399         f.close()
400
401     def _disconnect(self):
402         f = open(os.path.join(self.repo, 'repo.pkl'), 'wb')
403         pickle.dump([dict((k,v._objects_to_ids())
404                           for k,v in t.items()) for t in self._data], f, -1)
405         f.close()
406         self._data = None
407
408     def _add(self, id, parent=None, directory=False):
409         if parent == None:
410             parent = '__ROOT__'
411         p = self._data[-1][parent]
412         self._data[-1][id] = Entry(id, parent=p, directory=directory)
413
414     def _exists(self, id, revision=None):
415         if revision == None:
416             revision = -1
417         else:
418             revision = int(revision)
419         return id in self._data[revision]
420
421     def _remove(self, id):
422         if self._data[-1][id].directory == True \
423                 and len(self.children(id)) > 0:
424             raise DirectoryNotEmpty(id)
425         e = self._data[-1].pop(id)
426         e.parent.remove(e)
427
428     def _recursive_remove(self, id):
429         for entry in reversed(list(self._data[-1][id].traverse())):
430             self._remove(entry.id)
431
432     def _ancestors(self, id=None, revision=None):
433         if id == None:
434             return []
435         if revision == None:
436             revision = -1
437         else:
438             revision = int(revision)
439         ancestors = []
440         stack = [id]
441         while len(stack) > 0:
442             id = stack.pop(0)
443             parent = self._data[revision][id].parent
444             if parent != None and not parent.id.startswith('__'):
445                 ancestor = parent.id
446                 ancestors.append(ancestor)
447                 stack.append(ancestor)
448         return ancestors
449
450     def _children(self, id=None, revision=None):
451         if id == None:
452             id = '__ROOT__'
453         if revision == None:
454             revision = -1
455         else:
456             revision = int(revision)
457         return [c.id for c in self._data[revision][id]
458                 if not c.id.startswith('__')]
459
460     def _get(self, id, default=InvalidObject, revision=None):
461         if revision == None:
462             revision = -1
463         else:
464             revision = int(revision)
465         if id in self._data[revision] \
466                 and self._data[revision][id].value != _EMPTY:
467             return self._data[revision][id].value
468         elif default == InvalidObject:
469             raise InvalidID(id)
470         return default
471
472     def _set(self, id, value):
473         if id not in self._data[-1]:
474             raise InvalidID(id)
475         self._data[-1][id].value = value
476
477     def commit(self, *args, **kwargs):
478         """
479         Commit the current repository, with a commit message string
480         summary and body.  Return the name of the new revision.
481
482         If allow_empty == False (the default), raise EmptyCommit if
483         there are no changes to commit.
484         """
485         if self.is_writeable() == False:
486             raise NotWriteable('Cannot commit to unwriteable storage.')
487         return self._commit(*args, **kwargs)
488
489     def _commit(self, summary, body=None, allow_empty=False):
490         if self._data[-1] == self._data[-2] and allow_empty == False:
491             raise EmptyCommit
492         self._data[-1]["__COMMIT__SUMMARY__"].value = summary
493         self._data[-1]["__COMMIT__BODY__"].value = body
494         rev = str(len(self._data)-1)
495         self._data.append(copy.deepcopy(self._data[-1]))
496         return rev
497
498     def revision_id(self, index=None):
499         """
500         Return the name of the <index>th revision.  The choice of
501         which branch to follow when crossing branches/merges is not
502         defined.  Revision indices start at 1; ID 0 is the blank
503         repository.
504
505         Return None if index==None.
506
507         If the specified revision does not exist, raise InvalidRevision.
508         """
509         if index == None:
510             return None
511         try:
512             if int(index) != index:
513                 raise InvalidRevision(index)
514         except ValueError:
515             raise InvalidRevision(index)
516         L = len(self._data) - 1  # -1 b/c of initial commit
517         if index >= -L and index <= L:
518             return str(index % L)
519         raise InvalidRevision(i)
520
521     def changed(self, revision):
522         """
523         Return a tuple of lists of ids
524           (new, modified, removed)
525         from the specified revision to the current situation.
526         """
527         new = []
528         modified = []
529         removed = []
530         for id,value in self._data[int(revision)].items():
531             if id.startswith('__'):
532                 continue
533             if not id in self._data[-1]:
534                 removed.append(id)
535             elif value.value != self._data[-1][id].value:
536                 modified.append(id)
537         for id in self._data[-1]:
538             if not id in self._data[int(revision)]:
539                 new.append(id)
540         return (new, modified, removed)
541
542
543 if TESTING == True:
544     class StorageTestCase (unittest.TestCase):
545         """Test cases for Storage class."""
546
547         Class = Storage
548
549         def __init__(self, *args, **kwargs):
550             super(StorageTestCase, self).__init__(*args, **kwargs)
551             self.dirname = None
552
553         # this class will be the basis of tests for several classes,
554         # so make sure we print the name of the class we're dealing with.
555         def _classname(self):
556             version = '?'
557             try:
558                 if hasattr(self, 's'):
559                     version = self.s.version()
560             except:
561                 pass
562             return '%s:%s' % (self.Class.__name__, version)
563
564         def fail(self, msg=None):
565             """Fail immediately, with the given message."""
566             raise self.failureException, \
567                 '(%s) %s' % (self._classname(), msg)
568
569         def failIf(self, expr, msg=None):
570             "Fail the test if the expression is true."
571             if expr: raise self.failureException, \
572                 '(%s) %s' % (self.classname(), msg)
573
574         def failUnless(self, expr, msg=None):
575             """Fail the test unless the expression is true."""
576             if not expr: raise self.failureException, \
577                 '(%s) %s' % (self.classname(), msg)
578
579         def setUp(self):
580             """Set up test fixtures for Storage test case."""
581             super(StorageTestCase, self).setUp()
582             self.dir = Dir()
583             self.dirname = self.dir.path
584             self.s = self.Class(repo=self.dirname)
585             self.assert_failed_connect()
586             self.s.init()
587             self.s.connect()
588
589         def tearDown(self):
590             super(StorageTestCase, self).tearDown()
591             self.s.disconnect()
592             self.s.destroy()
593             self.assert_failed_connect()
594             self.dir.cleanup()
595
596         def assert_failed_connect(self):
597             try:
598                 self.s.connect()
599                 self.fail(
600                     "Connected to %(name)s repository before initialising"
601                     % vars(self.Class))
602             except ConnectionError:
603                 pass
604
605     class Storage_init_TestCase (StorageTestCase):
606         """Test cases for Storage.init method."""
607
608         def test_connect_should_succeed_after_init(self):
609             """Should connect after initialization."""
610             self.s.connect()
611
612     class Storage_connect_disconnect_TestCase (StorageTestCase):
613         """Test cases for Storage.connect and .disconnect methods."""
614
615         def test_multiple_disconnects(self):
616             """Should be able to call .disconnect multiple times."""
617             self.s.disconnect()
618             self.s.disconnect()
619
620     class Storage_add_remove_TestCase (StorageTestCase):
621         """Test cases for Storage.add, .remove, and .recursive_remove methods."""
622
623         def test_initially_empty(self):
624             """New repository should be empty."""
625             self.failUnless(len(self.s.children()) == 0, self.s.children())
626
627         def test_add_identical_rooted(self):
628             """Adding entries with the same ID should not increase the number of children.
629             """
630             for i in range(10):
631                 self.s.add('some id', directory=False)
632                 s = sorted(self.s.children())
633                 self.failUnless(s == ['some id'], s)
634
635         def test_add_rooted(self):
636             """Adding entries should increase the number of children (rooted).
637             """
638             ids = []
639             for i in range(10):
640                 ids.append(str(i))
641                 self.s.add(ids[-1], directory=(i % 2 == 0))
642                 s = sorted(self.s.children())
643                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
644
645         def test_add_nonrooted(self):
646             """Adding entries should increase the number of children (nonrooted).
647             """
648             self.s.add('parent', directory=True)
649             ids = []
650             for i in range(10):
651                 ids.append(str(i))
652                 self.s.add(ids[-1], 'parent', directory=(i % 2 == 0))
653                 s = sorted(self.s.children('parent'))
654                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
655                 s = self.s.children()
656                 self.failUnless(s == ['parent'], s)
657
658         def test_ancestors(self):
659             """Check ancestors lists.
660             """
661             self.s.add('parent', directory=True)
662             for i in range(10):
663                 i_id = str(i)
664                 self.s.add(i_id, 'parent', directory=True)
665                 for j in range(10): # add some grandkids
666                     j_id = str(20*(i+1)+j)
667                     self.s.add(j_id, i_id, directory=(i%2 == 0))
668                     ancestors = sorted(self.s.ancestors(j_id))
669                     self.failUnless(ancestors == [i_id, 'parent'],
670                         'Unexpected ancestors for %s/%s, "%s"'
671                         % (i_id, j_id, ancestors))
672
673         def test_children(self):
674             """Non-UUID ids should be returned as such.
675             """
676             self.s.add('parent', directory=True)
677             ids = []
678             for i in range(10):
679                 ids.append('parent/%s' % str(i))
680                 self.s.add(ids[-1], 'parent', directory=(i % 2 == 0))
681                 s = sorted(self.s.children('parent'))
682                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
683
684         def test_add_invalid_directory(self):
685             """Should not be able to add children to non-directories.
686             """
687             self.s.add('parent', directory=False)
688             try:
689                 self.s.add('child', 'parent', directory=False)
690                 self.fail(
691                     '%s.add() succeeded instead of raising InvalidDirectory'
692                     % (vars(self.Class)['name']))
693             except InvalidDirectory:
694                 pass
695             try:
696                 self.s.add('child', 'parent', directory=True)
697                 self.fail(
698                     '%s.add() succeeded instead of raising InvalidDirectory'
699                     % (vars(self.Class)['name']))
700             except InvalidDirectory:
701                 pass
702             self.failUnless(len(self.s.children('parent')) == 0,
703                             self.s.children('parent'))
704
705         def test_remove_rooted(self):
706             """Removing entries should decrease the number of children (rooted).
707             """
708             ids = []
709             for i in range(10):
710                 ids.append(str(i))
711                 self.s.add(ids[-1], directory=(i % 2 == 0))
712             for i in range(10):
713                 self.s.remove(ids.pop())
714                 s = sorted(self.s.children())
715                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
716
717         def test_remove_nonrooted(self):
718             """Removing entries should decrease the number of children (nonrooted).
719             """
720             self.s.add('parent', directory=True)
721             ids = []
722             for i in range(10):
723                 ids.append(str(i))
724                 self.s.add(ids[-1], 'parent', directory=False)#(i % 2 == 0))
725             for i in range(10):
726                 self.s.remove(ids.pop())
727                 s = sorted(self.s.children('parent'))
728                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
729                 if len(s) > 0:
730                     s = self.s.children()
731                     self.failUnless(s == ['parent'], s)
732
733         def test_remove_directory_not_empty(self):
734             """Removing a non-empty directory entry should raise exception.
735             """
736             self.s.add('parent', directory=True)
737             ids = []
738             for i in range(10):
739                 ids.append(str(i))
740                 self.s.add(ids[-1], 'parent', directory=(i % 2 == 0))
741             self.s.remove(ids.pop()) # empty directory removal succeeds
742             try:
743                 self.s.remove('parent') # empty directory removal succeeds
744                 self.fail(
745                     "%s.remove() didn't raise DirectoryNotEmpty"
746                     % (vars(self.Class)['name']))
747             except DirectoryNotEmpty:
748                 pass
749
750         def test_recursive_remove(self):
751             """Recursive remove should empty the tree."""
752             self.s.add('parent', directory=True)
753             ids = []
754             for i in range(10):
755                 ids.append(str(i))
756                 self.s.add(ids[-1], 'parent', directory=True)
757                 for j in range(10): # add some grandkids
758                     self.s.add(str(20*(i+1)+j), ids[-1], directory=(i%2 == 0))
759             self.s.recursive_remove('parent')
760             s = sorted(self.s.children())
761             self.failUnless(s == [], s)
762
763     class Storage_get_set_TestCase (StorageTestCase):
764         """Test cases for Storage.get and .set methods."""
765
766         id = 'unlikely id'
767         val = 'unlikely value'
768
769         def test_get_default(self):
770             """Get should return specified default if id not in Storage.
771             """
772             ret = self.s.get(self.id, default=self.val)
773             self.failUnless(ret == self.val,
774                     "%s.get() returned %s not %s"
775                     % (vars(self.Class)['name'], ret, self.val))
776
777         def test_get_default_exception(self):
778             """Get should raise exception if id not in Storage and no default.
779             """
780             try:
781                 ret = self.s.get(self.id)
782                 self.fail(
783                     "%s.get() returned %s instead of raising InvalidID"
784                     % (vars(self.Class)['name'], ret))
785             except InvalidID:
786                 pass
787
788         def test_get_initial_value(self):
789             """Data value should be default before any value has been set.
790             """
791             self.s.add(self.id, directory=False)
792             val = 'UNLIKELY DEFAULT'
793             ret = self.s.get(self.id, default=val)
794             self.failUnless(ret == val,
795                     "%s.get() returned %s not %s"
796                     % (vars(self.Class)['name'], ret, val))
797
798         def test_set_exception(self):
799             """Set should raise exception if id not in Storage.
800             """
801             try:
802                 self.s.set(self.id, self.val)
803                 self.fail(
804                     "%(name)s.set() did not raise InvalidID"
805                     % vars(self.Class))
806             except InvalidID:
807                 pass
808
809         def test_set(self):
810             """Set should define the value returned by get.
811             """
812             self.s.add(self.id, directory=False)
813             self.s.set(self.id, self.val)
814             ret = self.s.get(self.id)
815             self.failUnless(ret == self.val,
816                     "%s.get() returned %s not %s"
817                     % (vars(self.Class)['name'], ret, self.val))
818
819         def test_unicode_set(self):
820             """Set should define the value returned by get.
821             """
822             val = u'Fran\xe7ois'
823             self.s.add(self.id, directory=False)
824             self.s.set(self.id, val)
825             ret = self.s.get(self.id, decode=True)
826             self.failUnless(type(ret) == types.UnicodeType,
827                     "%s.get() returned %s not UnicodeType"
828                     % (vars(self.Class)['name'], type(ret)))
829             self.failUnless(ret == val,
830                     "%s.get() returned %s not %s"
831                     % (vars(self.Class)['name'], ret, self.val))
832             ret = self.s.get(self.id)
833             self.failUnless(type(ret) == types.StringType,
834                     "%s.get() returned %s not StringType"
835                     % (vars(self.Class)['name'], type(ret)))
836             s = unicode(ret, self.s.encoding)
837             self.failUnless(s == val,
838                     "%s.get() returned %s not %s"
839                     % (vars(self.Class)['name'], s, self.val))
840
841
842     class Storage_persistence_TestCase (StorageTestCase):
843         """Test cases for Storage.disconnect and .connect methods."""
844
845         id = 'unlikely id'
846         val = 'unlikely value'
847
848         def test_get_set_persistence(self):
849             """Set should define the value returned by get after reconnect.
850             """
851             self.s.add(self.id, directory=False)
852             self.s.set(self.id, self.val)
853             self.s.disconnect()
854             self.s.connect()
855             ret = self.s.get(self.id)
856             self.failUnless(ret == self.val,
857                     "%s.get() returned %s not %s"
858                     % (vars(self.Class)['name'], ret, self.val))
859
860         def test_empty_get_set_persistence(self):
861             """After empty set, get may return either an empty string or default.
862             """
863             self.s.add(self.id, directory=False)
864             self.s.set(self.id, '')
865             self.s.disconnect()
866             self.s.connect()
867             default = 'UNLIKELY DEFAULT'
868             ret = self.s.get(self.id, default=default)
869             self.failUnless(ret in ['', default],
870                     "%s.get() returned %s not in %s"
871                     % (vars(self.Class)['name'], ret, ['', default]))
872
873         def test_add_nonrooted_persistence(self):
874             """Adding entries should increase the number of children after reconnect.
875             """
876             self.s.add('parent', directory=True)
877             ids = []
878             for i in range(10):
879                 ids.append(str(i))
880                 self.s.add(ids[-1], 'parent', directory=(i % 2 == 0))
881             self.s.disconnect()
882             self.s.connect()
883             s = sorted(self.s.children('parent'))
884             self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
885             s = self.s.children()
886             self.failUnless(s == ['parent'], s)
887
888     class VersionedStorageTestCase (StorageTestCase):
889         """Test cases for VersionedStorage methods."""
890
891         Class = VersionedStorage
892
893     class VersionedStorage_commit_TestCase (VersionedStorageTestCase):
894         """Test cases for VersionedStorage.commit and revision_ids methods."""
895
896         id = 'unlikely id'
897         val = 'Some value'
898         commit_msg = 'Committing something interesting'
899         commit_body = 'Some\nlonger\ndescription\n'
900
901         def _setup_for_empty_commit(self):
902             """
903             Initialization might add some files to version control, so
904             commit those first, before testing the empty commit
905             functionality.
906             """
907             try:
908                 self.s.commit('Added initialization files')
909             except EmptyCommit:
910                 pass
911                 
912         def test_revision_id_exception(self):
913             """Invalid revision id should raise InvalidRevision.
914             """
915             try:
916                 rev = self.s.revision_id('highly unlikely revision id')
917                 self.fail(
918                     "%s.revision_id() didn't raise InvalidRevision, returned %s."
919                     % (vars(self.Class)['name'], rev))
920             except InvalidRevision:
921                 pass
922
923         def test_empty_commit_raises_exception(self):
924             """Empty commit should raise exception.
925             """
926             self._setup_for_empty_commit()
927             try:
928                 self.s.commit(self.commit_msg, self.commit_body)
929                 self.fail(
930                     "Empty %(name)s.commit() didn't raise EmptyCommit."
931                     % vars(self.Class))
932             except EmptyCommit:
933                 pass
934
935         def test_empty_commit_allowed(self):
936             """Empty commit should _not_ raise exception if allow_empty=True.
937             """
938             self._setup_for_empty_commit()
939             self.s.commit(self.commit_msg, self.commit_body,
940                           allow_empty=True)
941
942         def test_commit_revision_ids(self):
943             """Commit / revision_id should agree on revision ids.
944             """
945             def val(i):
946                 return '%s:%d' % (self.val, i+1)
947             self.s.add(self.id, directory=False)
948             revs = []
949             for i in range(10):
950                 self.s.set(self.id, val(i))
951                 revs.append(self.s.commit('%s: %d' % (self.commit_msg, i),
952                                           self.commit_body))
953             for i in range(10):
954                 rev = self.s.revision_id(i+1)
955                 self.failUnless(rev == revs[i],
956                                 "%s.revision_id(%d) returned %s not %s"
957                                 % (vars(self.Class)['name'], i+1, rev, revs[i]))
958             for i in range(-1, -9, -1):
959                 rev = self.s.revision_id(i)
960                 self.failUnless(rev == revs[i],
961                                 "%s.revision_id(%d) returned %s not %s"
962                                 % (vars(self.Class)['name'], i, rev, revs[i]))
963
964         def test_get_previous_version(self):
965             """Get should be able to return the previous version.
966             """
967             def val(i):
968                 return '%s:%d' % (self.val, i+1)
969             self.s.add(self.id, directory=False)
970             revs = []
971             for i in range(10):
972                 self.s.set(self.id, val(i))
973                 revs.append(self.s.commit('%s: %d' % (self.commit_msg, i),
974                                           self.commit_body))
975             for i in range(10):
976                 ret = self.s.get(self.id, revision=revs[i])
977                 self.failUnless(ret == val(i),
978                                 "%s.get() returned %s not %s for revision %s"
979                                 % (vars(self.Class)['name'], ret, val(i), revs[i]))
980
981         def test_get_previous_children(self):
982             """Children list should be revision dependent.
983             """
984             self.s.add('parent', directory=True)
985             revs = []
986             cur_children = []
987             children = []
988             for i in range(10):
989                 new_child = str(i)
990                 self.s.add(new_child, 'parent')
991                 self.s.set(new_child, self.val)
992                 revs.append(self.s.commit('%s: %d' % (self.commit_msg, i),
993                                           self.commit_body))
994                 cur_children.append(new_child)
995                 children.append(list(cur_children))
996             for i in range(10):
997                 ret = self.s.children('parent', revision=revs[i])
998                 self.failUnless(ret == children[i],
999                                 "%s.get() returned %s not %s for revision %s"
1000                                 % (vars(self.Class)['name'], ret,
1001                                    children[i], revs[i]))
1002
1003     class VersionedStorage_changed_TestCase (VersionedStorageTestCase):
1004         """Test cases for VersionedStorage.changed() method."""
1005
1006         def test_changed(self):
1007             """Changed lists should reflect past activity"""
1008             self.s.add('dir', directory=True)
1009             self.s.add('modified', parent='dir')
1010             self.s.set('modified', 'some value to be modified')
1011             self.s.add('moved', parent='dir')
1012             self.s.set('moved', 'this entry will be moved')
1013             self.s.add('removed', parent='dir')
1014             self.s.set('removed', 'this entry will be deleted')
1015             revA = self.s.commit('Initial state')
1016             self.s.add('new', parent='dir')
1017             self.s.set('new', 'this entry is new')
1018             self.s.set('modified', 'a new value')
1019             self.s.remove('moved')
1020             self.s.add('moved2', parent='dir')
1021             self.s.set('moved2', 'this entry will be moved')
1022             self.s.remove('removed')
1023             revB = self.s.commit('Final state')
1024             new,mod,rem = self.s.changed(revA)
1025             self.failUnless(sorted(new) == ['moved2', 'new'],
1026                             'Unexpected new: %s' % new)
1027             self.failUnless(mod == ['modified'],
1028                             'Unexpected modified: %s' % mod)
1029             self.failUnless(sorted(rem) == ['moved', 'removed'],
1030                             'Unexpected removed: %s' % rem)
1031
1032     def make_storage_testcase_subclasses(storage_class, namespace):
1033         """Make StorageTestCase subclasses for storage_class in namespace."""
1034         storage_testcase_classes = [
1035             c for c in (
1036                 ob for ob in globals().values() if isinstance(ob, type))
1037             if issubclass(c, StorageTestCase) \
1038                 and c.Class == Storage]
1039
1040         for base_class in storage_testcase_classes:
1041             testcase_class_name = storage_class.__name__ + base_class.__name__
1042             testcase_class_bases = (base_class,)
1043             testcase_class_dict = dict(base_class.__dict__)
1044             testcase_class_dict['Class'] = storage_class
1045             testcase_class = type(
1046                 testcase_class_name, testcase_class_bases, testcase_class_dict)
1047             setattr(namespace, testcase_class_name, testcase_class)
1048
1049     def make_versioned_storage_testcase_subclasses(storage_class, namespace):
1050         """Make VersionedStorageTestCase subclasses for storage_class in namespace."""
1051         storage_testcase_classes = [
1052             c for c in (
1053                 ob for ob in globals().values() if isinstance(ob, type))
1054             if ((issubclass(c, StorageTestCase) \
1055                      and c.Class == Storage)
1056                 or
1057                 (issubclass(c, VersionedStorageTestCase) \
1058                      and c.Class == VersionedStorage))]
1059
1060         for base_class in storage_testcase_classes:
1061             testcase_class_name = storage_class.__name__ + base_class.__name__
1062             testcase_class_bases = (base_class,)
1063             testcase_class_dict = dict(base_class.__dict__)
1064             testcase_class_dict['Class'] = storage_class
1065             testcase_class = type(
1066                 testcase_class_name, testcase_class_bases, testcase_class_dict)
1067             setattr(namespace, testcase_class_name, testcase_class)
1068
1069     make_storage_testcase_subclasses(VersionedStorage, sys.modules[__name__])
1070
1071     unitsuite =unittest.TestLoader().loadTestsFromModule(sys.modules[__name__])
1072     suite = unittest.TestSuite([unitsuite, doctest.DocTestSuite()])