Merged be.faster-diff branch, fixing #bea/ed5#.
[be.git] / libbe / storage / base.py
1 # Copyright (C) 2009-2010 W. Trevor King <wking@drexel.edu>
2 #
3 # This program is free software; you can redistribute it and/or modify
4 # it under the terms of the GNU General Public License as published by
5 # the Free Software Foundation; either version 2 of the License, or
6 # (at your option) any later version.
7 #
8 # This program is distributed in the hope that it will be useful,
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 # GNU General Public License for more details.
12 #
13 # You should have received a copy of the GNU General Public License along
14 # with this program; if not, write to the Free Software Foundation, Inc.,
15 # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
16
17 """
18 Abstract bug repository data storage to easily support multiple backends.
19 """
20
21 import copy
22 import os
23 import pickle
24 import types
25
26 from libbe.error import NotSupported
27 import libbe.storage
28 from libbe.util.tree import Tree
29 from libbe.util import InvalidObject
30 import libbe.version
31 from libbe import TESTING
32
33 if TESTING == True:
34     import doctest
35     import os.path
36     import sys
37     import unittest
38
39     from libbe.util.utility import Dir
40
41 class ConnectionError (Exception):
42     pass
43
44 class InvalidStorageVersion(ConnectionError):
45     def __init__(self, active_version, expected_version=None):
46         if expected_version == None:
47             expected_version = libbe.storage.STORAGE_VERSION
48         msg = 'Storage in "%s" not the expected "%s"' \
49             % (active_version, expected_version)
50         Exception.__init__(self, msg)
51         self.active_version = active_version
52         self.expected_version = expected_version
53
54 class InvalidID (KeyError):
55     def __init__(self, id=None, revision=None, msg=None):
56         KeyError.__init__(self, id)
57         self.msg = msg
58         self.id = id
59         self.revision = revision
60     def __str__(self):
61         if self.msg == None:
62             return '%s in revision %s' % (self.id, self.revision)
63         return self.msg
64
65
66 class InvalidRevision (KeyError):
67     pass
68
69 class InvalidDirectory (Exception):
70     pass
71
72 class DirectoryNotEmpty (InvalidDirectory):
73     pass
74
75 class NotWriteable (NotSupported):
76     def __init__(self, msg):
77         NotSupported.__init__(self, 'write', msg)
78
79 class NotReadable (NotSupported):
80     def __init__(self, msg):
81         NotSupported.__init__(self, 'read', msg)
82
83 class EmptyCommit(Exception):
84     def __init__(self):
85         Exception.__init__(self, 'No changes to commit')
86
87
88 class Entry (Tree):
89     def __init__(self, id, value=None, parent=None, directory=False,
90                  children=None):
91         if children == None:
92             Tree.__init__(self)
93         else:
94             Tree.__init__(self, children)
95         self.id = id
96         self.value = value
97         self.parent = parent
98         if self.parent != None:
99             if self.parent.directory == False:
100                 raise InvalidDirectory(
101                     'Non-directory %s cannot have children' % self.parent)
102             parent.append(self)
103         self.directory = directory
104
105     def __str__(self):
106         return '<Entry %s: %s>' % (self.id, self.value)
107
108     def __repr__(self):
109         return str(self)
110
111     def __cmp__(self, other, local=False):
112         if other == None:
113             return cmp(1, None)
114         if cmp(self.id, other.id) != 0:
115             return cmp(self.id, other.id)
116         if cmp(self.value, other.value) != 0:
117             return cmp(self.value, other.value)
118         if local == False:
119             if self.parent == None:
120                 if cmp(self.parent, other.parent) != 0:
121                     return cmp(self.parent, other.parent)
122             elif self.parent.__cmp__(other.parent, local=True) != 0:
123                 return self.parent.__cmp__(other.parent, local=True)
124             for sc,oc in zip(self, other):
125                 if sc.__cmp__(oc, local=True) != 0:
126                     return sc.__cmp__(oc, local=True)
127         return 0
128
129     def _objects_to_ids(self):
130         if self.parent != None:
131             self.parent = self.parent.id
132         for i,c in enumerate(self):
133             self[i] = c.id
134         return self
135
136     def _ids_to_objects(self, dict):
137         if self.parent != None:
138             self.parent = dict[self.parent]
139         for i,c in enumerate(self):
140             self[i] = dict[c]
141         return self
142
143 class Storage (object):
144     """
145     This class declares all the methods required by a Storage
146     interface.  This implementation just keeps the data in a
147     dictionary and uses pickle for persistent storage.
148     """
149     name = 'Storage'
150
151     def __init__(self, repo='/', encoding='utf-8', options=None):
152         self.repo = repo
153         self.encoding = encoding
154         self.options = options
155         self.readable = True  # soft limit (user choice)
156         self._readable = True # hard limit (backend choice)
157         self.writeable = True  # soft limit (user choice)
158         self._writeable = True # hard limit (backend choice)
159         self.versioned = False
160         self.can_init = True
161         self.connected = False
162
163     def __str__(self):
164         return '<%s %s %s>' % (self.__class__.__name__, id(self), self.repo)
165
166     def __repr__(self):
167         return str(self)
168
169     def version(self):
170         """Return a version string for this backend."""
171         return libbe.version.version()
172
173     def storage_version(self, revision=None):
174         """Return the storage format for this backend."""
175         return libbe.storage.STORAGE_VERSION
176
177     def is_readable(self):
178         return self.readable and self._readable
179
180     def is_writeable(self):
181         return self.writeable and self._writeable
182
183     def init(self):
184         """Create a new storage repository."""
185         if self.can_init == False:
186             raise NotSupported('init',
187                                'Cannot initialize this repository format.')
188         if self.is_writeable() == False:
189             raise NotWriteable('Cannot initialize unwriteable storage.')
190         return self._init()
191
192     def _init(self):
193         f = open(os.path.join(self.repo, 'repo.pkl'), 'wb')
194         root = Entry(id='__ROOT__', directory=True)
195         d = {root.id:root}
196         pickle.dump(dict((k,v._objects_to_ids()) for k,v in d.items()), f, -1)
197         f.close()
198
199     def destroy(self):
200         """Remove the storage repository."""
201         if self.is_writeable() == False:
202             raise NotWriteable('Cannot destroy unwriteable storage.')
203         return self._destroy()
204
205     def _destroy(self):
206         os.remove(os.path.join(self.repo, 'repo.pkl'))
207
208     def connect(self):
209         """Open a connection to the repository."""
210         if self.is_readable() == False:
211             raise NotReadable('Cannot connect to unreadable storage.')
212         self._connect()
213         self.connected = True
214
215     def _connect(self):
216         try:
217             f = open(os.path.join(self.repo, 'repo.pkl'), 'rb')
218         except IOError:
219             raise ConnectionError(self)
220         d = pickle.load(f)
221         self._data = dict((k,v._ids_to_objects(d)) for k,v in d.items())
222         f.close()
223
224     def disconnect(self):
225         """Close the connection to the repository."""
226         if self.is_writeable() == False:
227             return
228         if self.connected == False:
229             return
230         self._disconnect()
231         self.connected = False
232
233     def _disconnect(self):
234         f = open(os.path.join(self.repo, 'repo.pkl'), 'wb')
235         pickle.dump(dict((k,v._objects_to_ids())
236                          for k,v in self._data.items()), f, -1)
237         f.close()
238         self._data = None
239
240     def add(self, id, *args, **kwargs):
241         """Add an entry"""
242         if self.is_writeable() == False:
243             raise NotWriteable('Cannot add entry to unwriteable storage.')
244         try:  # Maybe we've already added that id?
245             self.get(id)
246             pass # yup, no need to add another
247         except InvalidID:
248             self._add(id, *args, **kwargs)
249
250     def _add(self, id, parent=None, directory=False):
251         if parent == None:
252             parent = '__ROOT__'
253         p = self._data[parent]
254         self._data[id] = Entry(id, parent=p, directory=directory)
255
256     def remove(self, *args, **kwargs):
257         """Remove an entry."""
258         if self.is_writeable() == False:
259             raise NotSupported('write',
260                                'Cannot remove entry from unwriteable storage.')
261         self._remove(*args, **kwargs)
262
263     def _remove(self, id):
264         if self._data[id].directory == True \
265                 and len(self.children(id)) > 0:
266             raise DirectoryNotEmpty(id)
267         e = self._data.pop(id)
268         e.parent.remove(e)
269
270     def recursive_remove(self, *args, **kwargs):
271         """Remove an entry and all its decendents."""
272         if self.is_writeable() == False:
273             raise NotSupported('write',
274                                'Cannot remove entries from unwriteable storage.')
275         self._recursive_remove(*args, **kwargs)
276
277     def _recursive_remove(self, id):
278         for entry in reversed(list(self._data[id].traverse())):
279             self._remove(entry.id)
280
281     def ancestors(self, *args, **kwargs):
282         """Return a list of the specified entry's ancestors' ids."""
283         if self.is_readable() == False:
284             raise NotReadable('Cannot list parents with unreadable storage.')
285         return self._ancestors(*args, **kwargs)
286
287     def _ancestors(self, id=None, revision=None):
288         if id == None:
289             return []
290         ancestors = []
291         stack = [id]
292         while len(stack) > 0:
293             id = stack.pop(0)
294             parent = self._data[id].parent
295             if parent != None and not parent.id.startswith('__'):
296                 ancestor = parent.id
297                 ancestors.append(ancestor)
298                 stack.append(ancestor)
299         return ancestors
300
301     def children(self, *args, **kwargs):
302         """Return a list of specified entry's children's ids."""
303         if self.is_readable() == False:
304             raise NotReadable('Cannot list children with unreadable storage.')
305         return self._children(*args, **kwargs)
306
307     def _children(self, id=None, revision=None):
308         if id == None:
309             id = '__ROOT__'
310         return [c.id for c in self._data[id] if not c.id.startswith('__')]
311
312     def get(self, *args, **kwargs):
313         """
314         Get contents of and entry as they were in a given revision.
315         revision==None specifies the current revision.
316
317         If there is no id, return default, unless default is not
318         given, in which case raise InvalidID.
319         """
320         if self.is_readable() == False:
321             raise NotReadable('Cannot get entry with unreadable storage.')
322         if 'decode' in kwargs:
323             decode = kwargs.pop('decode')
324         else:
325             decode = False
326         value = self._get(*args, **kwargs)
327         if value != None:
328             if decode == True and type(value) != types.UnicodeType:
329                 return unicode(value, self.encoding)
330             elif decode == False and type(value) != types.StringType:
331                 return value.encode(self.encoding)
332         return value
333
334     def _get(self, id, default=InvalidObject, revision=None):
335         if id in self._data:
336             return self._data[id].value
337         elif default == InvalidObject:
338             raise InvalidID(id)
339         return default
340
341     def set(self, id, value, *args, **kwargs):
342         """
343         Set the entry contents.
344         """
345         if self.is_writeable() == False:
346             raise NotWriteable('Cannot set entry in unwriteable storage.')
347         if type(value) == types.UnicodeType:
348             value = value.encode(self.encoding)
349         self._set(id, value, *args, **kwargs)
350
351     def _set(self, id, value):
352         if id not in self._data:
353             raise InvalidID(id)
354         if self._data[id].directory == True:
355             raise InvalidDirectory(
356                 'Directory %s cannot have data' % self.parent)
357         self._data[id].value = value
358
359 class VersionedStorage (Storage):
360     """
361     This class declares all the methods required by a Storage
362     interface that supports versioning.  This implementation just
363     keeps the data in a list and uses pickle for persistent
364     storage.
365     """
366     name = 'VersionedStorage'
367
368     def __init__(self, *args, **kwargs):
369         Storage.__init__(self, *args, **kwargs)
370         self.versioned = True
371
372     def _init(self):
373         f = open(os.path.join(self.repo, 'repo.pkl'), 'wb')
374         root = Entry(id='__ROOT__', directory=True)
375         summary = Entry(id='__COMMIT__SUMMARY__', value='Initial commit')
376         body = Entry(id='__COMMIT__BODY__')
377         initial_commit = {root.id:root, summary.id:summary, body.id:body}
378         d = dict((k,v._objects_to_ids()) for k,v in initial_commit.items())
379         pickle.dump([d, copy.deepcopy(d)], f, -1) # [inital tree, working tree]
380         f.close()
381
382     def _connect(self):
383         try:
384             f = open(os.path.join(self.repo, 'repo.pkl'), 'rb')
385         except IOError:
386             raise ConnectionError(self)
387         d = pickle.load(f)
388         self._data = [dict((k,v._ids_to_objects(t)) for k,v in t.items())
389                       for t in d]
390         f.close()
391
392     def _disconnect(self):
393         f = open(os.path.join(self.repo, 'repo.pkl'), 'wb')
394         pickle.dump([dict((k,v._objects_to_ids())
395                           for k,v in t.items()) for t in self._data], f, -1)
396         f.close()
397         self._data = None
398
399     def _add(self, id, parent=None, directory=False):
400         if parent == None:
401             parent = '__ROOT__'
402         p = self._data[-1][parent]
403         self._data[-1][id] = Entry(id, parent=p, directory=directory)
404
405     def _remove(self, id):
406         if self._data[-1][id].directory == True \
407                 and len(self.children(id)) > 0:
408             raise DirectoryNotEmpty(id)
409         e = self._data[-1].pop(id)
410         e.parent.remove(e)
411
412     def _recursive_remove(self, id):
413         for entry in reversed(list(self._data[-1][id].traverse())):
414             self._remove(entry.id)
415
416     def _ancestors(self, id=None, revision=None):
417         if id == None:
418             return []
419         if revision == None:
420             revision = -1
421         else:
422             revision = int(revision)
423         ancestors = []
424         stack = [id]
425         while len(stack) > 0:
426             id = stack.pop(0)
427             parent = self._data[revision][id].parent
428             if parent != None and not parent.id.startswith('__'):
429                 ancestor = parent.id
430                 ancestors.append(ancestor)
431                 stack.append(ancestor)
432         return ancestors
433
434     def _children(self, id=None, revision=None):
435         if id == None:
436             id = '__ROOT__'
437         if revision == None:
438             revision = -1
439         else:
440             revision = int(revision)
441         return [c.id for c in self._data[revision][id]
442                 if not c.id.startswith('__')]
443
444     def _get(self, id, default=InvalidObject, revision=None):
445         if revision == None:
446             revision = -1
447         else:
448             revision = int(revision)
449         if id in self._data[revision]:
450             return self._data[revision][id].value
451         elif default == InvalidObject:
452             raise InvalidID(id)
453         return default
454
455     def _set(self, id, value):
456         if id not in self._data[-1]:
457             raise InvalidID(id)
458         self._data[-1][id].value = value
459
460     def commit(self, *args, **kwargs):
461         """
462         Commit the current repository, with a commit message string
463         summary and body.  Return the name of the new revision.
464
465         If allow_empty == False (the default), raise EmptyCommit if
466         there are no changes to commit.
467         """
468         if self.is_writeable() == False:
469             raise NotWriteable('Cannot commit to unwriteable storage.')
470         return self._commit(*args, **kwargs)
471
472     def _commit(self, summary, body=None, allow_empty=False):
473         if self._data[-1] == self._data[-2] and allow_empty == False:
474             raise EmptyCommit
475         self._data[-1]["__COMMIT__SUMMARY__"].value = summary
476         self._data[-1]["__COMMIT__BODY__"].value = body
477         rev = str(len(self._data)-1)
478         self._data.append(copy.deepcopy(self._data[-1]))
479         return rev
480
481     def revision_id(self, index=None):
482         """
483         Return the name of the <index>th revision.  The choice of
484         which branch to follow when crossing branches/merges is not
485         defined.  Revision indices start at 1; ID 0 is the blank
486         repository.
487
488         Return None if index==None.
489
490         If the specified revision does not exist, raise InvalidRevision.
491         """
492         if index == None:
493             return None
494         try:
495             if int(index) != index:
496                 raise InvalidRevision(index)
497         except ValueError:
498             raise InvalidRevision(index)
499         L = len(self._data) - 1  # -1 b/c of initial commit
500         if index >= -L and index <= L:
501             return str(index % L)
502         raise InvalidRevision(i)
503
504     def changed(self, revision):
505         """
506         Return a tuple of lists of ids
507           (new, modified, removed)
508         from the specified revision to the current situation.
509         """
510         new = []
511         modified = []
512         removed = []
513         for id,value in self._data[int(revision)].items():
514             if id.startswith('__'):
515                 continue
516             if not id in self._data[-1]:
517                 removed.append(id)
518             elif value.value != self._data[-1][id].value:
519                 modified.append(id)
520         for id in self._data[-1]:
521             if not id in self._data[int(revision)]:
522                 new.append(id)
523         return (new, modified, removed)
524
525
526 if TESTING == True:
527     class StorageTestCase (unittest.TestCase):
528         """Test cases for Storage class."""
529
530         Class = Storage
531
532         def __init__(self, *args, **kwargs):
533             super(StorageTestCase, self).__init__(*args, **kwargs)
534             self.dirname = None
535
536         # this class will be the basis of tests for several classes,
537         # so make sure we print the name of the class we're dealing with.
538         def fail(self, msg=None):
539             """Fail immediately, with the given message."""
540             raise self.failureException, \
541                 '(%s) %s' % (self.Class.__name__, msg)
542
543         def failIf(self, expr, msg=None):
544             "Fail the test if the expression is true."
545             if expr: raise self.failureException, \
546                 '(%s) %s' % (self.Class.__name__, msg)
547
548         def failUnless(self, expr, msg=None):
549             """Fail the test unless the expression is true."""
550             if not expr: raise self.failureException, \
551                 '(%s) %s' % (self.Class.__name__, msg)
552
553         def setUp(self):
554             """Set up test fixtures for Storage test case."""
555             super(StorageTestCase, self).setUp()
556             self.dir = Dir()
557             self.dirname = self.dir.path
558             self.s = self.Class(repo=self.dirname)
559             self.assert_failed_connect()
560             self.s.init()
561             self.s.connect()
562
563         def tearDown(self):
564             super(StorageTestCase, self).tearDown()
565             self.s.disconnect()
566             self.s.destroy()
567             self.assert_failed_connect()
568             self.dir.cleanup()
569
570         def assert_failed_connect(self):
571             try:
572                 self.s.connect()
573                 self.fail(
574                     "Connected to %(name)s repository before initialising"
575                     % vars(self.Class))
576             except ConnectionError:
577                 pass
578
579     class Storage_init_TestCase (StorageTestCase):
580         """Test cases for Storage.init method."""
581
582         def test_connect_should_succeed_after_init(self):
583             """Should connect after initialization."""
584             self.s.connect()
585
586     class Storage_connect_disconnect_TestCase (StorageTestCase):
587         """Test cases for Storage.connect and .disconnect methods."""
588
589         def test_multiple_disconnects(self):
590             """Should be able to call .disconnect multiple times."""
591             self.s.disconnect()
592             self.s.disconnect()
593
594     class Storage_add_remove_TestCase (StorageTestCase):
595         """Test cases for Storage.add, .remove, and .recursive_remove methods."""
596
597         def test_initially_empty(self):
598             """New repository should be empty."""
599             self.failUnless(len(self.s.children()) == 0, self.s.children())
600
601         def test_add_identical_rooted(self):
602             """Adding entries with the same ID should not increase the number of children.
603             """
604             for i in range(10):
605                 self.s.add('some id', directory=False)
606                 s = sorted(self.s.children())
607                 self.failUnless(s == ['some id'], s)
608
609         def test_add_rooted(self):
610             """Adding entries should increase the number of children (rooted).
611             """
612             ids = []
613             for i in range(10):
614                 ids.append(str(i))
615                 self.s.add(ids[-1], directory=(i % 2 == 0))
616                 s = sorted(self.s.children())
617                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
618
619         def test_add_nonrooted(self):
620             """Adding entries should increase the number of children (nonrooted).
621             """
622             self.s.add('parent', directory=True)
623             ids = []
624             for i in range(10):
625                 ids.append(str(i))
626                 self.s.add(ids[-1], 'parent', directory=(i % 2 == 0))
627                 s = sorted(self.s.children('parent'))
628                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
629                 s = self.s.children()
630                 self.failUnless(s == ['parent'], s)
631
632         def test_ancestors(self):
633             """Check ancestors lists.
634             """
635             self.s.add('parent', directory=True)
636             for i in range(10):
637                 i_id = str(i)
638                 self.s.add(i_id, 'parent', directory=True)
639                 for j in range(10): # add some grandkids
640                     j_id = str(20*(i+1)+j)
641                     self.s.add(j_id, i_id, directory=(i%2 == 0))
642                     ancestors = sorted(self.s.ancestors(j_id))
643                     self.failUnless(ancestors == [i_id, 'parent'],
644                         'Unexpected ancestors for %s/%s, "%s"'
645                         % (i_id, j_id, ancestors))
646
647         def test_children(self):
648             """Non-UUID ids should be returned as such.
649             """
650             self.s.add('parent', directory=True)
651             ids = []
652             for i in range(10):
653                 ids.append('parent/%s' % str(i))
654                 self.s.add(ids[-1], 'parent', directory=(i % 2 == 0))
655                 s = sorted(self.s.children('parent'))
656                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
657
658         def test_add_invalid_directory(self):
659             """Should not be able to add children to non-directories.
660             """
661             self.s.add('parent', directory=False)
662             try:
663                 self.s.add('child', 'parent', directory=False)
664                 self.fail(
665                     '%s.add() succeeded instead of raising InvalidDirectory'
666                     % (vars(self.Class)['name']))
667             except InvalidDirectory:
668                 pass
669             try:
670                 self.s.add('child', 'parent', directory=True)
671                 self.fail(
672                     '%s.add() succeeded instead of raising InvalidDirectory'
673                     % (vars(self.Class)['name']))
674             except InvalidDirectory:
675                 pass
676             self.failUnless(len(self.s.children('parent')) == 0,
677                             self.s.children('parent'))
678
679         def test_remove_rooted(self):
680             """Removing entries should decrease the number of children (rooted).
681             """
682             ids = []
683             for i in range(10):
684                 ids.append(str(i))
685                 self.s.add(ids[-1], directory=(i % 2 == 0))
686             for i in range(10):
687                 self.s.remove(ids.pop())
688                 s = sorted(self.s.children())
689                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
690
691         def test_remove_nonrooted(self):
692             """Removing entries should decrease the number of children (nonrooted).
693             """
694             self.s.add('parent', directory=True)
695             ids = []
696             for i in range(10):
697                 ids.append(str(i))
698                 self.s.add(ids[-1], 'parent', directory=False)#(i % 2 == 0))
699             for i in range(10):
700                 self.s.remove(ids.pop())
701                 s = sorted(self.s.children('parent'))
702                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
703                 if len(s) > 0:
704                     s = self.s.children()
705                     self.failUnless(s == ['parent'], s)
706
707         def test_remove_directory_not_empty(self):
708             """Removing a non-empty directory entry should raise exception.
709             """
710             self.s.add('parent', directory=True)
711             ids = []
712             for i in range(10):
713                 ids.append(str(i))
714                 self.s.add(ids[-1], 'parent', directory=(i % 2 == 0))
715             self.s.remove(ids.pop()) # empty directory removal succeeds
716             try:
717                 self.s.remove('parent') # empty directory removal succeeds
718                 self.fail(
719                     "%s.remove() didn't raise DirectoryNotEmpty"
720                     % (vars(self.Class)['name']))
721             except DirectoryNotEmpty:
722                 pass
723
724         def test_recursive_remove(self):
725             """Recursive remove should empty the tree."""
726             self.s.add('parent', directory=True)
727             ids = []
728             for i in range(10):
729                 ids.append(str(i))
730                 self.s.add(ids[-1], 'parent', directory=True)
731                 for j in range(10): # add some grandkids
732                     self.s.add(str(20*(i+1)+j), ids[-1], directory=(i%2 == 0))
733             self.s.recursive_remove('parent')
734             s = sorted(self.s.children())
735             self.failUnless(s == [], s)
736
737     class Storage_get_set_TestCase (StorageTestCase):
738         """Test cases for Storage.get and .set methods."""
739
740         id = 'unlikely id'
741         val = 'unlikely value'
742
743         def test_get_default(self):
744             """Get should return specified default if id not in Storage.
745             """
746             ret = self.s.get(self.id, default=self.val)
747             self.failUnless(ret == self.val,
748                     "%s.get() returned %s not %s"
749                     % (vars(self.Class)['name'], ret, self.val))
750
751         def test_get_default_exception(self):
752             """Get should raise exception if id not in Storage and no default.
753             """
754             try:
755                 ret = self.s.get(self.id)
756                 self.fail(
757                     "%s.get() returned %s instead of raising InvalidID"
758                     % (vars(self.Class)['name'], ret))
759             except InvalidID:
760                 pass
761
762         def test_get_initial_value(self):
763             """Data value should be None before any value has been set.
764             """
765             self.s.add(self.id, directory=False)
766             ret = self.s.get(self.id)
767             self.failUnless(ret == None,
768                     "%s.get() returned %s not None"
769                     % (vars(self.Class)['name'], ret))
770
771         def test_set_exception(self):
772             """Set should raise exception if id not in Storage.
773             """
774             try:
775                 self.s.set(self.id, self.val)
776                 self.fail(
777                     "%(name)s.set() did not raise InvalidID"
778                     % vars(self.Class))
779             except InvalidID:
780                 pass
781
782         def test_set(self):
783             """Set should define the value returned by get.
784             """
785             self.s.add(self.id, directory=False)
786             self.s.set(self.id, self.val)
787             ret = self.s.get(self.id)
788             self.failUnless(ret == self.val,
789                     "%s.get() returned %s not %s"
790                     % (vars(self.Class)['name'], ret, self.val))
791
792         def test_unicode_set(self):
793             """Set should define the value returned by get.
794             """
795             val = u'Fran\xe7ois'
796             self.s.add(self.id, directory=False)
797             self.s.set(self.id, val)
798             ret = self.s.get(self.id, decode=True)
799             self.failUnless(type(ret) == types.UnicodeType,
800                     "%s.get() returned %s not UnicodeType"
801                     % (vars(self.Class)['name'], type(ret)))
802             self.failUnless(ret == val,
803                     "%s.get() returned %s not %s"
804                     % (vars(self.Class)['name'], ret, self.val))
805             ret = self.s.get(self.id)
806             self.failUnless(type(ret) == types.StringType,
807                     "%s.get() returned %s not StringType"
808                     % (vars(self.Class)['name'], type(ret)))
809             s = unicode(ret, self.s.encoding)
810             self.failUnless(s == val,
811                     "%s.get() returned %s not %s"
812                     % (vars(self.Class)['name'], s, self.val))
813
814
815     class Storage_persistence_TestCase (StorageTestCase):
816         """Test cases for Storage.disconnect and .connect methods."""
817
818         id = 'unlikely id'
819         val = 'unlikely value'
820
821         def test_get_set_persistence(self):
822             """Set should define the value returned by get after reconnect.
823             """
824             self.s.add(self.id, directory=False)
825             self.s.set(self.id, self.val)
826             self.s.disconnect()
827             self.s.connect()
828             ret = self.s.get(self.id)
829             self.failUnless(ret == self.val,
830                     "%s.get() returned %s not %s"
831                     % (vars(self.Class)['name'], ret, self.val))
832
833         def test_add_nonrooted_persistence(self):
834             """Adding entries should increase the number of children after reconnect.
835             """
836             self.s.add('parent', directory=True)
837             ids = []
838             for i in range(10):
839                 ids.append(str(i))
840                 self.s.add(ids[-1], 'parent', directory=(i % 2 == 0))
841             self.s.disconnect()
842             self.s.connect()
843             s = sorted(self.s.children('parent'))
844             self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
845             s = self.s.children()
846             self.failUnless(s == ['parent'], s)
847
848     class VersionedStorageTestCase (StorageTestCase):
849         """Test cases for VersionedStorage methods."""
850
851         Class = VersionedStorage
852
853     class VersionedStorage_commit_TestCase (VersionedStorageTestCase):
854         """Test cases for VersionedStorage.commit and revision_ids methods."""
855
856         id = 'unlikely id'
857         val = 'Some value'
858         commit_msg = 'Committing something interesting'
859         commit_body = 'Some\nlonger\ndescription\n'
860
861         def _setup_for_empty_commit(self):
862             """
863             Initialization might add some files to version control, so
864             commit those first, before testing the empty commit
865             functionality.
866             """
867             try:
868                 self.s.commit('Added initialization files')
869             except EmptyCommit:
870                 pass
871                 
872         def test_revision_id_exception(self):
873             """Invalid revision id should raise InvalidRevision.
874             """
875             try:
876                 rev = self.s.revision_id('highly unlikely revision id')
877                 self.fail(
878                     "%s.revision_id() didn't raise InvalidRevision, returned %s."
879                     % (vars(self.Class)['name'], rev))
880             except InvalidRevision:
881                 pass
882
883         def test_empty_commit_raises_exception(self):
884             """Empty commit should raise exception.
885             """
886             self._setup_for_empty_commit()
887             try:
888                 self.s.commit(self.commit_msg, self.commit_body)
889                 self.fail(
890                     "Empty %(name)s.commit() didn't raise EmptyCommit."
891                     % vars(self.Class))
892             except EmptyCommit:
893                 pass
894
895         def test_empty_commit_allowed(self):
896             """Empty commit should _not_ raise exception if allow_empty=True.
897             """
898             self._setup_for_empty_commit()
899             self.s.commit(self.commit_msg, self.commit_body,
900                           allow_empty=True)
901
902         def test_commit_revision_ids(self):
903             """Commit / revision_id should agree on revision ids.
904             """
905             def val(i):
906                 return '%s:%d' % (self.val, i+1)
907             self.s.add(self.id, directory=False)
908             revs = []
909             for i in range(10):
910                 self.s.set(self.id, val(i))
911                 revs.append(self.s.commit('%s: %d' % (self.commit_msg, i),
912                                           self.commit_body))
913             for i in range(10):
914                 rev = self.s.revision_id(i+1)
915                 self.failUnless(rev == revs[i],
916                                 "%s.revision_id(%d) returned %s not %s"
917                                 % (vars(self.Class)['name'], i+1, rev, revs[i]))
918             for i in range(-1, -9, -1):
919                 rev = self.s.revision_id(i)
920                 self.failUnless(rev == revs[i],
921                                 "%s.revision_id(%d) returned %s not %s"
922                                 % (vars(self.Class)['name'], i, rev, revs[i]))
923
924         def test_get_previous_version(self):
925             """Get should be able to return the previous version.
926             """
927             def val(i):
928                 return '%s:%d' % (self.val, i+1)
929             self.s.add(self.id, directory=False)
930             revs = []
931             for i in range(10):
932                 self.s.set(self.id, val(i))
933                 revs.append(self.s.commit('%s: %d' % (self.commit_msg, i),
934                                           self.commit_body))
935             for i in range(10):
936                 ret = self.s.get(self.id, revision=revs[i])
937                 self.failUnless(ret == val(i),
938                                 "%s.get() returned %s not %s for revision %s"
939                                 % (vars(self.Class)['name'], ret, val(i), revs[i]))
940
941         def test_get_previous_children(self):
942             """Children list should be revision dependent.
943             """
944             self.s.add('parent', directory=True)
945             revs = []
946             cur_children = []
947             children = []
948             for i in range(10):
949                 new_child = str(i)
950                 self.s.add(new_child, 'parent')
951                 self.s.set(new_child, self.val)
952                 revs.append(self.s.commit('%s: %d' % (self.commit_msg, i),
953                                           self.commit_body))
954                 cur_children.append(new_child)
955                 children.append(list(cur_children))
956             for i in range(10):
957                 ret = self.s.children('parent', revision=revs[i])
958                 self.failUnless(ret == children[i],
959                                 "%s.get() returned %s not %s for revision %s"
960                                 % (vars(self.Class)['name'], ret,
961                                    children[i], revs[i]))
962
963     class VersionedStorage_changed_TestCase (VersionedStorageTestCase):
964         """Test cases for VersionedStorage.changed() method."""
965
966         def test_changed(self):
967             """Changed lists should reflect past activity"""
968             self.s.add('dir', directory=True)
969             self.s.add('modified', parent='dir')
970             self.s.set('modified', 'some value to be modified')
971             self.s.add('moved', parent='dir')
972             self.s.set('moved', 'this entry will be moved')
973             self.s.add('removed', parent='dir')
974             self.s.set('removed', 'this entry will be deleted')
975             revA = self.s.commit('Initial state')
976             self.s.add('new', parent='dir')
977             self.s.set('new', 'this entry is new')
978             self.s.set('modified', 'a new value')
979             self.s.remove('moved')
980             self.s.add('moved2', parent='dir')
981             self.s.set('moved2', 'this entry will be moved')
982             self.s.remove('removed')
983             revB = self.s.commit('Final state')
984             new,mod,rem = self.s.changed(revA)
985             self.failUnless(sorted(new) == ['moved2', 'new'],
986                             'Unexpected new: %s' % new)
987             self.failUnless(mod == ['modified'],
988                             'Unexpected modified: %s' % mod)
989             self.failUnless(sorted(rem) == ['moved', 'removed'],
990                             'Unexpected removed: %s' % rem)
991
992     def make_storage_testcase_subclasses(storage_class, namespace):
993         """Make StorageTestCase subclasses for storage_class in namespace."""
994         storage_testcase_classes = [
995             c for c in (
996                 ob for ob in globals().values() if isinstance(ob, type))
997             if issubclass(c, StorageTestCase) \
998                 and c.Class == Storage]
999
1000         for base_class in storage_testcase_classes:
1001             testcase_class_name = storage_class.__name__ + base_class.__name__
1002             testcase_class_bases = (base_class,)
1003             testcase_class_dict = dict(base_class.__dict__)
1004             testcase_class_dict['Class'] = storage_class
1005             testcase_class = type(
1006                 testcase_class_name, testcase_class_bases, testcase_class_dict)
1007             setattr(namespace, testcase_class_name, testcase_class)
1008
1009     def make_versioned_storage_testcase_subclasses(storage_class, namespace):
1010         """Make VersionedStorageTestCase subclasses for storage_class in namespace."""
1011         storage_testcase_classes = [
1012             c for c in (
1013                 ob for ob in globals().values() if isinstance(ob, type))
1014             if ((issubclass(c, StorageTestCase) \
1015                      and c.Class == Storage)
1016                 or
1017                 (issubclass(c, VersionedStorageTestCase) \
1018                      and c.Class == VersionedStorage))]
1019
1020         for base_class in storage_testcase_classes:
1021             testcase_class_name = storage_class.__name__ + base_class.__name__
1022             testcase_class_bases = (base_class,)
1023             testcase_class_dict = dict(base_class.__dict__)
1024             testcase_class_dict['Class'] = storage_class
1025             testcase_class = type(
1026                 testcase_class_name, testcase_class_bases, testcase_class_dict)
1027             setattr(namespace, testcase_class_name, testcase_class)
1028
1029     make_storage_testcase_subclasses(VersionedStorage, sys.modules[__name__])
1030
1031     unitsuite =unittest.TestLoader().loadTestsFromModule(sys.modules[__name__])
1032     suite = unittest.TestSuite([unitsuite, doctest.DocTestSuite()])