Merge commit 'refs/merge-requests/3' of git://gitorious.org/be/be
[be.git] / libbe / storage / base.py
1 # Copyright (C) 2009-2011 W. Trevor King <wking@drexel.edu>
2 #
3 # This file is part of Bugs Everywhere.
4 #
5 # Bugs Everywhere is free software; you can redistribute it and/or modify it
6 # under the terms of the GNU General Public License as published by the
7 # Free Software Foundation, either version 2 of the License, or (at your
8 # option) any later version.
9 #
10 # Bugs Everywhere is distributed in the hope that it will be useful, but
11 # WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13 # General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with Bugs Everywhere.  If not, see <http://www.gnu.org/licenses/>.
17
18 """
19 Abstract bug repository data storage to easily support multiple backends.
20 """
21
22 import copy
23 import os
24 import pickle
25 import types
26
27 from libbe.error import NotSupported
28 import libbe.storage
29 from libbe.util.tree import Tree
30 from libbe.util import InvalidObject
31 import libbe.version
32 from libbe import TESTING
33
34 if TESTING == True:
35     import doctest
36     import os.path
37     import sys
38     import unittest
39
40     from libbe.util.utility import Dir
41
42 class ConnectionError (Exception):
43     pass
44
45 class InvalidStorageVersion(ConnectionError):
46     def __init__(self, active_version, expected_version=None):
47         if expected_version == None:
48             expected_version = libbe.storage.STORAGE_VERSION
49         msg = 'Storage in "%s" not the expected "%s"' \
50             % (active_version, expected_version)
51         Exception.__init__(self, msg)
52         self.active_version = active_version
53         self.expected_version = expected_version
54
55 class InvalidID (KeyError):
56     def __init__(self, id=None, revision=None, msg=None):
57         KeyError.__init__(self, id)
58         self.msg = msg
59         self.id = id
60         self.revision = revision
61     def __str__(self):
62         if self.msg == None:
63             return '%s in revision %s' % (self.id, self.revision)
64         return self.msg
65
66
67 class InvalidRevision (KeyError):
68     pass
69
70 class InvalidDirectory (Exception):
71     pass
72
73 class DirectoryNotEmpty (InvalidDirectory):
74     pass
75
76 class NotWriteable (NotSupported):
77     def __init__(self, msg):
78         NotSupported.__init__(self, 'write', msg)
79
80 class NotReadable (NotSupported):
81     def __init__(self, msg):
82         NotSupported.__init__(self, 'read', msg)
83
84 class EmptyCommit(Exception):
85     def __init__(self):
86         Exception.__init__(self, 'No changes to commit')
87
88 class _EMPTY (object):
89     """Entry has been added but has no user-set value."""
90     pass
91
92 class Entry (Tree):
93     def __init__(self, id, value=_EMPTY, parent=None, directory=False,
94                  children=None):
95         if children == None:
96             Tree.__init__(self)
97         else:
98             Tree.__init__(self, children)
99         self.id = id
100         self.value = value
101         self.parent = parent
102         if self.parent != None:
103             if self.parent.directory == False:
104                 raise InvalidDirectory(
105                     'Non-directory %s cannot have children' % self.parent)
106             parent.append(self)
107         self.directory = directory
108
109     def __str__(self):
110         return '<Entry %s: %s>' % (self.id, self.value)
111
112     def __repr__(self):
113         return str(self)
114
115     def __cmp__(self, other, local=False):
116         if other == None:
117             return cmp(1, None)
118         if cmp(self.id, other.id) != 0:
119             return cmp(self.id, other.id)
120         if cmp(self.value, other.value) != 0:
121             return cmp(self.value, other.value)
122         if local == False:
123             if self.parent == None:
124                 if cmp(self.parent, other.parent) != 0:
125                     return cmp(self.parent, other.parent)
126             elif self.parent.__cmp__(other.parent, local=True) != 0:
127                 return self.parent.__cmp__(other.parent, local=True)
128             for sc,oc in zip(self, other):
129                 if sc.__cmp__(oc, local=True) != 0:
130                     return sc.__cmp__(oc, local=True)
131         return 0
132
133     def _objects_to_ids(self):
134         if self.parent != None:
135             self.parent = self.parent.id
136         for i,c in enumerate(self):
137             self[i] = c.id
138         return self
139
140     def _ids_to_objects(self, dict):
141         if self.parent != None:
142             self.parent = dict[self.parent]
143         for i,c in enumerate(self):
144             self[i] = dict[c]
145         return self
146
147 class Storage (object):
148     """
149     This class declares all the methods required by a Storage
150     interface.  This implementation just keeps the data in a
151     dictionary and uses pickle for persistent storage.
152     """
153     name = 'Storage'
154
155     def __init__(self, repo='/', encoding='utf-8', options=None):
156         self.repo = repo
157         self.encoding = encoding
158         self.options = options
159         self.readable = True  # soft limit (user choice)
160         self._readable = True # hard limit (backend choice)
161         self.writeable = True  # soft limit (user choice)
162         self._writeable = True # hard limit (backend choice)
163         self.versioned = False
164         self.can_init = True
165         self.connected = False
166
167     def __str__(self):
168         return '<%s %s %s>' % (self.__class__.__name__, id(self), self.repo)
169
170     def __repr__(self):
171         return str(self)
172
173     def version(self):
174         """Return a version string for this backend."""
175         return libbe.version.version()
176
177     def storage_version(self, revision=None):
178         """Return the storage format for this backend."""
179         return libbe.storage.STORAGE_VERSION
180
181     def is_readable(self):
182         return self.readable and self._readable
183
184     def is_writeable(self):
185         return self.writeable and self._writeable
186
187     def init(self):
188         """Create a new storage repository."""
189         if self.can_init == False:
190             raise NotSupported('init',
191                                'Cannot initialize this repository format.')
192         if self.is_writeable() == False:
193             raise NotWriteable('Cannot initialize unwriteable storage.')
194         return self._init()
195
196     def _init(self):
197         f = open(os.path.join(self.repo, 'repo.pkl'), 'wb')
198         root = Entry(id='__ROOT__', directory=True)
199         d = {root.id:root}
200         pickle.dump(dict((k,v._objects_to_ids()) for k,v in d.items()), f, -1)
201         f.close()
202
203     def destroy(self):
204         """Remove the storage repository."""
205         if self.is_writeable() == False:
206             raise NotWriteable('Cannot destroy unwriteable storage.')
207         return self._destroy()
208
209     def _destroy(self):
210         os.remove(os.path.join(self.repo, 'repo.pkl'))
211
212     def connect(self):
213         """Open a connection to the repository."""
214         if self.is_readable() == False:
215             raise NotReadable('Cannot connect to unreadable storage.')
216         self._connect()
217         self.connected = True
218
219     def _connect(self):
220         try:
221             f = open(os.path.join(self.repo, 'repo.pkl'), 'rb')
222         except IOError:
223             raise ConnectionError(self)
224         d = pickle.load(f)
225         self._data = dict((k,v._ids_to_objects(d)) for k,v in d.items())
226         f.close()
227
228     def disconnect(self):
229         """Close the connection to the repository."""
230         if self.is_writeable() == False:
231             return
232         if self.connected == False:
233             return
234         self._disconnect()
235         self.connected = False
236
237     def _disconnect(self):
238         f = open(os.path.join(self.repo, 'repo.pkl'), 'wb')
239         pickle.dump(dict((k,v._objects_to_ids())
240                          for k,v in self._data.items()), f, -1)
241         f.close()
242         self._data = None
243
244     def add(self, id, *args, **kwargs):
245         """Add an entry"""
246         if self.is_writeable() == False:
247             raise NotWriteable('Cannot add entry to unwriteable storage.')
248         if not self.exists(id):
249             self._add(id, *args, **kwargs)
250
251     def _add(self, id, parent=None, directory=False):
252         if parent == None:
253             parent = '__ROOT__'
254         p = self._data[parent]
255         self._data[id] = Entry(id, parent=p, directory=directory)
256
257     def exists(self, *args, **kwargs):
258         """Check an entry's existence"""
259         if self.is_readable() == False:
260             raise NotReadable('Cannot check entry existence in unreadable storage.')
261         return self._exists(*args, **kwargs)
262
263     def _exists(self, id, revision=None):
264         return id in self._data
265
266     def remove(self, *args, **kwargs):
267         """Remove an entry."""
268         if self.is_writeable() == False:
269             raise NotSupported('write',
270                                'Cannot remove entry from unwriteable storage.')
271         self._remove(*args, **kwargs)
272
273     def _remove(self, id):
274         if self._data[id].directory == True \
275                 and len(self.children(id)) > 0:
276             raise DirectoryNotEmpty(id)
277         e = self._data.pop(id)
278         e.parent.remove(e)
279
280     def recursive_remove(self, *args, **kwargs):
281         """Remove an entry and all its decendents."""
282         if self.is_writeable() == False:
283             raise NotSupported('write',
284                                'Cannot remove entries from unwriteable storage.')
285         self._recursive_remove(*args, **kwargs)
286
287     def _recursive_remove(self, id):
288         for entry in reversed(list(self._data[id].traverse())):
289             self._remove(entry.id)
290
291     def ancestors(self, *args, **kwargs):
292         """Return a list of the specified entry's ancestors' ids."""
293         if self.is_readable() == False:
294             raise NotReadable('Cannot list parents with unreadable storage.')
295         return self._ancestors(*args, **kwargs)
296
297     def _ancestors(self, id=None, revision=None):
298         if id == None:
299             return []
300         ancestors = []
301         stack = [id]
302         while len(stack) > 0:
303             id = stack.pop(0)
304             parent = self._data[id].parent
305             if parent != None and not parent.id.startswith('__'):
306                 ancestor = parent.id
307                 ancestors.append(ancestor)
308                 stack.append(ancestor)
309         return ancestors
310
311     def children(self, *args, **kwargs):
312         """Return a list of specified entry's children's ids."""
313         if self.is_readable() == False:
314             raise NotReadable('Cannot list children with unreadable storage.')
315         return self._children(*args, **kwargs)
316
317     def _children(self, id=None, revision=None):
318         if id == None:
319             id = '__ROOT__'
320         return [c.id for c in self._data[id] if not c.id.startswith('__')]
321
322     def get(self, *args, **kwargs):
323         """
324         Get contents of and entry as they were in a given revision.
325         revision==None specifies the current revision.
326
327         If there is no id, return default, unless default is not
328         given, in which case raise InvalidID.
329         """
330         if self.is_readable() == False:
331             raise NotReadable('Cannot get entry with unreadable storage.')
332         if 'decode' in kwargs:
333             decode = kwargs.pop('decode')
334         else:
335             decode = False
336         value = self._get(*args, **kwargs)
337         if value != None:
338             if decode == True and type(value) != types.UnicodeType:
339                 return unicode(value, self.encoding)
340             elif decode == False and type(value) != types.StringType:
341                 return value.encode(self.encoding)
342         return value
343
344     def _get(self, id, default=InvalidObject, revision=None):
345         if id in self._data and self._data[id].value != _EMPTY:
346             return self._data[id].value
347         elif default == InvalidObject:
348             raise InvalidID(id)
349         return default
350
351     def set(self, id, value, *args, **kwargs):
352         """
353         Set the entry contents.
354         """
355         if self.is_writeable() == False:
356             raise NotWriteable('Cannot set entry in unwriteable storage.')
357         if type(value) == types.UnicodeType:
358             value = value.encode(self.encoding)
359         self._set(id, value, *args, **kwargs)
360
361     def _set(self, id, value):
362         if id not in self._data:
363             raise InvalidID(id)
364         if self._data[id].directory == True:
365             raise InvalidDirectory(
366                 'Directory %s cannot have data' % self.parent)
367         self._data[id].value = value
368
369 class VersionedStorage (Storage):
370     """
371     This class declares all the methods required by a Storage
372     interface that supports versioning.  This implementation just
373     keeps the data in a list and uses pickle for persistent
374     storage.
375     """
376     name = 'VersionedStorage'
377
378     def __init__(self, *args, **kwargs):
379         Storage.__init__(self, *args, **kwargs)
380         self.versioned = True
381
382     def _init(self):
383         f = open(os.path.join(self.repo, 'repo.pkl'), 'wb')
384         root = Entry(id='__ROOT__', directory=True)
385         summary = Entry(id='__COMMIT__SUMMARY__', value='Initial commit')
386         body = Entry(id='__COMMIT__BODY__')
387         initial_commit = {root.id:root, summary.id:summary, body.id:body}
388         d = dict((k,v._objects_to_ids()) for k,v in initial_commit.items())
389         pickle.dump([d, copy.deepcopy(d)], f, -1) # [inital tree, working tree]
390         f.close()
391
392     def _connect(self):
393         try:
394             f = open(os.path.join(self.repo, 'repo.pkl'), 'rb')
395         except IOError:
396             raise ConnectionError(self)
397         d = pickle.load(f)
398         self._data = [dict((k,v._ids_to_objects(t)) for k,v in t.items())
399                       for t in d]
400         f.close()
401
402     def _disconnect(self):
403         f = open(os.path.join(self.repo, 'repo.pkl'), 'wb')
404         pickle.dump([dict((k,v._objects_to_ids())
405                           for k,v in t.items()) for t in self._data], f, -1)
406         f.close()
407         self._data = None
408
409     def _add(self, id, parent=None, directory=False):
410         if parent == None:
411             parent = '__ROOT__'
412         p = self._data[-1][parent]
413         self._data[-1][id] = Entry(id, parent=p, directory=directory)
414
415     def _exists(self, id, revision=None):
416         if revision == None:
417             revision = -1
418         else:
419             revision = int(revision)
420         return id in self._data[revision]
421
422     def _remove(self, id):
423         if self._data[-1][id].directory == True \
424                 and len(self.children(id)) > 0:
425             raise DirectoryNotEmpty(id)
426         e = self._data[-1].pop(id)
427         e.parent.remove(e)
428
429     def _recursive_remove(self, id):
430         for entry in reversed(list(self._data[-1][id].traverse())):
431             self._remove(entry.id)
432
433     def _ancestors(self, id=None, revision=None):
434         if id == None:
435             return []
436         if revision == None:
437             revision = -1
438         else:
439             revision = int(revision)
440         ancestors = []
441         stack = [id]
442         while len(stack) > 0:
443             id = stack.pop(0)
444             parent = self._data[revision][id].parent
445             if parent != None and not parent.id.startswith('__'):
446                 ancestor = parent.id
447                 ancestors.append(ancestor)
448                 stack.append(ancestor)
449         return ancestors
450
451     def _children(self, id=None, revision=None):
452         if id == None:
453             id = '__ROOT__'
454         if revision == None:
455             revision = -1
456         else:
457             revision = int(revision)
458         return [c.id for c in self._data[revision][id]
459                 if not c.id.startswith('__')]
460
461     def _get(self, id, default=InvalidObject, revision=None):
462         if revision == None:
463             revision = -1
464         else:
465             revision = int(revision)
466         if id in self._data[revision] \
467                 and self._data[revision][id].value != _EMPTY:
468             return self._data[revision][id].value
469         elif default == InvalidObject:
470             raise InvalidID(id)
471         return default
472
473     def _set(self, id, value):
474         if id not in self._data[-1]:
475             raise InvalidID(id)
476         self._data[-1][id].value = value
477
478     def commit(self, *args, **kwargs):
479         """
480         Commit the current repository, with a commit message string
481         summary and body.  Return the name of the new revision.
482
483         If allow_empty == False (the default), raise EmptyCommit if
484         there are no changes to commit.
485         """
486         if self.is_writeable() == False:
487             raise NotWriteable('Cannot commit to unwriteable storage.')
488         return self._commit(*args, **kwargs)
489
490     def _commit(self, summary, body=None, allow_empty=False):
491         if self._data[-1] == self._data[-2] and allow_empty == False:
492             raise EmptyCommit
493         self._data[-1]["__COMMIT__SUMMARY__"].value = summary
494         self._data[-1]["__COMMIT__BODY__"].value = body
495         rev = str(len(self._data)-1)
496         self._data.append(copy.deepcopy(self._data[-1]))
497         return rev
498
499     def revision_id(self, index=None):
500         """
501         Return the name of the <index>th revision.  The choice of
502         which branch to follow when crossing branches/merges is not
503         defined.  Revision indices start at 1; ID 0 is the blank
504         repository.
505
506         Return None if index==None.
507
508         If the specified revision does not exist, raise InvalidRevision.
509         """
510         if index == None:
511             return None
512         try:
513             if int(index) != index:
514                 raise InvalidRevision(index)
515         except ValueError:
516             raise InvalidRevision(index)
517         L = len(self._data) - 1  # -1 b/c of initial commit
518         if index >= -L and index <= L:
519             return str(index % L)
520         raise InvalidRevision(i)
521
522     def changed(self, revision):
523         """Return a tuple of lists of ids `(new, modified, removed)` from the
524         specified revision to the current situation.
525         """
526         new = []
527         modified = []
528         removed = []
529         for id,value in self._data[int(revision)].items():
530             if id.startswith('__'):
531                 continue
532             if not id in self._data[-1]:
533                 removed.append(id)
534             elif value.value != self._data[-1][id].value:
535                 modified.append(id)
536         for id in self._data[-1]:
537             if not id in self._data[int(revision)]:
538                 new.append(id)
539         return (new, modified, removed)
540
541
542 if TESTING == True:
543     class StorageTestCase (unittest.TestCase):
544         """Test cases for Storage class."""
545
546         Class = Storage
547
548         def __init__(self, *args, **kwargs):
549             super(StorageTestCase, self).__init__(*args, **kwargs)
550             self.dirname = None
551
552         # this class will be the basis of tests for several classes,
553         # so make sure we print the name of the class we're dealing with.
554         def _classname(self):
555             version = '?'
556             try:
557                 if hasattr(self, 's'):
558                     version = self.s.version()
559             except:
560                 pass
561             return '%s:%s' % (self.Class.__name__, version)
562
563         def fail(self, msg=None):
564             """Fail immediately, with the given message."""
565             raise self.failureException, \
566                 '(%s) %s' % (self._classname(), msg)
567
568         def failIf(self, expr, msg=None):
569             "Fail the test if the expression is true."
570             if expr: raise self.failureException, \
571                 '(%s) %s' % (self._classname(), msg)
572
573         def failUnless(self, expr, msg=None):
574             """Fail the test unless the expression is true."""
575             if not expr: raise self.failureException, \
576                 '(%s) %s' % (self._classname(), msg)
577
578         def setUp(self):
579             """Set up test fixtures for Storage test case."""
580             super(StorageTestCase, self).setUp()
581             self.dir = Dir()
582             self.dirname = self.dir.path
583             self.s = self.Class(repo=self.dirname)
584             self.assert_failed_connect()
585             self.s.init()
586             self.s.connect()
587
588         def tearDown(self):
589             super(StorageTestCase, self).tearDown()
590             self.s.disconnect()
591             self.s.destroy()
592             self.assert_failed_connect()
593             self.dir.cleanup()
594
595         def assert_failed_connect(self):
596             try:
597                 self.s.connect()
598                 self.fail(
599                     "Connected to %(name)s repository before initialising"
600                     % vars(self.Class))
601             except ConnectionError:
602                 pass
603
604     class Storage_init_TestCase (StorageTestCase):
605         """Test cases for Storage.init method."""
606
607         def test_connect_should_succeed_after_init(self):
608             """Should connect after initialization."""
609             self.s.connect()
610
611     class Storage_connect_disconnect_TestCase (StorageTestCase):
612         """Test cases for Storage.connect and .disconnect methods."""
613
614         def test_multiple_disconnects(self):
615             """Should be able to call .disconnect multiple times."""
616             self.s.disconnect()
617             self.s.disconnect()
618
619     class Storage_add_remove_TestCase (StorageTestCase):
620         """Test cases for Storage.add, .remove, and .recursive_remove methods."""
621
622         def test_initially_empty(self):
623             """New repository should be empty."""
624             self.failUnless(len(self.s.children()) == 0, self.s.children())
625
626         def test_add_identical_rooted(self):
627             """Adding entries with the same ID should not increase the number of children.
628             """
629             for i in range(10):
630                 self.s.add('some id', directory=False)
631                 s = sorted(self.s.children())
632                 self.failUnless(s == ['some id'], s)
633
634         def test_add_rooted(self):
635             """Adding entries should increase the number of children (rooted).
636             """
637             ids = []
638             for i in range(10):
639                 ids.append(str(i))
640                 self.s.add(ids[-1], directory=(i % 2 == 0))
641                 s = sorted(self.s.children())
642                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
643
644         def test_add_nonrooted(self):
645             """Adding entries should increase the number of children (nonrooted).
646             """
647             self.s.add('parent', directory=True)
648             ids = []
649             for i in range(10):
650                 ids.append(str(i))
651                 self.s.add(ids[-1], 'parent', directory=(i % 2 == 0))
652                 s = sorted(self.s.children('parent'))
653                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
654                 s = self.s.children()
655                 self.failUnless(s == ['parent'], s)
656
657         def test_ancestors(self):
658             """Check ancestors lists.
659             """
660             self.s.add('parent', directory=True)
661             for i in range(10):
662                 i_id = str(i)
663                 self.s.add(i_id, 'parent', directory=True)
664                 for j in range(10): # add some grandkids
665                     j_id = str(20*(i+1)+j)
666                     self.s.add(j_id, i_id, directory=(i%2 == 0))
667                     ancestors = sorted(self.s.ancestors(j_id))
668                     self.failUnless(ancestors == [i_id, 'parent'],
669                         'Unexpected ancestors for %s/%s, "%s"'
670                         % (i_id, j_id, ancestors))
671
672         def test_children(self):
673             """Non-UUID ids should be returned as such.
674             """
675             self.s.add('parent', directory=True)
676             ids = []
677             for i in range(10):
678                 ids.append('parent/%s' % str(i))
679                 self.s.add(ids[-1], 'parent', directory=(i % 2 == 0))
680                 s = sorted(self.s.children('parent'))
681                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
682
683         def test_add_invalid_directory(self):
684             """Should not be able to add children to non-directories.
685             """
686             self.s.add('parent', directory=False)
687             try:
688                 self.s.add('child', 'parent', directory=False)
689                 self.fail(
690                     '%s.add() succeeded instead of raising InvalidDirectory'
691                     % (vars(self.Class)['name']))
692             except InvalidDirectory:
693                 pass
694             try:
695                 self.s.add('child', 'parent', directory=True)
696                 self.fail(
697                     '%s.add() succeeded instead of raising InvalidDirectory'
698                     % (vars(self.Class)['name']))
699             except InvalidDirectory:
700                 pass
701             self.failUnless(len(self.s.children('parent')) == 0,
702                             self.s.children('parent'))
703
704         def test_remove_rooted(self):
705             """Removing entries should decrease the number of children (rooted).
706             """
707             ids = []
708             for i in range(10):
709                 ids.append(str(i))
710                 self.s.add(ids[-1], directory=(i % 2 == 0))
711             for i in range(10):
712                 self.s.remove(ids.pop())
713                 s = sorted(self.s.children())
714                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
715
716         def test_remove_nonrooted(self):
717             """Removing entries should decrease the number of children (nonrooted).
718             """
719             self.s.add('parent', directory=True)
720             ids = []
721             for i in range(10):
722                 ids.append(str(i))
723                 self.s.add(ids[-1], 'parent', directory=False)#(i % 2 == 0))
724             for i in range(10):
725                 self.s.remove(ids.pop())
726                 s = sorted(self.s.children('parent'))
727                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
728                 if len(s) > 0:
729                     s = self.s.children()
730                     self.failUnless(s == ['parent'], s)
731
732         def test_remove_directory_not_empty(self):
733             """Removing a non-empty directory entry should raise exception.
734             """
735             self.s.add('parent', directory=True)
736             ids = []
737             for i in range(10):
738                 ids.append(str(i))
739                 self.s.add(ids[-1], 'parent', directory=(i % 2 == 0))
740             self.s.remove(ids.pop()) # empty directory removal succeeds
741             try:
742                 self.s.remove('parent') # empty directory removal succeeds
743                 self.fail(
744                     "%s.remove() didn't raise DirectoryNotEmpty"
745                     % (vars(self.Class)['name']))
746             except DirectoryNotEmpty:
747                 pass
748
749         def test_recursive_remove(self):
750             """Recursive remove should empty the tree."""
751             self.s.add('parent', directory=True)
752             ids = []
753             for i in range(10):
754                 ids.append(str(i))
755                 self.s.add(ids[-1], 'parent', directory=True)
756                 for j in range(10): # add some grandkids
757                     self.s.add(str(20*(i+1)+j), ids[-1], directory=(i%2 == 0))
758             self.s.recursive_remove('parent')
759             s = sorted(self.s.children())
760             self.failUnless(s == [], s)
761
762     class Storage_get_set_TestCase (StorageTestCase):
763         """Test cases for Storage.get and .set methods."""
764
765         id = 'unlikely id'
766         val = 'unlikely value'
767
768         def test_get_default(self):
769             """Get should return specified default if id not in Storage.
770             """
771             ret = self.s.get(self.id, default=self.val)
772             self.failUnless(ret == self.val,
773                     "%s.get() returned %s not %s"
774                     % (vars(self.Class)['name'], ret, self.val))
775
776         def test_get_default_exception(self):
777             """Get should raise exception if id not in Storage and no default.
778             """
779             try:
780                 ret = self.s.get(self.id)
781                 self.fail(
782                     "%s.get() returned %s instead of raising InvalidID"
783                     % (vars(self.Class)['name'], ret))
784             except InvalidID:
785                 pass
786
787         def test_get_initial_value(self):
788             """Data value should be default before any value has been set.
789             """
790             self.s.add(self.id, directory=False)
791             val = 'UNLIKELY DEFAULT'
792             ret = self.s.get(self.id, default=val)
793             self.failUnless(ret == val,
794                     "%s.get() returned %s not %s"
795                     % (vars(self.Class)['name'], ret, val))
796
797         def test_set_exception(self):
798             """Set should raise exception if id not in Storage.
799             """
800             try:
801                 self.s.set(self.id, self.val)
802                 self.fail(
803                     "%(name)s.set() did not raise InvalidID"
804                     % vars(self.Class))
805             except InvalidID:
806                 pass
807
808         def test_set(self):
809             """Set should define the value returned by get.
810             """
811             self.s.add(self.id, directory=False)
812             self.s.set(self.id, self.val)
813             ret = self.s.get(self.id)
814             self.failUnless(ret == self.val,
815                     "%s.get() returned %s not %s"
816                     % (vars(self.Class)['name'], ret, self.val))
817
818         def test_unicode_set(self):
819             """Set should define the value returned by get.
820             """
821             val = u'Fran\xe7ois'
822             self.s.add(self.id, directory=False)
823             self.s.set(self.id, val)
824             ret = self.s.get(self.id, decode=True)
825             self.failUnless(type(ret) == types.UnicodeType,
826                     "%s.get() returned %s not UnicodeType"
827                     % (vars(self.Class)['name'], type(ret)))
828             self.failUnless(ret == val,
829                     "%s.get() returned %s not %s"
830                     % (vars(self.Class)['name'], ret, self.val))
831             ret = self.s.get(self.id)
832             self.failUnless(type(ret) == types.StringType,
833                     "%s.get() returned %s not StringType"
834                     % (vars(self.Class)['name'], type(ret)))
835             s = unicode(ret, self.s.encoding)
836             self.failUnless(s == val,
837                     "%s.get() returned %s not %s"
838                     % (vars(self.Class)['name'], s, self.val))
839
840
841     class Storage_persistence_TestCase (StorageTestCase):
842         """Test cases for Storage.disconnect and .connect methods."""
843
844         id = 'unlikely id'
845         val = 'unlikely value'
846
847         def test_get_set_persistence(self):
848             """Set should define the value returned by get after reconnect.
849             """
850             self.s.add(self.id, directory=False)
851             self.s.set(self.id, self.val)
852             self.s.disconnect()
853             self.s.connect()
854             ret = self.s.get(self.id)
855             self.failUnless(ret == self.val,
856                     "%s.get() returned %s not %s"
857                     % (vars(self.Class)['name'], ret, self.val))
858
859         def test_empty_get_set_persistence(self):
860             """After empty set, get may return either an empty string or default.
861             """
862             self.s.add(self.id, directory=False)
863             self.s.set(self.id, '')
864             self.s.disconnect()
865             self.s.connect()
866             default = 'UNLIKELY DEFAULT'
867             ret = self.s.get(self.id, default=default)
868             self.failUnless(ret in ['', default],
869                     "%s.get() returned %s not in %s"
870                     % (vars(self.Class)['name'], ret, ['', default]))
871
872         def test_add_nonrooted_persistence(self):
873             """Adding entries should increase the number of children after reconnect.
874             """
875             self.s.add('parent', directory=True)
876             ids = []
877             for i in range(10):
878                 ids.append(str(i))
879                 self.s.add(ids[-1], 'parent', directory=(i % 2 == 0))
880             self.s.disconnect()
881             self.s.connect()
882             s = sorted(self.s.children('parent'))
883             self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
884             s = self.s.children()
885             self.failUnless(s == ['parent'], s)
886
887     class VersionedStorageTestCase (StorageTestCase):
888         """Test cases for VersionedStorage methods."""
889
890         Class = VersionedStorage
891
892     class VersionedStorage_commit_TestCase (VersionedStorageTestCase):
893         """Test cases for VersionedStorage.commit and revision_ids methods."""
894
895         id = 'unlikely id'
896         val = 'Some value'
897         commit_msg = 'Committing something interesting'
898         commit_body = 'Some\nlonger\ndescription\n'
899
900         def _setup_for_empty_commit(self):
901             """
902             Initialization might add some files to version control, so
903             commit those first, before testing the empty commit
904             functionality.
905             """
906             try:
907                 self.s.commit('Added initialization files')
908             except EmptyCommit:
909                 pass
910                 
911         def test_revision_id_exception(self):
912             """Invalid revision id should raise InvalidRevision.
913             """
914             try:
915                 rev = self.s.revision_id('highly unlikely revision id')
916                 self.fail(
917                     "%s.revision_id() didn't raise InvalidRevision, returned %s."
918                     % (vars(self.Class)['name'], rev))
919             except InvalidRevision:
920                 pass
921
922         def test_empty_commit_raises_exception(self):
923             """Empty commit should raise exception.
924             """
925             self._setup_for_empty_commit()
926             try:
927                 self.s.commit(self.commit_msg, self.commit_body)
928                 self.fail(
929                     "Empty %(name)s.commit() didn't raise EmptyCommit."
930                     % vars(self.Class))
931             except EmptyCommit:
932                 pass
933
934         def test_empty_commit_allowed(self):
935             """Empty commit should _not_ raise exception if allow_empty=True.
936             """
937             self._setup_for_empty_commit()
938             self.s.commit(self.commit_msg, self.commit_body,
939                           allow_empty=True)
940
941         def test_commit_revision_ids(self):
942             """Commit / revision_id should agree on revision ids.
943             """
944             def val(i):
945                 return '%s:%d' % (self.val, i+1)
946             self.s.add(self.id, directory=False)
947             revs = []
948             for i in range(10):
949                 self.s.set(self.id, val(i))
950                 revs.append(self.s.commit('%s: %d' % (self.commit_msg, i),
951                                           self.commit_body))
952             for i in range(10):
953                 rev = self.s.revision_id(i+1)
954                 self.failUnless(rev == revs[i],
955                                 "%s.revision_id(%d) returned %s not %s"
956                                 % (vars(self.Class)['name'], i+1, rev, revs[i]))
957             for i in range(-1, -9, -1):
958                 rev = self.s.revision_id(i)
959                 self.failUnless(rev == revs[i],
960                                 "%s.revision_id(%d) returned %s not %s"
961                                 % (vars(self.Class)['name'], i, rev, revs[i]))
962
963         def test_get_previous_version(self):
964             """Get should be able to return the previous version.
965             """
966             def val(i):
967                 return '%s:%d' % (self.val, i+1)
968             self.s.add(self.id, directory=False)
969             revs = []
970             for i in range(10):
971                 self.s.set(self.id, val(i))
972                 revs.append(self.s.commit('%s: %d' % (self.commit_msg, i),
973                                           self.commit_body))
974             for i in range(10):
975                 ret = self.s.get(self.id, revision=revs[i])
976                 self.failUnless(ret == val(i),
977                                 "%s.get() returned %s not %s for revision %s"
978                                 % (vars(self.Class)['name'], ret, val(i), revs[i]))
979
980         def test_get_previous_children(self):
981             """Children list should be revision dependent.
982             """
983             self.s.add('parent', directory=True)
984             revs = []
985             cur_children = []
986             children = []
987             for i in range(10):
988                 new_child = str(i)
989                 self.s.add(new_child, 'parent')
990                 self.s.set(new_child, self.val)
991                 revs.append(self.s.commit('%s: %d' % (self.commit_msg, i),
992                                           self.commit_body))
993                 cur_children.append(new_child)
994                 children.append(list(cur_children))
995             for i in range(10):
996                 ret = sorted(self.s.children('parent', revision=revs[i]))
997                 self.failUnless(ret == children[i],
998                                 "%s.children() returned %s not %s for revision %s"
999                                 % (vars(self.Class)['name'], ret,
1000                                    children[i], revs[i]))
1001
1002     class VersionedStorage_changed_TestCase (VersionedStorageTestCase):
1003         """Test cases for VersionedStorage.changed() method."""
1004
1005         def test_changed(self):
1006             """Changed lists should reflect past activity"""
1007             self.s.add('dir', directory=True)
1008             self.s.add('modified', parent='dir')
1009             self.s.set('modified', 'some value to be modified')
1010             self.s.add('moved', parent='dir')
1011             self.s.set('moved', 'this entry will be moved')
1012             self.s.add('removed', parent='dir')
1013             self.s.set('removed', 'this entry will be deleted')
1014             revA = self.s.commit('Initial state')
1015             self.s.add('new', parent='dir')
1016             self.s.set('new', 'this entry is new')
1017             self.s.set('modified', 'a new value')
1018             self.s.remove('moved')
1019             self.s.add('moved2', parent='dir')
1020             self.s.set('moved2', 'this entry will be moved')
1021             self.s.remove('removed')
1022             revB = self.s.commit('Final state')
1023             new,mod,rem = self.s.changed(revA)
1024             self.failUnless(sorted(new) == ['moved2', 'new'],
1025                             'Unexpected new: %s' % new)
1026             self.failUnless(mod == ['modified'],
1027                             'Unexpected modified: %s' % mod)
1028             self.failUnless(sorted(rem) == ['moved', 'removed'],
1029                             'Unexpected removed: %s' % rem)
1030
1031     def make_storage_testcase_subclasses(storage_class, namespace):
1032         """Make StorageTestCase subclasses for storage_class in namespace."""
1033         storage_testcase_classes = [
1034             c for c in (
1035                 ob for ob in globals().values() if isinstance(ob, type))
1036             if issubclass(c, StorageTestCase) \
1037                 and c.Class == Storage]
1038
1039         for base_class in storage_testcase_classes:
1040             testcase_class_name = storage_class.__name__ + base_class.__name__
1041             testcase_class_bases = (base_class,)
1042             testcase_class_dict = dict(base_class.__dict__)
1043             testcase_class_dict['Class'] = storage_class
1044             testcase_class = type(
1045                 testcase_class_name, testcase_class_bases, testcase_class_dict)
1046             setattr(namespace, testcase_class_name, testcase_class)
1047
1048     def make_versioned_storage_testcase_subclasses(storage_class, namespace):
1049         """Make VersionedStorageTestCase subclasses for storage_class in namespace."""
1050         storage_testcase_classes = [
1051             c for c in (
1052                 ob for ob in globals().values() if isinstance(ob, type))
1053             if ((issubclass(c, StorageTestCase) \
1054                      and c.Class == Storage)
1055                 or
1056                 (issubclass(c, VersionedStorageTestCase) \
1057                      and c.Class == VersionedStorage))]
1058
1059         for base_class in storage_testcase_classes:
1060             testcase_class_name = storage_class.__name__ + base_class.__name__
1061             testcase_class_bases = (base_class,)
1062             testcase_class_dict = dict(base_class.__dict__)
1063             testcase_class_dict['Class'] = storage_class
1064             testcase_class = type(
1065                 testcase_class_name, testcase_class_bases, testcase_class_dict)
1066             setattr(namespace, testcase_class_name, testcase_class)
1067
1068     make_storage_testcase_subclasses(VersionedStorage, sys.modules[__name__])
1069
1070     unitsuite =unittest.TestLoader().loadTestsFromModule(sys.modules[__name__])
1071     suite = unittest.TestSuite([unitsuite, doctest.DocTestSuite()])