Ran update-copyright.py.
[be.git] / libbe / storage / base.py
1 # Copyright (C) 2009-2012 Chris Ball <cjb@laptop.org>
2 #                         W. Trevor King <wking@drexel.edu>
3 #
4 # This file is part of Bugs Everywhere.
5 #
6 # Bugs Everywhere is free software: you can redistribute it and/or modify it
7 # under the terms of the GNU General Public License as published by the Free
8 # Software Foundation, either version 2 of the License, or (at your option) any
9 # later version.
10 #
11 # Bugs Everywhere is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 # FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
14 # more details.
15 #
16 # You should have received a copy of the GNU General Public License along with
17 # Bugs Everywhere.  If not, see <http://www.gnu.org/licenses/>.
18
19 """
20 Abstract bug repository data storage to easily support multiple backends.
21 """
22
23 import copy
24 import os
25 import pickle
26 import types
27
28 from libbe.error import NotSupported
29 import libbe.storage
30 from libbe.util.tree import Tree
31 from libbe.util import InvalidObject
32 import libbe.version
33 from libbe import TESTING
34
35 if TESTING == True:
36     import doctest
37     import os.path
38     import sys
39     import unittest
40
41     from libbe.util.utility import Dir
42
43 class ConnectionError (Exception):
44     pass
45
46 class InvalidStorageVersion(ConnectionError):
47     def __init__(self, active_version, expected_version=None):
48         if expected_version == None:
49             expected_version = libbe.storage.STORAGE_VERSION
50         msg = 'Storage in "%s" not the expected "%s"' \
51             % (active_version, expected_version)
52         Exception.__init__(self, msg)
53         self.active_version = active_version
54         self.expected_version = expected_version
55
56 class InvalidID (KeyError):
57     def __init__(self, id=None, revision=None, msg=None):
58         KeyError.__init__(self, id)
59         self.msg = msg
60         self.id = id
61         self.revision = revision
62     def __str__(self):
63         if self.msg == None:
64             return '%s in revision %s' % (self.id, self.revision)
65         return self.msg
66
67
68 class InvalidRevision (KeyError):
69     pass
70
71 class InvalidDirectory (Exception):
72     pass
73
74 class DirectoryNotEmpty (InvalidDirectory):
75     pass
76
77 class NotWriteable (NotSupported):
78     def __init__(self, msg):
79         NotSupported.__init__(self, 'write', msg)
80
81 class NotReadable (NotSupported):
82     def __init__(self, msg):
83         NotSupported.__init__(self, 'read', msg)
84
85 class EmptyCommit(Exception):
86     def __init__(self):
87         Exception.__init__(self, 'No changes to commit')
88
89 class _EMPTY (object):
90     """Entry has been added but has no user-set value."""
91     pass
92
93 class Entry (Tree):
94     def __init__(self, id, value=_EMPTY, parent=None, directory=False,
95                  children=None):
96         if children == None:
97             Tree.__init__(self)
98         else:
99             Tree.__init__(self, children)
100         self.id = id
101         self.value = value
102         self.parent = parent
103         if self.parent != None:
104             if self.parent.directory == False:
105                 raise InvalidDirectory(
106                     'Non-directory %s cannot have children' % self.parent)
107             parent.append(self)
108         self.directory = directory
109
110     def __str__(self):
111         return '<Entry %s: %s>' % (self.id, self.value)
112
113     def __repr__(self):
114         return str(self)
115
116     def __cmp__(self, other, local=False):
117         if other == None:
118             return cmp(1, None)
119         if cmp(self.id, other.id) != 0:
120             return cmp(self.id, other.id)
121         if cmp(self.value, other.value) != 0:
122             return cmp(self.value, other.value)
123         if local == False:
124             if self.parent == None:
125                 if cmp(self.parent, other.parent) != 0:
126                     return cmp(self.parent, other.parent)
127             elif self.parent.__cmp__(other.parent, local=True) != 0:
128                 return self.parent.__cmp__(other.parent, local=True)
129             for sc,oc in zip(self, other):
130                 if sc.__cmp__(oc, local=True) != 0:
131                     return sc.__cmp__(oc, local=True)
132         return 0
133
134     def _objects_to_ids(self):
135         if self.parent != None:
136             self.parent = self.parent.id
137         for i,c in enumerate(self):
138             self[i] = c.id
139         return self
140
141     def _ids_to_objects(self, dict):
142         if self.parent != None:
143             self.parent = dict[self.parent]
144         for i,c in enumerate(self):
145             self[i] = dict[c]
146         return self
147
148 class Storage (object):
149     """
150     This class declares all the methods required by a Storage
151     interface.  This implementation just keeps the data in a
152     dictionary and uses pickle for persistent storage.
153     """
154     name = 'Storage'
155
156     def __init__(self, repo='/', encoding='utf-8', options=None):
157         self.repo = repo
158         self.encoding = encoding
159         self.options = options
160         self.readable = True  # soft limit (user choice)
161         self._readable = True # hard limit (backend choice)
162         self.writeable = True  # soft limit (user choice)
163         self._writeable = True # hard limit (backend choice)
164         self.versioned = False
165         self.can_init = True
166         self.connected = False
167
168     def __str__(self):
169         return '<%s %s %s>' % (self.__class__.__name__, id(self), self.repo)
170
171     def __repr__(self):
172         return str(self)
173
174     def version(self):
175         """Return a version string for this backend."""
176         return libbe.version.version()
177
178     def storage_version(self, revision=None):
179         """Return the storage format for this backend."""
180         return libbe.storage.STORAGE_VERSION
181
182     def is_readable(self):
183         return self.readable and self._readable
184
185     def is_writeable(self):
186         return self.writeable and self._writeable
187
188     def init(self):
189         """Create a new storage repository."""
190         if self.can_init == False:
191             raise NotSupported('init',
192                                'Cannot initialize this repository format.')
193         if self.is_writeable() == False:
194             raise NotWriteable('Cannot initialize unwriteable storage.')
195         return self._init()
196
197     def _init(self):
198         f = open(os.path.join(self.repo, 'repo.pkl'), 'wb')
199         root = Entry(id='__ROOT__', directory=True)
200         d = {root.id:root}
201         pickle.dump(dict((k,v._objects_to_ids()) for k,v in d.items()), f, -1)
202         f.close()
203
204     def destroy(self):
205         """Remove the storage repository."""
206         if self.is_writeable() == False:
207             raise NotWriteable('Cannot destroy unwriteable storage.')
208         return self._destroy()
209
210     def _destroy(self):
211         os.remove(os.path.join(self.repo, 'repo.pkl'))
212
213     def connect(self):
214         """Open a connection to the repository."""
215         if self.is_readable() == False:
216             raise NotReadable('Cannot connect to unreadable storage.')
217         self._connect()
218         self.connected = True
219
220     def _connect(self):
221         try:
222             f = open(os.path.join(self.repo, 'repo.pkl'), 'rb')
223         except IOError:
224             raise ConnectionError(self)
225         d = pickle.load(f)
226         self._data = dict((k,v._ids_to_objects(d)) for k,v in d.items())
227         f.close()
228
229     def disconnect(self):
230         """Close the connection to the repository."""
231         if self.is_writeable() == False:
232             return
233         if self.connected == False:
234             return
235         self._disconnect()
236         self.connected = False
237
238     def _disconnect(self):
239         f = open(os.path.join(self.repo, 'repo.pkl'), 'wb')
240         pickle.dump(dict((k,v._objects_to_ids())
241                          for k,v in self._data.items()), f, -1)
242         f.close()
243         self._data = None
244
245     def add(self, id, *args, **kwargs):
246         """Add an entry"""
247         if self.is_writeable() == False:
248             raise NotWriteable('Cannot add entry to unwriteable storage.')
249         if not self.exists(id):
250             self._add(id, *args, **kwargs)
251
252     def _add(self, id, parent=None, directory=False):
253         if parent == None:
254             parent = '__ROOT__'
255         p = self._data[parent]
256         self._data[id] = Entry(id, parent=p, directory=directory)
257
258     def exists(self, *args, **kwargs):
259         """Check an entry's existence"""
260         if self.is_readable() == False:
261             raise NotReadable('Cannot check entry existence in unreadable storage.')
262         return self._exists(*args, **kwargs)
263
264     def _exists(self, id, revision=None):
265         return id in self._data
266
267     def remove(self, *args, **kwargs):
268         """Remove an entry."""
269         if self.is_writeable() == False:
270             raise NotSupported('write',
271                                'Cannot remove entry from unwriteable storage.')
272         self._remove(*args, **kwargs)
273
274     def _remove(self, id):
275         if self._data[id].directory == True \
276                 and len(self.children(id)) > 0:
277             raise DirectoryNotEmpty(id)
278         e = self._data.pop(id)
279         e.parent.remove(e)
280
281     def recursive_remove(self, *args, **kwargs):
282         """Remove an entry and all its decendents."""
283         if self.is_writeable() == False:
284             raise NotSupported('write',
285                                'Cannot remove entries from unwriteable storage.')
286         self._recursive_remove(*args, **kwargs)
287
288     def _recursive_remove(self, id):
289         for entry in reversed(list(self._data[id].traverse())):
290             self._remove(entry.id)
291
292     def ancestors(self, *args, **kwargs):
293         """Return a list of the specified entry's ancestors' ids."""
294         if self.is_readable() == False:
295             raise NotReadable('Cannot list parents with unreadable storage.')
296         return self._ancestors(*args, **kwargs)
297
298     def _ancestors(self, id=None, revision=None):
299         if id == None:
300             return []
301         ancestors = []
302         stack = [id]
303         while len(stack) > 0:
304             id = stack.pop(0)
305             parent = self._data[id].parent
306             if parent != None and not parent.id.startswith('__'):
307                 ancestor = parent.id
308                 ancestors.append(ancestor)
309                 stack.append(ancestor)
310         return ancestors
311
312     def children(self, *args, **kwargs):
313         """Return a list of specified entry's children's ids."""
314         if self.is_readable() == False:
315             raise NotReadable('Cannot list children with unreadable storage.')
316         return self._children(*args, **kwargs)
317
318     def _children(self, id=None, revision=None):
319         if id == None:
320             id = '__ROOT__'
321         return [c.id for c in self._data[id] if not c.id.startswith('__')]
322
323     def get(self, *args, **kwargs):
324         """
325         Get contents of and entry as they were in a given revision.
326         revision==None specifies the current revision.
327
328         If there is no id, return default, unless default is not
329         given, in which case raise InvalidID.
330         """
331         if self.is_readable() == False:
332             raise NotReadable('Cannot get entry with unreadable storage.')
333         if 'decode' in kwargs:
334             decode = kwargs.pop('decode')
335         else:
336             decode = False
337         value = self._get(*args, **kwargs)
338         if value != None:
339             if decode == True and type(value) != types.UnicodeType:
340                 return unicode(value, self.encoding)
341             elif decode == False and type(value) != types.StringType:
342                 return value.encode(self.encoding)
343         return value
344
345     def _get(self, id, default=InvalidObject, revision=None):
346         if id in self._data and self._data[id].value != _EMPTY:
347             return self._data[id].value
348         elif default == InvalidObject:
349             raise InvalidID(id)
350         return default
351
352     def set(self, id, value, *args, **kwargs):
353         """
354         Set the entry contents.
355         """
356         if self.is_writeable() == False:
357             raise NotWriteable('Cannot set entry in unwriteable storage.')
358         if type(value) == types.UnicodeType:
359             value = value.encode(self.encoding)
360         self._set(id, value, *args, **kwargs)
361
362     def _set(self, id, value):
363         if id not in self._data:
364             raise InvalidID(id)
365         if self._data[id].directory == True:
366             raise InvalidDirectory(
367                 'Directory %s cannot have data' % self.parent)
368         self._data[id].value = value
369
370 class VersionedStorage (Storage):
371     """
372     This class declares all the methods required by a Storage
373     interface that supports versioning.  This implementation just
374     keeps the data in a list and uses pickle for persistent
375     storage.
376     """
377     name = 'VersionedStorage'
378
379     def __init__(self, *args, **kwargs):
380         Storage.__init__(self, *args, **kwargs)
381         self.versioned = True
382
383     def _init(self):
384         f = open(os.path.join(self.repo, 'repo.pkl'), 'wb')
385         root = Entry(id='__ROOT__', directory=True)
386         summary = Entry(id='__COMMIT__SUMMARY__', value='Initial commit')
387         body = Entry(id='__COMMIT__BODY__')
388         initial_commit = {root.id:root, summary.id:summary, body.id:body}
389         d = dict((k,v._objects_to_ids()) for k,v in initial_commit.items())
390         pickle.dump([d, copy.deepcopy(d)], f, -1) # [inital tree, working tree]
391         f.close()
392
393     def _connect(self):
394         try:
395             f = open(os.path.join(self.repo, 'repo.pkl'), 'rb')
396         except IOError:
397             raise ConnectionError(self)
398         d = pickle.load(f)
399         self._data = [dict((k,v._ids_to_objects(t)) for k,v in t.items())
400                       for t in d]
401         f.close()
402
403     def _disconnect(self):
404         f = open(os.path.join(self.repo, 'repo.pkl'), 'wb')
405         pickle.dump([dict((k,v._objects_to_ids())
406                           for k,v in t.items()) for t in self._data], f, -1)
407         f.close()
408         self._data = None
409
410     def _add(self, id, parent=None, directory=False):
411         if parent == None:
412             parent = '__ROOT__'
413         p = self._data[-1][parent]
414         self._data[-1][id] = Entry(id, parent=p, directory=directory)
415
416     def _exists(self, id, revision=None):
417         if revision == None:
418             revision = -1
419         else:
420             revision = int(revision)
421         return id in self._data[revision]
422
423     def _remove(self, id):
424         if self._data[-1][id].directory == True \
425                 and len(self.children(id)) > 0:
426             raise DirectoryNotEmpty(id)
427         e = self._data[-1].pop(id)
428         e.parent.remove(e)
429
430     def _recursive_remove(self, id):
431         for entry in reversed(list(self._data[-1][id].traverse())):
432             self._remove(entry.id)
433
434     def _ancestors(self, id=None, revision=None):
435         if id == None:
436             return []
437         if revision == None:
438             revision = -1
439         else:
440             revision = int(revision)
441         ancestors = []
442         stack = [id]
443         while len(stack) > 0:
444             id = stack.pop(0)
445             parent = self._data[revision][id].parent
446             if parent != None and not parent.id.startswith('__'):
447                 ancestor = parent.id
448                 ancestors.append(ancestor)
449                 stack.append(ancestor)
450         return ancestors
451
452     def _children(self, id=None, revision=None):
453         if id == None:
454             id = '__ROOT__'
455         if revision == None:
456             revision = -1
457         else:
458             revision = int(revision)
459         return [c.id for c in self._data[revision][id]
460                 if not c.id.startswith('__')]
461
462     def _get(self, id, default=InvalidObject, revision=None):
463         if revision == None:
464             revision = -1
465         else:
466             revision = int(revision)
467         if id in self._data[revision] \
468                 and self._data[revision][id].value != _EMPTY:
469             return self._data[revision][id].value
470         elif default == InvalidObject:
471             raise InvalidID(id)
472         return default
473
474     def _set(self, id, value):
475         if id not in self._data[-1]:
476             raise InvalidID(id)
477         self._data[-1][id].value = value
478
479     def commit(self, *args, **kwargs):
480         """
481         Commit the current repository, with a commit message string
482         summary and body.  Return the name of the new revision.
483
484         If allow_empty == False (the default), raise EmptyCommit if
485         there are no changes to commit.
486         """
487         if self.is_writeable() == False:
488             raise NotWriteable('Cannot commit to unwriteable storage.')
489         return self._commit(*args, **kwargs)
490
491     def _commit(self, summary, body=None, allow_empty=False):
492         if self._data[-1] == self._data[-2] and allow_empty == False:
493             raise EmptyCommit
494         self._data[-1]["__COMMIT__SUMMARY__"].value = summary
495         self._data[-1]["__COMMIT__BODY__"].value = body
496         rev = str(len(self._data)-1)
497         self._data.append(copy.deepcopy(self._data[-1]))
498         return rev
499
500     def revision_id(self, index=None):
501         """
502         Return the name of the <index>th revision.  The choice of
503         which branch to follow when crossing branches/merges is not
504         defined.  Revision indices start at 1; ID 0 is the blank
505         repository.
506
507         Return None if index==None.
508
509         If the specified revision does not exist, raise InvalidRevision.
510         """
511         if index == None:
512             return None
513         try:
514             if int(index) != index:
515                 raise InvalidRevision(index)
516         except ValueError:
517             raise InvalidRevision(index)
518         L = len(self._data) - 1  # -1 b/c of initial commit
519         if index >= -L and index <= L:
520             return str(index % L)
521         raise InvalidRevision(i)
522
523     def changed(self, revision):
524         """Return a tuple of lists of ids `(new, modified, removed)` from the
525         specified revision to the current situation.
526         """
527         new = []
528         modified = []
529         removed = []
530         for id,value in self._data[int(revision)].items():
531             if id.startswith('__'):
532                 continue
533             if not id in self._data[-1]:
534                 removed.append(id)
535             elif value.value != self._data[-1][id].value:
536                 modified.append(id)
537         for id in self._data[-1]:
538             if not id in self._data[int(revision)]:
539                 new.append(id)
540         return (new, modified, removed)
541
542
543 if TESTING == True:
544     class StorageTestCase (unittest.TestCase):
545         """Test cases for Storage class."""
546
547         Class = Storage
548
549         def __init__(self, *args, **kwargs):
550             super(StorageTestCase, self).__init__(*args, **kwargs)
551             self.dirname = None
552
553         # this class will be the basis of tests for several classes,
554         # so make sure we print the name of the class we're dealing with.
555         def _classname(self):
556             version = '?'
557             try:
558                 if hasattr(self, 's'):
559                     version = self.s.version()
560             except:
561                 pass
562             return '%s:%s' % (self.Class.__name__, version)
563
564         def fail(self, msg=None):
565             """Fail immediately, with the given message."""
566             raise self.failureException, \
567                 '(%s) %s' % (self._classname(), msg)
568
569         def failIf(self, expr, msg=None):
570             "Fail the test if the expression is true."
571             if expr: raise self.failureException, \
572                 '(%s) %s' % (self._classname(), msg)
573
574         def failUnless(self, expr, msg=None):
575             """Fail the test unless the expression is true."""
576             if not expr: raise self.failureException, \
577                 '(%s) %s' % (self._classname(), msg)
578
579         def setUp(self):
580             """Set up test fixtures for Storage test case."""
581             super(StorageTestCase, self).setUp()
582             self.dir = Dir()
583             self.dirname = self.dir.path
584             self.s = self.Class(repo=self.dirname)
585             self.assert_failed_connect()
586             self.s.init()
587             self.s.connect()
588
589         def tearDown(self):
590             super(StorageTestCase, self).tearDown()
591             self.s.disconnect()
592             self.s.destroy()
593             self.assert_failed_connect()
594             self.dir.cleanup()
595
596         def assert_failed_connect(self):
597             try:
598                 self.s.connect()
599                 self.fail(
600                     "Connected to %(name)s repository before initialising"
601                     % vars(self.Class))
602             except ConnectionError:
603                 pass
604
605     class Storage_init_TestCase (StorageTestCase):
606         """Test cases for Storage.init method."""
607
608         def test_connect_should_succeed_after_init(self):
609             """Should connect after initialization."""
610             self.s.connect()
611
612     class Storage_connect_disconnect_TestCase (StorageTestCase):
613         """Test cases for Storage.connect and .disconnect methods."""
614
615         def test_multiple_disconnects(self):
616             """Should be able to call .disconnect multiple times."""
617             self.s.disconnect()
618             self.s.disconnect()
619
620     class Storage_add_remove_TestCase (StorageTestCase):
621         """Test cases for Storage.add, .remove, and .recursive_remove methods."""
622
623         def test_initially_empty(self):
624             """New repository should be empty."""
625             self.failUnless(len(self.s.children()) == 0, self.s.children())
626
627         def test_add_identical_rooted(self):
628             """Adding entries with the same ID should not increase the number of children.
629             """
630             for i in range(10):
631                 self.s.add('some id', directory=False)
632                 s = sorted(self.s.children())
633                 self.failUnless(s == ['some id'], s)
634
635         def test_add_rooted(self):
636             """Adding entries should increase the number of children (rooted).
637             """
638             ids = []
639             for i in range(10):
640                 ids.append(str(i))
641                 self.s.add(ids[-1], directory=(i % 2 == 0))
642                 s = sorted(self.s.children())
643                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
644
645         def test_add_nonrooted(self):
646             """Adding entries should increase the number of children (nonrooted).
647             """
648             self.s.add('parent', directory=True)
649             ids = []
650             for i in range(10):
651                 ids.append(str(i))
652                 self.s.add(ids[-1], 'parent', directory=(i % 2 == 0))
653                 s = sorted(self.s.children('parent'))
654                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
655                 s = self.s.children()
656                 self.failUnless(s == ['parent'], s)
657
658         def test_ancestors(self):
659             """Check ancestors lists.
660             """
661             self.s.add('parent', directory=True)
662             for i in range(10):
663                 i_id = str(i)
664                 self.s.add(i_id, 'parent', directory=True)
665                 for j in range(10): # add some grandkids
666                     j_id = str(20*(i+1)+j)
667                     self.s.add(j_id, i_id, directory=(i%2 == 0))
668                     ancestors = sorted(self.s.ancestors(j_id))
669                     self.failUnless(ancestors == [i_id, 'parent'],
670                         'Unexpected ancestors for %s/%s, "%s"'
671                         % (i_id, j_id, ancestors))
672
673         def test_children(self):
674             """Non-UUID ids should be returned as such.
675             """
676             self.s.add('parent', directory=True)
677             ids = []
678             for i in range(10):
679                 ids.append('parent/%s' % str(i))
680                 self.s.add(ids[-1], 'parent', directory=(i % 2 == 0))
681                 s = sorted(self.s.children('parent'))
682                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
683
684         def test_add_invalid_directory(self):
685             """Should not be able to add children to non-directories.
686             """
687             self.s.add('parent', directory=False)
688             try:
689                 self.s.add('child', 'parent', directory=False)
690                 self.fail(
691                     '%s.add() succeeded instead of raising InvalidDirectory'
692                     % (vars(self.Class)['name']))
693             except InvalidDirectory:
694                 pass
695             try:
696                 self.s.add('child', 'parent', directory=True)
697                 self.fail(
698                     '%s.add() succeeded instead of raising InvalidDirectory'
699                     % (vars(self.Class)['name']))
700             except InvalidDirectory:
701                 pass
702             self.failUnless(len(self.s.children('parent')) == 0,
703                             self.s.children('parent'))
704
705         def test_remove_rooted(self):
706             """Removing entries should decrease the number of children (rooted).
707             """
708             ids = []
709             for i in range(10):
710                 ids.append(str(i))
711                 self.s.add(ids[-1], directory=(i % 2 == 0))
712             for i in range(10):
713                 self.s.remove(ids.pop())
714                 s = sorted(self.s.children())
715                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
716
717         def test_remove_nonrooted(self):
718             """Removing entries should decrease the number of children (nonrooted).
719             """
720             self.s.add('parent', directory=True)
721             ids = []
722             for i in range(10):
723                 ids.append(str(i))
724                 self.s.add(ids[-1], 'parent', directory=False)#(i % 2 == 0))
725             for i in range(10):
726                 self.s.remove(ids.pop())
727                 s = sorted(self.s.children('parent'))
728                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
729                 if len(s) > 0:
730                     s = self.s.children()
731                     self.failUnless(s == ['parent'], s)
732
733         def test_remove_directory_not_empty(self):
734             """Removing a non-empty directory entry should raise exception.
735             """
736             self.s.add('parent', directory=True)
737             ids = []
738             for i in range(10):
739                 ids.append(str(i))
740                 self.s.add(ids[-1], 'parent', directory=(i % 2 == 0))
741             self.s.remove(ids.pop()) # empty directory removal succeeds
742             try:
743                 self.s.remove('parent') # empty directory removal succeeds
744                 self.fail(
745                     "%s.remove() didn't raise DirectoryNotEmpty"
746                     % (vars(self.Class)['name']))
747             except DirectoryNotEmpty:
748                 pass
749
750         def test_recursive_remove(self):
751             """Recursive remove should empty the tree."""
752             self.s.add('parent', directory=True)
753             ids = []
754             for i in range(10):
755                 ids.append(str(i))
756                 self.s.add(ids[-1], 'parent', directory=True)
757                 for j in range(10): # add some grandkids
758                     self.s.add(str(20*(i+1)+j), ids[-1], directory=(i%2 == 0))
759             self.s.recursive_remove('parent')
760             s = sorted(self.s.children())
761             self.failUnless(s == [], s)
762
763     class Storage_get_set_TestCase (StorageTestCase):
764         """Test cases for Storage.get and .set methods."""
765
766         id = 'unlikely id'
767         val = 'unlikely value'
768
769         def test_get_default(self):
770             """Get should return specified default if id not in Storage.
771             """
772             ret = self.s.get(self.id, default=self.val)
773             self.failUnless(ret == self.val,
774                     "%s.get() returned %s not %s"
775                     % (vars(self.Class)['name'], ret, self.val))
776
777         def test_get_default_exception(self):
778             """Get should raise exception if id not in Storage and no default.
779             """
780             try:
781                 ret = self.s.get(self.id)
782                 self.fail(
783                     "%s.get() returned %s instead of raising InvalidID"
784                     % (vars(self.Class)['name'], ret))
785             except InvalidID:
786                 pass
787
788         def test_get_initial_value(self):
789             """Data value should be default before any value has been set.
790             """
791             self.s.add(self.id, directory=False)
792             val = 'UNLIKELY DEFAULT'
793             ret = self.s.get(self.id, default=val)
794             self.failUnless(ret == val,
795                     "%s.get() returned %s not %s"
796                     % (vars(self.Class)['name'], ret, val))
797
798         def test_set_exception(self):
799             """Set should raise exception if id not in Storage.
800             """
801             try:
802                 self.s.set(self.id, self.val)
803                 self.fail(
804                     "%(name)s.set() did not raise InvalidID"
805                     % vars(self.Class))
806             except InvalidID:
807                 pass
808
809         def test_set(self):
810             """Set should define the value returned by get.
811             """
812             self.s.add(self.id, directory=False)
813             self.s.set(self.id, self.val)
814             ret = self.s.get(self.id)
815             self.failUnless(ret == self.val,
816                     "%s.get() returned %s not %s"
817                     % (vars(self.Class)['name'], ret, self.val))
818
819         def test_unicode_set(self):
820             """Set should define the value returned by get.
821             """
822             val = u'Fran\xe7ois'
823             self.s.add(self.id, directory=False)
824             self.s.set(self.id, val)
825             ret = self.s.get(self.id, decode=True)
826             self.failUnless(type(ret) == types.UnicodeType,
827                     "%s.get() returned %s not UnicodeType"
828                     % (vars(self.Class)['name'], type(ret)))
829             self.failUnless(ret == val,
830                     "%s.get() returned %s not %s"
831                     % (vars(self.Class)['name'], ret, self.val))
832             ret = self.s.get(self.id)
833             self.failUnless(type(ret) == types.StringType,
834                     "%s.get() returned %s not StringType"
835                     % (vars(self.Class)['name'], type(ret)))
836             s = unicode(ret, self.s.encoding)
837             self.failUnless(s == val,
838                     "%s.get() returned %s not %s"
839                     % (vars(self.Class)['name'], s, self.val))
840
841
842     class Storage_persistence_TestCase (StorageTestCase):
843         """Test cases for Storage.disconnect and .connect methods."""
844
845         id = 'unlikely id'
846         val = 'unlikely value'
847
848         def test_get_set_persistence(self):
849             """Set should define the value returned by get after reconnect.
850             """
851             self.s.add(self.id, directory=False)
852             self.s.set(self.id, self.val)
853             self.s.disconnect()
854             self.s.connect()
855             ret = self.s.get(self.id)
856             self.failUnless(ret == self.val,
857                     "%s.get() returned %s not %s"
858                     % (vars(self.Class)['name'], ret, self.val))
859
860         def test_empty_get_set_persistence(self):
861             """After empty set, get may return either an empty string or default.
862             """
863             self.s.add(self.id, directory=False)
864             self.s.set(self.id, '')
865             self.s.disconnect()
866             self.s.connect()
867             default = 'UNLIKELY DEFAULT'
868             ret = self.s.get(self.id, default=default)
869             self.failUnless(ret in ['', default],
870                     "%s.get() returned %s not in %s"
871                     % (vars(self.Class)['name'], ret, ['', default]))
872
873         def test_add_nonrooted_persistence(self):
874             """Adding entries should increase the number of children after reconnect.
875             """
876             self.s.add('parent', directory=True)
877             ids = []
878             for i in range(10):
879                 ids.append(str(i))
880                 self.s.add(ids[-1], 'parent', directory=(i % 2 == 0))
881             self.s.disconnect()
882             self.s.connect()
883             s = sorted(self.s.children('parent'))
884             self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
885             s = self.s.children()
886             self.failUnless(s == ['parent'], s)
887
888     class VersionedStorageTestCase (StorageTestCase):
889         """Test cases for VersionedStorage methods."""
890
891         Class = VersionedStorage
892
893     class VersionedStorage_commit_TestCase (VersionedStorageTestCase):
894         """Test cases for VersionedStorage.commit and revision_ids methods."""
895
896         id = 'unlikely id'
897         val = 'Some value'
898         commit_msg = 'Committing something interesting'
899         commit_body = 'Some\nlonger\ndescription\n'
900
901         def _setup_for_empty_commit(self):
902             """
903             Initialization might add some files to version control, so
904             commit those first, before testing the empty commit
905             functionality.
906             """
907             try:
908                 self.s.commit('Added initialization files')
909             except EmptyCommit:
910                 pass
911                 
912         def test_revision_id_exception(self):
913             """Invalid revision id should raise InvalidRevision.
914             """
915             try:
916                 rev = self.s.revision_id('highly unlikely revision id')
917                 self.fail(
918                     "%s.revision_id() didn't raise InvalidRevision, returned %s."
919                     % (vars(self.Class)['name'], rev))
920             except InvalidRevision:
921                 pass
922
923         def test_empty_commit_raises_exception(self):
924             """Empty commit should raise exception.
925             """
926             self._setup_for_empty_commit()
927             try:
928                 self.s.commit(self.commit_msg, self.commit_body)
929                 self.fail(
930                     "Empty %(name)s.commit() didn't raise EmptyCommit."
931                     % vars(self.Class))
932             except EmptyCommit:
933                 pass
934
935         def test_empty_commit_allowed(self):
936             """Empty commit should _not_ raise exception if allow_empty=True.
937             """
938             self._setup_for_empty_commit()
939             self.s.commit(self.commit_msg, self.commit_body,
940                           allow_empty=True)
941
942         def test_commit_revision_ids(self):
943             """Commit / revision_id should agree on revision ids.
944             """
945             def val(i):
946                 return '%s:%d' % (self.val, i+1)
947             self.s.add(self.id, directory=False)
948             revs = []
949             for i in range(10):
950                 self.s.set(self.id, val(i))
951                 revs.append(self.s.commit('%s: %d' % (self.commit_msg, i),
952                                           self.commit_body))
953             for i in range(10):
954                 rev = self.s.revision_id(i+1)
955                 self.failUnless(rev == revs[i],
956                                 "%s.revision_id(%d) returned %s not %s"
957                                 % (vars(self.Class)['name'], i+1, rev, revs[i]))
958             for i in range(-1, -9, -1):
959                 rev = self.s.revision_id(i)
960                 self.failUnless(rev == revs[i],
961                                 "%s.revision_id(%d) returned %s not %s"
962                                 % (vars(self.Class)['name'], i, rev, revs[i]))
963
964         def test_get_previous_version(self):
965             """Get should be able to return the previous version.
966             """
967             def val(i):
968                 return '%s:%d' % (self.val, i+1)
969             self.s.add(self.id, directory=False)
970             revs = []
971             for i in range(10):
972                 self.s.set(self.id, val(i))
973                 revs.append(self.s.commit('%s: %d' % (self.commit_msg, i),
974                                           self.commit_body))
975             for i in range(10):
976                 ret = self.s.get(self.id, revision=revs[i])
977                 self.failUnless(ret == val(i),
978                                 "%s.get() returned %s not %s for revision %s"
979                                 % (vars(self.Class)['name'], ret, val(i), revs[i]))
980
981         def test_get_previous_children(self):
982             """Children list should be revision dependent.
983             """
984             self.s.add('parent', directory=True)
985             revs = []
986             cur_children = []
987             children = []
988             for i in range(10):
989                 new_child = str(i)
990                 self.s.add(new_child, 'parent')
991                 self.s.set(new_child, self.val)
992                 revs.append(self.s.commit('%s: %d' % (self.commit_msg, i),
993                                           self.commit_body))
994                 cur_children.append(new_child)
995                 children.append(list(cur_children))
996             for i in range(10):
997                 ret = sorted(self.s.children('parent', revision=revs[i]))
998                 self.failUnless(ret == children[i],
999                                 "%s.children() returned %s not %s for revision %s"
1000                                 % (vars(self.Class)['name'], ret,
1001                                    children[i], revs[i]))
1002
1003     class VersionedStorage_changed_TestCase (VersionedStorageTestCase):
1004         """Test cases for VersionedStorage.changed() method."""
1005
1006         def test_changed(self):
1007             """Changed lists should reflect past activity"""
1008             self.s.add('dir', directory=True)
1009             self.s.add('modified', parent='dir')
1010             self.s.set('modified', 'some value to be modified')
1011             self.s.add('moved', parent='dir')
1012             self.s.set('moved', 'this entry will be moved')
1013             self.s.add('removed', parent='dir')
1014             self.s.set('removed', 'this entry will be deleted')
1015             revA = self.s.commit('Initial state')
1016             self.s.add('new', parent='dir')
1017             self.s.set('new', 'this entry is new')
1018             self.s.set('modified', 'a new value')
1019             self.s.remove('moved')
1020             self.s.add('moved2', parent='dir')
1021             self.s.set('moved2', 'this entry will be moved')
1022             self.s.remove('removed')
1023             revB = self.s.commit('Final state')
1024             new,mod,rem = self.s.changed(revA)
1025             self.failUnless(sorted(new) == ['moved2', 'new'],
1026                             'Unexpected new: %s' % new)
1027             self.failUnless(mod == ['modified'],
1028                             'Unexpected modified: %s' % mod)
1029             self.failUnless(sorted(rem) == ['moved', 'removed'],
1030                             'Unexpected removed: %s' % rem)
1031
1032     def make_storage_testcase_subclasses(storage_class, namespace):
1033         """Make StorageTestCase subclasses for storage_class in namespace."""
1034         storage_testcase_classes = [
1035             c for c in (
1036                 ob for ob in globals().values() if isinstance(ob, type))
1037             if issubclass(c, StorageTestCase) \
1038                 and c.Class == Storage]
1039
1040         for base_class in storage_testcase_classes:
1041             testcase_class_name = storage_class.__name__ + base_class.__name__
1042             testcase_class_bases = (base_class,)
1043             testcase_class_dict = dict(base_class.__dict__)
1044             testcase_class_dict['Class'] = storage_class
1045             testcase_class = type(
1046                 testcase_class_name, testcase_class_bases, testcase_class_dict)
1047             setattr(namespace, testcase_class_name, testcase_class)
1048
1049     def make_versioned_storage_testcase_subclasses(storage_class, namespace):
1050         """Make VersionedStorageTestCase subclasses for storage_class in namespace."""
1051         storage_testcase_classes = [
1052             c for c in (
1053                 ob for ob in globals().values() if isinstance(ob, type))
1054             if ((issubclass(c, StorageTestCase) \
1055                      and c.Class == Storage)
1056                 or
1057                 (issubclass(c, VersionedStorageTestCase) \
1058                      and c.Class == VersionedStorage))]
1059
1060         for base_class in storage_testcase_classes:
1061             testcase_class_name = storage_class.__name__ + base_class.__name__
1062             testcase_class_bases = (base_class,)
1063             testcase_class_dict = dict(base_class.__dict__)
1064             testcase_class_dict['Class'] = storage_class
1065             testcase_class = type(
1066                 testcase_class_name, testcase_class_bases, testcase_class_dict)
1067             setattr(namespace, testcase_class_name, testcase_class)
1068
1069     make_storage_testcase_subclasses(VersionedStorage, sys.modules[__name__])
1070
1071     unitsuite =unittest.TestLoader().loadTestsFromModule(sys.modules[__name__])
1072     suite = unittest.TestSuite([unitsuite, doctest.DocTestSuite()])