Merged Trevor's tree
[be.git] / libbe / storage / base.py
1 # Copyright (C) 2009-2010 W. Trevor King <wking@drexel.edu>
2 #
3 # This program is free software; you can redistribute it and/or modify
4 # it under the terms of the GNU General Public License as published by
5 # the Free Software Foundation; either version 2 of the License, or
6 # (at your option) any later version.
7 #
8 # This program is distributed in the hope that it will be useful,
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 # GNU General Public License for more details.
12 #
13 # You should have received a copy of the GNU General Public License along
14 # with this program; if not, write to the Free Software Foundation, Inc.,
15 # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
16
17 """
18 Abstract bug repository data storage to easily support multiple backends.
19 """
20
21 import copy
22 import os
23 import pickle
24 import types
25
26 from libbe.error import NotSupported
27 import libbe.storage
28 from libbe.util.tree import Tree
29 from libbe.util import InvalidObject
30 import libbe.version
31 from libbe import TESTING
32
33 if TESTING == True:
34     import doctest
35     import os.path
36     import sys
37     import unittest
38
39     from libbe.util.utility import Dir
40
41 class ConnectionError (Exception):
42     pass
43
44 class InvalidStorageVersion(ConnectionError):
45     def __init__(self, active_version, expected_version=None):
46         if expected_version == None:
47             expected_version = libbe.storage.STORAGE_VERSION
48         msg = 'Storage in "%s" not the expected "%s"' \
49             % (active_version, expected_version)
50         Exception.__init__(self, msg)
51         self.active_version = active_version
52         self.expected_version = expected_version
53
54 class InvalidID (KeyError):
55     def __init__(self, id=None, revision=None, msg=None):
56         KeyError.__init__(self, id)
57         self.msg = msg
58         self.id = id
59         self.revision = revision
60     def __str__(self):
61         if self.msg == None:
62             return '%s in revision %s' % (self.id, self.revision)
63         return self.msg
64
65
66 class InvalidRevision (KeyError):
67     pass
68
69 class InvalidDirectory (Exception):
70     pass
71
72 class DirectoryNotEmpty (InvalidDirectory):
73     pass
74
75 class NotWriteable (NotSupported):
76     def __init__(self, msg):
77         NotSupported.__init__(self, 'write', msg)
78
79 class NotReadable (NotSupported):
80     def __init__(self, msg):
81         NotSupported.__init__(self, 'read', msg)
82
83 class EmptyCommit(Exception):
84     def __init__(self):
85         Exception.__init__(self, 'No changes to commit')
86
87 class _EMPTY (object):
88     """Entry has been added but has no user-set value."""
89     pass
90
91 class Entry (Tree):
92     def __init__(self, id, value=_EMPTY, parent=None, directory=False,
93                  children=None):
94         if children == None:
95             Tree.__init__(self)
96         else:
97             Tree.__init__(self, children)
98         self.id = id
99         self.value = value
100         self.parent = parent
101         if self.parent != None:
102             if self.parent.directory == False:
103                 raise InvalidDirectory(
104                     'Non-directory %s cannot have children' % self.parent)
105             parent.append(self)
106         self.directory = directory
107
108     def __str__(self):
109         return '<Entry %s: %s>' % (self.id, self.value)
110
111     def __repr__(self):
112         return str(self)
113
114     def __cmp__(self, other, local=False):
115         if other == None:
116             return cmp(1, None)
117         if cmp(self.id, other.id) != 0:
118             return cmp(self.id, other.id)
119         if cmp(self.value, other.value) != 0:
120             return cmp(self.value, other.value)
121         if local == False:
122             if self.parent == None:
123                 if cmp(self.parent, other.parent) != 0:
124                     return cmp(self.parent, other.parent)
125             elif self.parent.__cmp__(other.parent, local=True) != 0:
126                 return self.parent.__cmp__(other.parent, local=True)
127             for sc,oc in zip(self, other):
128                 if sc.__cmp__(oc, local=True) != 0:
129                     return sc.__cmp__(oc, local=True)
130         return 0
131
132     def _objects_to_ids(self):
133         if self.parent != None:
134             self.parent = self.parent.id
135         for i,c in enumerate(self):
136             self[i] = c.id
137         return self
138
139     def _ids_to_objects(self, dict):
140         if self.parent != None:
141             self.parent = dict[self.parent]
142         for i,c in enumerate(self):
143             self[i] = dict[c]
144         return self
145
146 class Storage (object):
147     """
148     This class declares all the methods required by a Storage
149     interface.  This implementation just keeps the data in a
150     dictionary and uses pickle for persistent storage.
151     """
152     name = 'Storage'
153
154     def __init__(self, repo='/', encoding='utf-8', options=None):
155         self.repo = repo
156         self.encoding = encoding
157         self.options = options
158         self.readable = True  # soft limit (user choice)
159         self._readable = True # hard limit (backend choice)
160         self.writeable = True  # soft limit (user choice)
161         self._writeable = True # hard limit (backend choice)
162         self.versioned = False
163         self.can_init = True
164         self.connected = False
165
166     def __str__(self):
167         return '<%s %s %s>' % (self.__class__.__name__, id(self), self.repo)
168
169     def __repr__(self):
170         return str(self)
171
172     def version(self):
173         """Return a version string for this backend."""
174         return libbe.version.version()
175
176     def storage_version(self, revision=None):
177         """Return the storage format for this backend."""
178         return libbe.storage.STORAGE_VERSION
179
180     def is_readable(self):
181         return self.readable and self._readable
182
183     def is_writeable(self):
184         return self.writeable and self._writeable
185
186     def init(self):
187         """Create a new storage repository."""
188         if self.can_init == False:
189             raise NotSupported('init',
190                                'Cannot initialize this repository format.')
191         if self.is_writeable() == False:
192             raise NotWriteable('Cannot initialize unwriteable storage.')
193         return self._init()
194
195     def _init(self):
196         f = open(os.path.join(self.repo, 'repo.pkl'), 'wb')
197         root = Entry(id='__ROOT__', directory=True)
198         d = {root.id:root}
199         pickle.dump(dict((k,v._objects_to_ids()) for k,v in d.items()), f, -1)
200         f.close()
201
202     def destroy(self):
203         """Remove the storage repository."""
204         if self.is_writeable() == False:
205             raise NotWriteable('Cannot destroy unwriteable storage.')
206         return self._destroy()
207
208     def _destroy(self):
209         os.remove(os.path.join(self.repo, 'repo.pkl'))
210
211     def connect(self):
212         """Open a connection to the repository."""
213         if self.is_readable() == False:
214             raise NotReadable('Cannot connect to unreadable storage.')
215         self._connect()
216         self.connected = True
217
218     def _connect(self):
219         try:
220             f = open(os.path.join(self.repo, 'repo.pkl'), 'rb')
221         except IOError:
222             raise ConnectionError(self)
223         d = pickle.load(f)
224         self._data = dict((k,v._ids_to_objects(d)) for k,v in d.items())
225         f.close()
226
227     def disconnect(self):
228         """Close the connection to the repository."""
229         if self.is_writeable() == False:
230             return
231         if self.connected == False:
232             return
233         self._disconnect()
234         self.connected = False
235
236     def _disconnect(self):
237         f = open(os.path.join(self.repo, 'repo.pkl'), 'wb')
238         pickle.dump(dict((k,v._objects_to_ids())
239                          for k,v in self._data.items()), f, -1)
240         f.close()
241         self._data = None
242
243     def add(self, id, *args, **kwargs):
244         """Add an entry"""
245         if self.is_writeable() == False:
246             raise NotWriteable('Cannot add entry to unwriteable storage.')
247         if not self.exists(id):
248             self._add(id, *args, **kwargs)
249
250     def _add(self, id, parent=None, directory=False):
251         if parent == None:
252             parent = '__ROOT__'
253         p = self._data[parent]
254         self._data[id] = Entry(id, parent=p, directory=directory)
255
256     def exists(self, *args, **kwargs):
257         """Check an entry's existence"""
258         if self.is_readable() == False:
259             raise NotReadable('Cannot check entry existence in unreadable storage.')
260         return self._exists(*args, **kwargs)
261
262     def _exists(self, id, revision=None):
263         return id in self._data
264
265     def remove(self, *args, **kwargs):
266         """Remove an entry."""
267         if self.is_writeable() == False:
268             raise NotSupported('write',
269                                'Cannot remove entry from unwriteable storage.')
270         self._remove(*args, **kwargs)
271
272     def _remove(self, id):
273         if self._data[id].directory == True \
274                 and len(self.children(id)) > 0:
275             raise DirectoryNotEmpty(id)
276         e = self._data.pop(id)
277         e.parent.remove(e)
278
279     def recursive_remove(self, *args, **kwargs):
280         """Remove an entry and all its decendents."""
281         if self.is_writeable() == False:
282             raise NotSupported('write',
283                                'Cannot remove entries from unwriteable storage.')
284         self._recursive_remove(*args, **kwargs)
285
286     def _recursive_remove(self, id):
287         for entry in reversed(list(self._data[id].traverse())):
288             self._remove(entry.id)
289
290     def ancestors(self, *args, **kwargs):
291         """Return a list of the specified entry's ancestors' ids."""
292         if self.is_readable() == False:
293             raise NotReadable('Cannot list parents with unreadable storage.')
294         return self._ancestors(*args, **kwargs)
295
296     def _ancestors(self, id=None, revision=None):
297         if id == None:
298             return []
299         ancestors = []
300         stack = [id]
301         while len(stack) > 0:
302             id = stack.pop(0)
303             parent = self._data[id].parent
304             if parent != None and not parent.id.startswith('__'):
305                 ancestor = parent.id
306                 ancestors.append(ancestor)
307                 stack.append(ancestor)
308         return ancestors
309
310     def children(self, *args, **kwargs):
311         """Return a list of specified entry's children's ids."""
312         if self.is_readable() == False:
313             raise NotReadable('Cannot list children with unreadable storage.')
314         return self._children(*args, **kwargs)
315
316     def _children(self, id=None, revision=None):
317         if id == None:
318             id = '__ROOT__'
319         return [c.id for c in self._data[id] if not c.id.startswith('__')]
320
321     def get(self, *args, **kwargs):
322         """
323         Get contents of and entry as they were in a given revision.
324         revision==None specifies the current revision.
325
326         If there is no id, return default, unless default is not
327         given, in which case raise InvalidID.
328         """
329         if self.is_readable() == False:
330             raise NotReadable('Cannot get entry with unreadable storage.')
331         if 'decode' in kwargs:
332             decode = kwargs.pop('decode')
333         else:
334             decode = False
335         value = self._get(*args, **kwargs)
336         if value != None:
337             if decode == True and type(value) != types.UnicodeType:
338                 return unicode(value, self.encoding)
339             elif decode == False and type(value) != types.StringType:
340                 return value.encode(self.encoding)
341         return value
342
343     def _get(self, id, default=InvalidObject, revision=None):
344         if id in self._data and self._data[id].value != _EMPTY:
345             return self._data[id].value
346         elif default == InvalidObject:
347             raise InvalidID(id)
348         return default
349
350     def set(self, id, value, *args, **kwargs):
351         """
352         Set the entry contents.
353         """
354         if self.is_writeable() == False:
355             raise NotWriteable('Cannot set entry in unwriteable storage.')
356         if type(value) == types.UnicodeType:
357             value = value.encode(self.encoding)
358         self._set(id, value, *args, **kwargs)
359
360     def _set(self, id, value):
361         if id not in self._data:
362             raise InvalidID(id)
363         if self._data[id].directory == True:
364             raise InvalidDirectory(
365                 'Directory %s cannot have data' % self.parent)
366         self._data[id].value = value
367
368 class VersionedStorage (Storage):
369     """
370     This class declares all the methods required by a Storage
371     interface that supports versioning.  This implementation just
372     keeps the data in a list and uses pickle for persistent
373     storage.
374     """
375     name = 'VersionedStorage'
376
377     def __init__(self, *args, **kwargs):
378         Storage.__init__(self, *args, **kwargs)
379         self.versioned = True
380
381     def _init(self):
382         f = open(os.path.join(self.repo, 'repo.pkl'), 'wb')
383         root = Entry(id='__ROOT__', directory=True)
384         summary = Entry(id='__COMMIT__SUMMARY__', value='Initial commit')
385         body = Entry(id='__COMMIT__BODY__')
386         initial_commit = {root.id:root, summary.id:summary, body.id:body}
387         d = dict((k,v._objects_to_ids()) for k,v in initial_commit.items())
388         pickle.dump([d, copy.deepcopy(d)], f, -1) # [inital tree, working tree]
389         f.close()
390
391     def _connect(self):
392         try:
393             f = open(os.path.join(self.repo, 'repo.pkl'), 'rb')
394         except IOError:
395             raise ConnectionError(self)
396         d = pickle.load(f)
397         self._data = [dict((k,v._ids_to_objects(t)) for k,v in t.items())
398                       for t in d]
399         f.close()
400
401     def _disconnect(self):
402         f = open(os.path.join(self.repo, 'repo.pkl'), 'wb')
403         pickle.dump([dict((k,v._objects_to_ids())
404                           for k,v in t.items()) for t in self._data], f, -1)
405         f.close()
406         self._data = None
407
408     def _add(self, id, parent=None, directory=False):
409         if parent == None:
410             parent = '__ROOT__'
411         p = self._data[-1][parent]
412         self._data[-1][id] = Entry(id, parent=p, directory=directory)
413
414     def _exists(self, id, revision=None):
415         if revision == None:
416             revision = -1
417         else:
418             revision = int(revision)
419         return id in self._data[revision]
420
421     def _remove(self, id):
422         if self._data[-1][id].directory == True \
423                 and len(self.children(id)) > 0:
424             raise DirectoryNotEmpty(id)
425         e = self._data[-1].pop(id)
426         e.parent.remove(e)
427
428     def _recursive_remove(self, id):
429         for entry in reversed(list(self._data[-1][id].traverse())):
430             self._remove(entry.id)
431
432     def _ancestors(self, id=None, revision=None):
433         if id == None:
434             return []
435         if revision == None:
436             revision = -1
437         else:
438             revision = int(revision)
439         ancestors = []
440         stack = [id]
441         while len(stack) > 0:
442             id = stack.pop(0)
443             parent = self._data[revision][id].parent
444             if parent != None and not parent.id.startswith('__'):
445                 ancestor = parent.id
446                 ancestors.append(ancestor)
447                 stack.append(ancestor)
448         return ancestors
449
450     def _children(self, id=None, revision=None):
451         if id == None:
452             id = '__ROOT__'
453         if revision == None:
454             revision = -1
455         else:
456             revision = int(revision)
457         return [c.id for c in self._data[revision][id]
458                 if not c.id.startswith('__')]
459
460     def _get(self, id, default=InvalidObject, revision=None):
461         if revision == None:
462             revision = -1
463         else:
464             revision = int(revision)
465         if id in self._data[revision] \
466                 and self._data[revision][id].value != _EMPTY:
467             return self._data[revision][id].value
468         elif default == InvalidObject:
469             raise InvalidID(id)
470         return default
471
472     def _set(self, id, value):
473         if id not in self._data[-1]:
474             raise InvalidID(id)
475         self._data[-1][id].value = value
476
477     def commit(self, *args, **kwargs):
478         """
479         Commit the current repository, with a commit message string
480         summary and body.  Return the name of the new revision.
481
482         If allow_empty == False (the default), raise EmptyCommit if
483         there are no changes to commit.
484         """
485         if self.is_writeable() == False:
486             raise NotWriteable('Cannot commit to unwriteable storage.')
487         return self._commit(*args, **kwargs)
488
489     def _commit(self, summary, body=None, allow_empty=False):
490         if self._data[-1] == self._data[-2] and allow_empty == False:
491             raise EmptyCommit
492         self._data[-1]["__COMMIT__SUMMARY__"].value = summary
493         self._data[-1]["__COMMIT__BODY__"].value = body
494         rev = str(len(self._data)-1)
495         self._data.append(copy.deepcopy(self._data[-1]))
496         return rev
497
498     def revision_id(self, index=None):
499         """
500         Return the name of the <index>th revision.  The choice of
501         which branch to follow when crossing branches/merges is not
502         defined.  Revision indices start at 1; ID 0 is the blank
503         repository.
504
505         Return None if index==None.
506
507         If the specified revision does not exist, raise InvalidRevision.
508         """
509         if index == None:
510             return None
511         try:
512             if int(index) != index:
513                 raise InvalidRevision(index)
514         except ValueError:
515             raise InvalidRevision(index)
516         L = len(self._data) - 1  # -1 b/c of initial commit
517         if index >= -L and index <= L:
518             return str(index % L)
519         raise InvalidRevision(i)
520
521     def changed(self, revision):
522         """Return a tuple of lists of ids `(new, modified, removed)` from the
523         specified revision to the current situation.
524         """
525         new = []
526         modified = []
527         removed = []
528         for id,value in self._data[int(revision)].items():
529             if id.startswith('__'):
530                 continue
531             if not id in self._data[-1]:
532                 removed.append(id)
533             elif value.value != self._data[-1][id].value:
534                 modified.append(id)
535         for id in self._data[-1]:
536             if not id in self._data[int(revision)]:
537                 new.append(id)
538         return (new, modified, removed)
539
540
541 if TESTING == True:
542     class StorageTestCase (unittest.TestCase):
543         """Test cases for Storage class."""
544
545         Class = Storage
546
547         def __init__(self, *args, **kwargs):
548             super(StorageTestCase, self).__init__(*args, **kwargs)
549             self.dirname = None
550
551         # this class will be the basis of tests for several classes,
552         # so make sure we print the name of the class we're dealing with.
553         def _classname(self):
554             version = '?'
555             try:
556                 if hasattr(self, 's'):
557                     version = self.s.version()
558             except:
559                 pass
560             return '%s:%s' % (self.Class.__name__, version)
561
562         def fail(self, msg=None):
563             """Fail immediately, with the given message."""
564             raise self.failureException, \
565                 '(%s) %s' % (self._classname(), msg)
566
567         def failIf(self, expr, msg=None):
568             "Fail the test if the expression is true."
569             if expr: raise self.failureException, \
570                 '(%s) %s' % (self._classname(), msg)
571
572         def failUnless(self, expr, msg=None):
573             """Fail the test unless the expression is true."""
574             if not expr: raise self.failureException, \
575                 '(%s) %s' % (self._classname(), msg)
576
577         def setUp(self):
578             """Set up test fixtures for Storage test case."""
579             super(StorageTestCase, self).setUp()
580             self.dir = Dir()
581             self.dirname = self.dir.path
582             self.s = self.Class(repo=self.dirname)
583             self.assert_failed_connect()
584             self.s.init()
585             self.s.connect()
586
587         def tearDown(self):
588             super(StorageTestCase, self).tearDown()
589             self.s.disconnect()
590             self.s.destroy()
591             self.assert_failed_connect()
592             self.dir.cleanup()
593
594         def assert_failed_connect(self):
595             try:
596                 self.s.connect()
597                 self.fail(
598                     "Connected to %(name)s repository before initialising"
599                     % vars(self.Class))
600             except ConnectionError:
601                 pass
602
603     class Storage_init_TestCase (StorageTestCase):
604         """Test cases for Storage.init method."""
605
606         def test_connect_should_succeed_after_init(self):
607             """Should connect after initialization."""
608             self.s.connect()
609
610     class Storage_connect_disconnect_TestCase (StorageTestCase):
611         """Test cases for Storage.connect and .disconnect methods."""
612
613         def test_multiple_disconnects(self):
614             """Should be able to call .disconnect multiple times."""
615             self.s.disconnect()
616             self.s.disconnect()
617
618     class Storage_add_remove_TestCase (StorageTestCase):
619         """Test cases for Storage.add, .remove, and .recursive_remove methods."""
620
621         def test_initially_empty(self):
622             """New repository should be empty."""
623             self.failUnless(len(self.s.children()) == 0, self.s.children())
624
625         def test_add_identical_rooted(self):
626             """Adding entries with the same ID should not increase the number of children.
627             """
628             for i in range(10):
629                 self.s.add('some id', directory=False)
630                 s = sorted(self.s.children())
631                 self.failUnless(s == ['some id'], s)
632
633         def test_add_rooted(self):
634             """Adding entries should increase the number of children (rooted).
635             """
636             ids = []
637             for i in range(10):
638                 ids.append(str(i))
639                 self.s.add(ids[-1], directory=(i % 2 == 0))
640                 s = sorted(self.s.children())
641                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
642
643         def test_add_nonrooted(self):
644             """Adding entries should increase the number of children (nonrooted).
645             """
646             self.s.add('parent', directory=True)
647             ids = []
648             for i in range(10):
649                 ids.append(str(i))
650                 self.s.add(ids[-1], 'parent', directory=(i % 2 == 0))
651                 s = sorted(self.s.children('parent'))
652                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
653                 s = self.s.children()
654                 self.failUnless(s == ['parent'], s)
655
656         def test_ancestors(self):
657             """Check ancestors lists.
658             """
659             self.s.add('parent', directory=True)
660             for i in range(10):
661                 i_id = str(i)
662                 self.s.add(i_id, 'parent', directory=True)
663                 for j in range(10): # add some grandkids
664                     j_id = str(20*(i+1)+j)
665                     self.s.add(j_id, i_id, directory=(i%2 == 0))
666                     ancestors = sorted(self.s.ancestors(j_id))
667                     self.failUnless(ancestors == [i_id, 'parent'],
668                         'Unexpected ancestors for %s/%s, "%s"'
669                         % (i_id, j_id, ancestors))
670
671         def test_children(self):
672             """Non-UUID ids should be returned as such.
673             """
674             self.s.add('parent', directory=True)
675             ids = []
676             for i in range(10):
677                 ids.append('parent/%s' % str(i))
678                 self.s.add(ids[-1], 'parent', directory=(i % 2 == 0))
679                 s = sorted(self.s.children('parent'))
680                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
681
682         def test_add_invalid_directory(self):
683             """Should not be able to add children to non-directories.
684             """
685             self.s.add('parent', directory=False)
686             try:
687                 self.s.add('child', 'parent', directory=False)
688                 self.fail(
689                     '%s.add() succeeded instead of raising InvalidDirectory'
690                     % (vars(self.Class)['name']))
691             except InvalidDirectory:
692                 pass
693             try:
694                 self.s.add('child', 'parent', directory=True)
695                 self.fail(
696                     '%s.add() succeeded instead of raising InvalidDirectory'
697                     % (vars(self.Class)['name']))
698             except InvalidDirectory:
699                 pass
700             self.failUnless(len(self.s.children('parent')) == 0,
701                             self.s.children('parent'))
702
703         def test_remove_rooted(self):
704             """Removing entries should decrease the number of children (rooted).
705             """
706             ids = []
707             for i in range(10):
708                 ids.append(str(i))
709                 self.s.add(ids[-1], directory=(i % 2 == 0))
710             for i in range(10):
711                 self.s.remove(ids.pop())
712                 s = sorted(self.s.children())
713                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
714
715         def test_remove_nonrooted(self):
716             """Removing entries should decrease the number of children (nonrooted).
717             """
718             self.s.add('parent', directory=True)
719             ids = []
720             for i in range(10):
721                 ids.append(str(i))
722                 self.s.add(ids[-1], 'parent', directory=False)#(i % 2 == 0))
723             for i in range(10):
724                 self.s.remove(ids.pop())
725                 s = sorted(self.s.children('parent'))
726                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
727                 if len(s) > 0:
728                     s = self.s.children()
729                     self.failUnless(s == ['parent'], s)
730
731         def test_remove_directory_not_empty(self):
732             """Removing a non-empty directory entry should raise exception.
733             """
734             self.s.add('parent', directory=True)
735             ids = []
736             for i in range(10):
737                 ids.append(str(i))
738                 self.s.add(ids[-1], 'parent', directory=(i % 2 == 0))
739             self.s.remove(ids.pop()) # empty directory removal succeeds
740             try:
741                 self.s.remove('parent') # empty directory removal succeeds
742                 self.fail(
743                     "%s.remove() didn't raise DirectoryNotEmpty"
744                     % (vars(self.Class)['name']))
745             except DirectoryNotEmpty:
746                 pass
747
748         def test_recursive_remove(self):
749             """Recursive remove should empty the tree."""
750             self.s.add('parent', directory=True)
751             ids = []
752             for i in range(10):
753                 ids.append(str(i))
754                 self.s.add(ids[-1], 'parent', directory=True)
755                 for j in range(10): # add some grandkids
756                     self.s.add(str(20*(i+1)+j), ids[-1], directory=(i%2 == 0))
757             self.s.recursive_remove('parent')
758             s = sorted(self.s.children())
759             self.failUnless(s == [], s)
760
761     class Storage_get_set_TestCase (StorageTestCase):
762         """Test cases for Storage.get and .set methods."""
763
764         id = 'unlikely id'
765         val = 'unlikely value'
766
767         def test_get_default(self):
768             """Get should return specified default if id not in Storage.
769             """
770             ret = self.s.get(self.id, default=self.val)
771             self.failUnless(ret == self.val,
772                     "%s.get() returned %s not %s"
773                     % (vars(self.Class)['name'], ret, self.val))
774
775         def test_get_default_exception(self):
776             """Get should raise exception if id not in Storage and no default.
777             """
778             try:
779                 ret = self.s.get(self.id)
780                 self.fail(
781                     "%s.get() returned %s instead of raising InvalidID"
782                     % (vars(self.Class)['name'], ret))
783             except InvalidID:
784                 pass
785
786         def test_get_initial_value(self):
787             """Data value should be default before any value has been set.
788             """
789             self.s.add(self.id, directory=False)
790             val = 'UNLIKELY DEFAULT'
791             ret = self.s.get(self.id, default=val)
792             self.failUnless(ret == val,
793                     "%s.get() returned %s not %s"
794                     % (vars(self.Class)['name'], ret, val))
795
796         def test_set_exception(self):
797             """Set should raise exception if id not in Storage.
798             """
799             try:
800                 self.s.set(self.id, self.val)
801                 self.fail(
802                     "%(name)s.set() did not raise InvalidID"
803                     % vars(self.Class))
804             except InvalidID:
805                 pass
806
807         def test_set(self):
808             """Set should define the value returned by get.
809             """
810             self.s.add(self.id, directory=False)
811             self.s.set(self.id, self.val)
812             ret = self.s.get(self.id)
813             self.failUnless(ret == self.val,
814                     "%s.get() returned %s not %s"
815                     % (vars(self.Class)['name'], ret, self.val))
816
817         def test_unicode_set(self):
818             """Set should define the value returned by get.
819             """
820             val = u'Fran\xe7ois'
821             self.s.add(self.id, directory=False)
822             self.s.set(self.id, val)
823             ret = self.s.get(self.id, decode=True)
824             self.failUnless(type(ret) == types.UnicodeType,
825                     "%s.get() returned %s not UnicodeType"
826                     % (vars(self.Class)['name'], type(ret)))
827             self.failUnless(ret == val,
828                     "%s.get() returned %s not %s"
829                     % (vars(self.Class)['name'], ret, self.val))
830             ret = self.s.get(self.id)
831             self.failUnless(type(ret) == types.StringType,
832                     "%s.get() returned %s not StringType"
833                     % (vars(self.Class)['name'], type(ret)))
834             s = unicode(ret, self.s.encoding)
835             self.failUnless(s == val,
836                     "%s.get() returned %s not %s"
837                     % (vars(self.Class)['name'], s, self.val))
838
839
840     class Storage_persistence_TestCase (StorageTestCase):
841         """Test cases for Storage.disconnect and .connect methods."""
842
843         id = 'unlikely id'
844         val = 'unlikely value'
845
846         def test_get_set_persistence(self):
847             """Set should define the value returned by get after reconnect.
848             """
849             self.s.add(self.id, directory=False)
850             self.s.set(self.id, self.val)
851             self.s.disconnect()
852             self.s.connect()
853             ret = self.s.get(self.id)
854             self.failUnless(ret == self.val,
855                     "%s.get() returned %s not %s"
856                     % (vars(self.Class)['name'], ret, self.val))
857
858         def test_empty_get_set_persistence(self):
859             """After empty set, get may return either an empty string or default.
860             """
861             self.s.add(self.id, directory=False)
862             self.s.set(self.id, '')
863             self.s.disconnect()
864             self.s.connect()
865             default = 'UNLIKELY DEFAULT'
866             ret = self.s.get(self.id, default=default)
867             self.failUnless(ret in ['', default],
868                     "%s.get() returned %s not in %s"
869                     % (vars(self.Class)['name'], ret, ['', default]))
870
871         def test_add_nonrooted_persistence(self):
872             """Adding entries should increase the number of children after reconnect.
873             """
874             self.s.add('parent', directory=True)
875             ids = []
876             for i in range(10):
877                 ids.append(str(i))
878                 self.s.add(ids[-1], 'parent', directory=(i % 2 == 0))
879             self.s.disconnect()
880             self.s.connect()
881             s = sorted(self.s.children('parent'))
882             self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
883             s = self.s.children()
884             self.failUnless(s == ['parent'], s)
885
886     class VersionedStorageTestCase (StorageTestCase):
887         """Test cases for VersionedStorage methods."""
888
889         Class = VersionedStorage
890
891     class VersionedStorage_commit_TestCase (VersionedStorageTestCase):
892         """Test cases for VersionedStorage.commit and revision_ids methods."""
893
894         id = 'unlikely id'
895         val = 'Some value'
896         commit_msg = 'Committing something interesting'
897         commit_body = 'Some\nlonger\ndescription\n'
898
899         def _setup_for_empty_commit(self):
900             """
901             Initialization might add some files to version control, so
902             commit those first, before testing the empty commit
903             functionality.
904             """
905             try:
906                 self.s.commit('Added initialization files')
907             except EmptyCommit:
908                 pass
909                 
910         def test_revision_id_exception(self):
911             """Invalid revision id should raise InvalidRevision.
912             """
913             try:
914                 rev = self.s.revision_id('highly unlikely revision id')
915                 self.fail(
916                     "%s.revision_id() didn't raise InvalidRevision, returned %s."
917                     % (vars(self.Class)['name'], rev))
918             except InvalidRevision:
919                 pass
920
921         def test_empty_commit_raises_exception(self):
922             """Empty commit should raise exception.
923             """
924             self._setup_for_empty_commit()
925             try:
926                 self.s.commit(self.commit_msg, self.commit_body)
927                 self.fail(
928                     "Empty %(name)s.commit() didn't raise EmptyCommit."
929                     % vars(self.Class))
930             except EmptyCommit:
931                 pass
932
933         def test_empty_commit_allowed(self):
934             """Empty commit should _not_ raise exception if allow_empty=True.
935             """
936             self._setup_for_empty_commit()
937             self.s.commit(self.commit_msg, self.commit_body,
938                           allow_empty=True)
939
940         def test_commit_revision_ids(self):
941             """Commit / revision_id should agree on revision ids.
942             """
943             def val(i):
944                 return '%s:%d' % (self.val, i+1)
945             self.s.add(self.id, directory=False)
946             revs = []
947             for i in range(10):
948                 self.s.set(self.id, val(i))
949                 revs.append(self.s.commit('%s: %d' % (self.commit_msg, i),
950                                           self.commit_body))
951             for i in range(10):
952                 rev = self.s.revision_id(i+1)
953                 self.failUnless(rev == revs[i],
954                                 "%s.revision_id(%d) returned %s not %s"
955                                 % (vars(self.Class)['name'], i+1, rev, revs[i]))
956             for i in range(-1, -9, -1):
957                 rev = self.s.revision_id(i)
958                 self.failUnless(rev == revs[i],
959                                 "%s.revision_id(%d) returned %s not %s"
960                                 % (vars(self.Class)['name'], i, rev, revs[i]))
961
962         def test_get_previous_version(self):
963             """Get should be able to return the previous version.
964             """
965             def val(i):
966                 return '%s:%d' % (self.val, i+1)
967             self.s.add(self.id, directory=False)
968             revs = []
969             for i in range(10):
970                 self.s.set(self.id, val(i))
971                 revs.append(self.s.commit('%s: %d' % (self.commit_msg, i),
972                                           self.commit_body))
973             for i in range(10):
974                 ret = self.s.get(self.id, revision=revs[i])
975                 self.failUnless(ret == val(i),
976                                 "%s.get() returned %s not %s for revision %s"
977                                 % (vars(self.Class)['name'], ret, val(i), revs[i]))
978
979         def test_get_previous_children(self):
980             """Children list should be revision dependent.
981             """
982             self.s.add('parent', directory=True)
983             revs = []
984             cur_children = []
985             children = []
986             for i in range(10):
987                 new_child = str(i)
988                 self.s.add(new_child, 'parent')
989                 self.s.set(new_child, self.val)
990                 revs.append(self.s.commit('%s: %d' % (self.commit_msg, i),
991                                           self.commit_body))
992                 cur_children.append(new_child)
993                 children.append(list(cur_children))
994             for i in range(10):
995                 ret = sorted(self.s.children('parent', revision=revs[i]))
996                 self.failUnless(ret == children[i],
997                                 "%s.get() returned %s not %s for revision %s"
998                                 % (vars(self.Class)['name'], ret,
999                                    children[i], revs[i]))
1000
1001     class VersionedStorage_changed_TestCase (VersionedStorageTestCase):
1002         """Test cases for VersionedStorage.changed() method."""
1003
1004         def test_changed(self):
1005             """Changed lists should reflect past activity"""
1006             self.s.add('dir', directory=True)
1007             self.s.add('modified', parent='dir')
1008             self.s.set('modified', 'some value to be modified')
1009             self.s.add('moved', parent='dir')
1010             self.s.set('moved', 'this entry will be moved')
1011             self.s.add('removed', parent='dir')
1012             self.s.set('removed', 'this entry will be deleted')
1013             revA = self.s.commit('Initial state')
1014             self.s.add('new', parent='dir')
1015             self.s.set('new', 'this entry is new')
1016             self.s.set('modified', 'a new value')
1017             self.s.remove('moved')
1018             self.s.add('moved2', parent='dir')
1019             self.s.set('moved2', 'this entry will be moved')
1020             self.s.remove('removed')
1021             revB = self.s.commit('Final state')
1022             new,mod,rem = self.s.changed(revA)
1023             self.failUnless(sorted(new) == ['moved2', 'new'],
1024                             'Unexpected new: %s' % new)
1025             self.failUnless(mod == ['modified'],
1026                             'Unexpected modified: %s' % mod)
1027             self.failUnless(sorted(rem) == ['moved', 'removed'],
1028                             'Unexpected removed: %s' % rem)
1029
1030     def make_storage_testcase_subclasses(storage_class, namespace):
1031         """Make StorageTestCase subclasses for storage_class in namespace."""
1032         storage_testcase_classes = [
1033             c for c in (
1034                 ob for ob in globals().values() if isinstance(ob, type))
1035             if issubclass(c, StorageTestCase) \
1036                 and c.Class == Storage]
1037
1038         for base_class in storage_testcase_classes:
1039             testcase_class_name = storage_class.__name__ + base_class.__name__
1040             testcase_class_bases = (base_class,)
1041             testcase_class_dict = dict(base_class.__dict__)
1042             testcase_class_dict['Class'] = storage_class
1043             testcase_class = type(
1044                 testcase_class_name, testcase_class_bases, testcase_class_dict)
1045             setattr(namespace, testcase_class_name, testcase_class)
1046
1047     def make_versioned_storage_testcase_subclasses(storage_class, namespace):
1048         """Make VersionedStorageTestCase subclasses for storage_class in namespace."""
1049         storage_testcase_classes = [
1050             c for c in (
1051                 ob for ob in globals().values() if isinstance(ob, type))
1052             if ((issubclass(c, StorageTestCase) \
1053                      and c.Class == Storage)
1054                 or
1055                 (issubclass(c, VersionedStorageTestCase) \
1056                      and c.Class == VersionedStorage))]
1057
1058         for base_class in storage_testcase_classes:
1059             testcase_class_name = storage_class.__name__ + base_class.__name__
1060             testcase_class_bases = (base_class,)
1061             testcase_class_dict = dict(base_class.__dict__)
1062             testcase_class_dict['Class'] = storage_class
1063             testcase_class = type(
1064                 testcase_class_name, testcase_class_bases, testcase_class_dict)
1065             setattr(namespace, testcase_class_name, testcase_class)
1066
1067     make_storage_testcase_subclasses(VersionedStorage, sys.modules[__name__])
1068
1069     unitsuite =unittest.TestLoader().loadTestsFromModule(sys.modules[__name__])
1070     suite = unittest.TestSuite([unitsuite, doctest.DocTestSuite()])