Add VersionedStorageTestCases in make_versioned_storage_testcase_subclasses
[be.git] / libbe / storage / base.py
1 # Copyright (C) 2009-2010 W. Trevor King <wking@drexel.edu>
2 #
3 # This program is free software; you can redistribute it and/or modify
4 # it under the terms of the GNU General Public License as published by
5 # the Free Software Foundation; either version 2 of the License, or
6 # (at your option) any later version.
7 #
8 # This program is distributed in the hope that it will be useful,
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 # GNU General Public License for more details.
12 #
13 # You should have received a copy of the GNU General Public License along
14 # with this program; if not, write to the Free Software Foundation, Inc.,
15 # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
16
17 """
18 Abstract bug repository data storage to easily support multiple backends.
19 """
20
21 import copy
22 import os
23 import pickle
24 import types
25
26 from libbe.error import NotSupported
27 import libbe.storage
28 from libbe.util.tree import Tree
29 from libbe.util import InvalidObject
30 import libbe.version
31 from libbe import TESTING
32
33 if TESTING == True:
34     import doctest
35     import os.path
36     import sys
37     import unittest
38
39     from libbe.util.utility import Dir
40
41 class ConnectionError (Exception):
42     pass
43
44 class InvalidStorageVersion(ConnectionError):
45     def __init__(self, active_version, expected_version=None):
46         if expected_version == None:
47             expected_version = libbe.storage.STORAGE_VERSION
48         msg = 'Storage in "%s" not the expected "%s"' \
49             % (active_version, expected_version)
50         Exception.__init__(self, msg)
51         self.active_version = active_version
52         self.expected_version = expected_version
53
54 class InvalidID (KeyError):
55     def __init__(self, id=None, revision=None, msg=None):
56         if msg == None and id != None:
57             msg = id
58         KeyError.__init__(self, msg)
59         self.id = id
60         self.revision = revision
61
62 class InvalidRevision (KeyError):
63     pass
64
65 class InvalidDirectory (Exception):
66     pass
67
68 class DirectoryNotEmpty (InvalidDirectory):
69     pass
70
71 class NotWriteable (NotSupported):
72     def __init__(self, msg):
73         NotSupported.__init__(self, 'write', msg)
74
75 class NotReadable (NotSupported):
76     def __init__(self, msg):
77         NotSupported.__init__(self, 'read', msg)
78
79 class EmptyCommit(Exception):
80     def __init__(self):
81         Exception.__init__(self, 'No changes to commit')
82
83
84 class Entry (Tree):
85     def __init__(self, id, value=None, parent=None, directory=False,
86                  children=None):
87         if children == None:
88             Tree.__init__(self)
89         else:
90             Tree.__init__(self, children)
91         self.id = id
92         self.value = value
93         self.parent = parent
94         if self.parent != None:
95             if self.parent.directory == False:
96                 raise InvalidDirectory(
97                     'Non-directory %s cannot have children' % self.parent)
98             parent.append(self)
99         self.directory = directory
100
101     def __str__(self):
102         return '<Entry %s: %s>' % (self.id, self.value)
103
104     def __repr__(self):
105         return str(self)
106
107     def __cmp__(self, other, local=False):
108         if other == None:
109             return cmp(1, None)
110         if cmp(self.id, other.id) != 0:
111             return cmp(self.id, other.id)
112         if cmp(self.value, other.value) != 0:
113             return cmp(self.value, other.value)
114         if local == False:
115             if self.parent == None:
116                 if cmp(self.parent, other.parent) != 0:
117                     return cmp(self.parent, other.parent)
118             elif self.parent.__cmp__(other.parent, local=True) != 0:
119                 return self.parent.__cmp__(other.parent, local=True)
120             for sc,oc in zip(self, other):
121                 if sc.__cmp__(oc, local=True) != 0:
122                     return sc.__cmp__(oc, local=True)
123         return 0
124
125     def _objects_to_ids(self):
126         if self.parent != None:
127             self.parent = self.parent.id
128         for i,c in enumerate(self):
129             self[i] = c.id
130         return self
131
132     def _ids_to_objects(self, dict):
133         if self.parent != None:
134             self.parent = dict[self.parent]
135         for i,c in enumerate(self):
136             self[i] = dict[c]
137         return self
138
139 class Storage (object):
140     """
141     This class declares all the methods required by a Storage
142     interface.  This implementation just keeps the data in a
143     dictionary and uses pickle for persistent storage.
144     """
145     name = 'Storage'
146
147     def __init__(self, repo='/', encoding='utf-8', options=None):
148         self.repo = repo
149         self.encoding = encoding
150         self.options = options
151         self.readable = True  # soft limit (user choice)
152         self._readable = True # hard limit (backend choice)
153         self.writeable = True  # soft limit (user choice)
154         self._writeable = True # hard limit (backend choice)
155         self.versioned = False
156         self.can_init = True
157         self.connected = False
158
159     def __str__(self):
160         return '<%s %s %s>' % (self.__class__.__name__, id(self), self.repo)
161
162     def __repr__(self):
163         return str(self)
164
165     def version(self):
166         """Return a version string for this backend."""
167         return libbe.version.version()
168
169     def storage_version(self, revision=None):
170         """Return the storage format for this backend."""
171         return libbe.storage.STORAGE_VERSION
172
173     def is_readable(self):
174         return self.readable and self._readable
175
176     def is_writeable(self):
177         return self.writeable and self._writeable
178
179     def init(self):
180         """Create a new storage repository."""
181         if self.can_init == False:
182             raise NotSupported('init',
183                                'Cannot initialize this repository format.')
184         if self.is_writeable() == False:
185             raise NotWriteable('Cannot initialize unwriteable storage.')
186         return self._init()
187
188     def _init(self):
189         f = open(os.path.join(self.repo, 'repo.pkl'), 'wb')
190         root = Entry(id='__ROOT__', directory=True)
191         d = {root.id:root}
192         pickle.dump(dict((k,v._objects_to_ids()) for k,v in d.items()), f, -1)
193         f.close()
194
195     def destroy(self):
196         """Remove the storage repository."""
197         if self.is_writeable() == False:
198             raise NotWriteable('Cannot destroy unwriteable storage.')
199         return self._destroy()
200
201     def _destroy(self):
202         os.remove(os.path.join(self.repo, 'repo.pkl'))
203
204     def connect(self):
205         """Open a connection to the repository."""
206         if self.is_readable() == False:
207             raise NotReadable('Cannot connect to unreadable storage.')
208         self._connect()
209         self.connected = True
210
211     def _connect(self):
212         try:
213             f = open(os.path.join(self.repo, 'repo.pkl'), 'rb')
214         except IOError:
215             raise ConnectionError(self)
216         d = pickle.load(f)
217         self._data = dict((k,v._ids_to_objects(d)) for k,v in d.items())
218         f.close()
219
220     def disconnect(self):
221         """Close the connection to the repository."""
222         if self.is_writeable() == False:
223             return
224         if self.connected == False:
225             return
226         self._disconnect()
227         self.connected = False
228
229     def _disconnect(self):
230         f = open(os.path.join(self.repo, 'repo.pkl'), 'wb')
231         pickle.dump(dict((k,v._objects_to_ids())
232                          for k,v in self._data.items()), f, -1)
233         f.close()
234         self._data = None
235
236     def add(self, id, *args, **kwargs):
237         """Add an entry"""
238         if self.is_writeable() == False:
239             raise NotWriteable('Cannot add entry to unwriteable storage.')
240         try:  # Maybe we've already added that id?
241             self.get(id)
242             pass # yup, no need to add another
243         except InvalidID:
244             self._add(id, *args, **kwargs)
245
246     def _add(self, id, parent=None, directory=False):
247         if parent == None:
248             parent = '__ROOT__'
249         p = self._data[parent]
250         self._data[id] = Entry(id, parent=p, directory=directory)
251
252     def remove(self, *args, **kwargs):
253         """Remove an entry."""
254         if self.is_writeable() == False:
255             raise NotSupported('write',
256                                'Cannot remove entry from unwriteable storage.')
257         self._remove(*args, **kwargs)
258
259     def _remove(self, id):
260         if self._data[id].directory == True \
261                 and len(self.children(id)) > 0:
262             raise DirectoryNotEmpty(id)
263         e = self._data.pop(id)
264         e.parent.remove(e)
265
266     def recursive_remove(self, *args, **kwargs):
267         """Remove an entry and all its decendents."""
268         if self.is_writeable() == False:
269             raise NotSupported('write',
270                                'Cannot remove entries from unwriteable storage.')
271         self._recursive_remove(*args, **kwargs)
272
273     def _recursive_remove(self, id):
274         for entry in reversed(list(self._data[id].traverse())):
275             self._remove(entry.id)
276
277     def children(self, *args, **kwargs):
278         """Return a list of specified entry's children's ids."""
279         if self.is_readable() == False:
280             raise NotReadable('Cannot list children with unreadable storage.')
281         return self._children(*args, **kwargs)
282
283     def _children(self, id=None, revision=None):
284         if id == None:
285             id = '__ROOT__'
286         return [c.id for c in self._data[id] if not c.id.startswith('__')]
287
288     def get(self, *args, **kwargs):
289         """
290         Get contents of and entry as they were in a given revision.
291         revision==None specifies the current revision.
292
293         If there is no id, return default, unless default is not
294         given, in which case raise InvalidID.
295         """
296         if self.is_readable() == False:
297             raise NotReadable('Cannot get entry with unreadable storage.')
298         if 'decode' in kwargs:
299             decode = kwargs.pop('decode')
300         else:
301             decode = False
302         value = self._get(*args, **kwargs)
303         if value != None:
304             if decode == True and type(value) != types.UnicodeType:
305                 return unicode(value, self.encoding)
306             elif decode == False and type(value) != types.StringType:
307                 return value.encode(self.encoding)
308         return value
309
310     def _get(self, id, default=InvalidObject, revision=None):
311         if id in self._data:
312             return self._data[id].value
313         elif default == InvalidObject:
314             raise InvalidID(id)
315         return default
316
317     def set(self, id, value, *args, **kwargs):
318         """
319         Set the entry contents.
320         """
321         if self.is_writeable() == False:
322             raise NotWriteable('Cannot set entry in unwriteable storage.')
323         if type(value) == types.UnicodeType:
324             value = value.encode(self.encoding)
325         self._set(id, value, *args, **kwargs)
326
327     def _set(self, id, value):
328         if id not in self._data:
329             raise InvalidID(id)
330         if self._data[id].directory == True:
331             raise InvalidDirectory(
332                 'Directory %s cannot have data' % self.parent)
333         self._data[id].value = value
334
335 class VersionedStorage (Storage):
336     """
337     This class declares all the methods required by a Storage
338     interface that supports versioning.  This implementation just
339     keeps the data in a list and uses pickle for persistent
340     storage.
341     """
342     name = 'VersionedStorage'
343
344     def __init__(self, *args, **kwargs):
345         Storage.__init__(self, *args, **kwargs)
346         self.versioned = True
347
348     def _init(self):
349         f = open(os.path.join(self.repo, 'repo.pkl'), 'wb')
350         root = Entry(id='__ROOT__', directory=True)
351         summary = Entry(id='__COMMIT__SUMMARY__', value='Initial commit')
352         body = Entry(id='__COMMIT__BODY__')
353         initial_commit = {root.id:root, summary.id:summary, body.id:body}
354         d = dict((k,v._objects_to_ids()) for k,v in initial_commit.items())
355         pickle.dump([d, copy.deepcopy(d)], f, -1) # [inital tree, working tree]
356         f.close()
357
358     def _connect(self):
359         try:
360             f = open(os.path.join(self.repo, 'repo.pkl'), 'rb')
361         except IOError:
362             raise ConnectionError(self)
363         d = pickle.load(f)
364         self._data = [dict((k,v._ids_to_objects(t)) for k,v in t.items())
365                       for t in d]
366         f.close()
367
368     def _disconnect(self):
369         f = open(os.path.join(self.repo, 'repo.pkl'), 'wb')
370         pickle.dump([dict((k,v._objects_to_ids())
371                           for k,v in t.items()) for t in self._data], f, -1)
372         f.close()
373         self._data = None
374
375     def _add(self, id, parent=None, directory=False):
376         if parent == None:
377             parent = '__ROOT__'
378         p = self._data[-1][parent]
379         self._data[-1][id] = Entry(id, parent=p, directory=directory)
380
381     def _remove(self, id):
382         if self._data[-1][id].directory == True \
383                 and len(self.children(id)) > 0:
384             raise DirectoryNotEmpty(id)
385         e = self._data[-1].pop(id)
386         e.parent.remove(e)
387
388     def _recursive_remove(self, id):
389         for entry in reversed(list(self._data[-1][id].traverse())):
390             self._remove(entry.id)
391
392     def _children(self, id=None, revision=None):
393         if id == None:
394             id = '__ROOT__'
395         if revision == None:
396             revision = -1
397         else:
398             revision = int(revision)
399         return [c.id for c in self._data[revision][id]
400                 if not c.id.startswith('__')]
401
402     def _get(self, id, default=InvalidObject, revision=None):
403         if revision == None:
404             revision = -1
405         else:
406             revision = int(revision)
407         if id in self._data[revision]:
408             return self._data[revision][id].value
409         elif default == InvalidObject:
410             raise InvalidID(id)
411         return default
412
413     def _set(self, id, value):
414         if id not in self._data[-1]:
415             raise InvalidID(id)
416         self._data[-1][id].value = value
417
418     def commit(self, *args, **kwargs):
419         """
420         Commit the current repository, with a commit message string
421         summary and body.  Return the name of the new revision.
422
423         If allow_empty == False (the default), raise EmptyCommit if
424         there are no changes to commit.
425         """
426         if self.is_writeable() == False:
427             raise NotWriteable('Cannot commit to unwriteable storage.')
428         return self._commit(*args, **kwargs)
429
430     def _commit(self, summary, body=None, allow_empty=False):
431         if self._data[-1] == self._data[-2] and allow_empty == False:
432             raise EmptyCommit
433         self._data[-1]["__COMMIT__SUMMARY__"].value = summary
434         self._data[-1]["__COMMIT__BODY__"].value = body
435         rev = str(len(self._data)-1)
436         self._data.append(copy.deepcopy(self._data[-1]))
437         return rev
438
439     def revision_id(self, index=None):
440         """
441         Return the name of the <index>th revision.  The choice of
442         which branch to follow when crossing branches/merges is not
443         defined.  Revision indices start at 1; ID 0 is the blank
444         repository.
445
446         Return None if index==None.
447
448         If the specified revision does not exist, raise InvalidRevision.
449         """
450         if index == None:
451             return None
452         try:
453             if int(index) != index:
454                 raise InvalidRevision(index)
455         except ValueError:
456             raise InvalidRevision(index)
457         L = len(self._data) - 1  # -1 b/c of initial commit
458         if index >= -L and index <= L:
459             return str(index % L)
460         raise InvalidRevision(i)
461
462     def changed(self, revision):
463         """
464         Return a tuple of lists of ids
465           (new, modified, removed)
466         from the specified revision to the current situation.
467         """
468         new = []
469         modified = []
470         removed = []
471         for id,value in self._data[int(revision)].items():
472             if id.startswith('__'):
473                 continue
474             if not id in self._data[-1]:
475                 removed.append(id)
476             elif value.value != self._data[-1][id].value:
477                 modified.append(id)
478         for id in self._data[-1]:
479             if not id in self._data[int(revision)]:
480                 new.append(id)
481         return (new, modified, removed)
482
483
484 if TESTING == True:
485     class StorageTestCase (unittest.TestCase):
486         """Test cases for Storage class."""
487
488         Class = Storage
489
490         def __init__(self, *args, **kwargs):
491             super(StorageTestCase, self).__init__(*args, **kwargs)
492             self.dirname = None
493
494         def setUp(self):
495             """Set up test fixtures for Storage test case."""
496             super(StorageTestCase, self).setUp()
497             self.dir = Dir()
498             self.dirname = self.dir.path
499             self.s = self.Class(repo=self.dirname)
500             self.assert_failed_connect()
501             self.s.init()
502             self.s.connect()
503
504         def tearDown(self):
505             super(StorageTestCase, self).tearDown()
506             self.s.disconnect()
507             self.s.destroy()
508             self.assert_failed_connect()
509             self.dir.cleanup()
510
511         def assert_failed_connect(self):
512             try:
513                 self.s.connect()
514                 self.fail(
515                     "Connected to %(name)s repository before initialising"
516                     % vars(self.Class))
517             except ConnectionError:
518                 pass
519
520     class Storage_init_TestCase (StorageTestCase):
521         """Test cases for Storage.init method."""
522
523         def test_connect_should_succeed_after_init(self):
524             """Should connect after initialization."""
525             self.s.connect()
526
527     class Storage_connect_disconnect_TestCase (StorageTestCase):
528         """Test cases for Storage.connect and .disconnect methods."""
529
530         def test_multiple_disconnects(self):
531             """Should be able to call .disconnect multiple times."""
532             self.s.disconnect()
533             self.s.disconnect()
534
535     class Storage_add_remove_TestCase (StorageTestCase):
536         """Test cases for Storage.add, .remove, and .recursive_remove methods."""
537
538         def test_initially_empty(self):
539             """New repository should be empty."""
540             self.failUnless(len(self.s.children()) == 0, self.s.children())
541
542         def test_add_identical_rooted(self):
543             """
544             Adding entries with the same ID should not increase the number of children.
545             """
546             for i in range(10):
547                 self.s.add('some id', directory=False)
548                 s = sorted(self.s.children())
549                 self.failUnless(s == ['some id'], s)
550
551         def test_add_rooted(self):
552             """
553             Adding entries should increase the number of children (rooted).
554             """
555             ids = []
556             for i in range(10):
557                 ids.append(str(i))
558                 self.s.add(ids[-1], directory=(i % 2 == 0))
559                 s = sorted(self.s.children())
560                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
561
562         def test_add_nonrooted(self):
563             """
564             Adding entries should increase the number of children (nonrooted).
565             """
566             self.s.add('parent', directory=True)
567             ids = []
568             for i in range(10):
569                 ids.append(str(i))
570                 self.s.add(ids[-1], 'parent', directory=(i % 2 == 0))
571                 s = sorted(self.s.children('parent'))
572                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
573                 s = self.s.children()
574                 self.failUnless(s == ['parent'], s)
575
576         def test_children(self):
577             """
578             Non-UUID ids should be returned as such.
579             """
580             self.s.add('parent', directory=True)
581             ids = []
582             for i in range(10):
583                 ids.append('parent/%s' % str(i))
584                 self.s.add(ids[-1], 'parent', directory=(i % 2 == 0))
585                 s = sorted(self.s.children('parent'))
586                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
587
588         def test_add_invalid_directory(self):
589             """
590             Should not be able to add children to non-directories.
591             """
592             self.s.add('parent', directory=False)
593             try:
594                 self.s.add('child', 'parent', directory=False)
595                 self.fail(
596                     '%s.add() succeeded instead of raising InvalidDirectory'
597                     % (vars(self.Class)['name']))
598             except InvalidDirectory:
599                 pass
600             try:
601                 self.s.add('child', 'parent', directory=True)
602                 self.fail(
603                     '%s.add() succeeded instead of raising InvalidDirectory'
604                     % (vars(self.Class)['name']))
605             except InvalidDirectory:
606                 pass
607             self.failUnless(len(self.s.children('parent')) == 0,
608                             self.s.children('parent'))
609
610         def test_remove_rooted(self):
611             """
612             Removing entries should decrease the number of children (rooted).
613             """
614             ids = []
615             for i in range(10):
616                 ids.append(str(i))
617                 self.s.add(ids[-1], directory=(i % 2 == 0))
618             for i in range(10):
619                 self.s.remove(ids.pop())
620                 s = sorted(self.s.children())
621                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
622
623         def test_remove_nonrooted(self):
624             """
625             Removing entries should decrease the number of children (nonrooted).
626             """
627             self.s.add('parent', directory=True)
628             ids = []
629             for i in range(10):
630                 ids.append(str(i))
631                 self.s.add(ids[-1], 'parent', directory=False)#(i % 2 == 0))
632             for i in range(10):
633                 self.s.remove(ids.pop())
634                 s = sorted(self.s.children('parent'))
635                 self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
636                 if len(s) > 0:
637                     s = self.s.children()
638                     self.failUnless(s == ['parent'], s)
639
640         def test_remove_directory_not_empty(self):
641             """
642             Removing a non-empty directory entry should raise exception.
643             """
644             self.s.add('parent', directory=True)
645             ids = []
646             for i in range(10):
647                 ids.append(str(i))
648                 self.s.add(ids[-1], 'parent', directory=(i % 2 == 0))
649             self.s.remove(ids.pop()) # empty directory removal succeeds
650             try:
651                 self.s.remove('parent') # empty directory removal succeeds
652                 self.fail(
653                     "%s.remove() didn't raise DirectoryNotEmpty"
654                     % (vars(self.Class)['name']))
655             except DirectoryNotEmpty:
656                 pass
657
658         def test_recursive_remove(self):
659             """
660             Recursive remove should empty the tree.
661             """
662             self.s.add('parent', directory=True)
663             ids = []
664             for i in range(10):
665                 ids.append(str(i))
666                 self.s.add(ids[-1], 'parent', directory=True)
667                 for j in range(10): # add some grandkids
668                     self.s.add(str(20*(i+1)+j), ids[-1], directory=(i%2 == 0))
669             self.s.recursive_remove('parent')
670             s = sorted(self.s.children())
671             self.failUnless(s == [], s)
672
673     class Storage_get_set_TestCase (StorageTestCase):
674         """Test cases for Storage.get and .set methods."""
675
676         id = 'unlikely id'
677         val = 'unlikely value'
678
679         def test_get_default(self):
680             """
681             Get should return specified default if id not in Storage.
682             """
683             ret = self.s.get(self.id, default=self.val)
684             self.failUnless(ret == self.val,
685                     "%s.get() returned %s not %s"
686                     % (vars(self.Class)['name'], ret, self.val))
687
688         def test_get_default_exception(self):
689             """
690             Get should raise exception if id not in Storage and no default.
691             """
692             try:
693                 ret = self.s.get(self.id)
694                 self.fail(
695                     "%s.get() returned %s instead of raising InvalidID"
696                     % (vars(self.Class)['name'], ret))
697             except InvalidID:
698                 pass
699
700         def test_get_initial_value(self):
701             """
702             Data value should be None before any value has been set.
703             """
704             self.s.add(self.id, directory=False)
705             ret = self.s.get(self.id)
706             self.failUnless(ret == None,
707                     "%s.get() returned %s not None"
708                     % (vars(self.Class)['name'], ret))
709
710         def test_set_exception(self):
711             """
712             Set should raise exception if id not in Storage.
713             """
714             try:
715                 self.s.set(self.id, self.val)
716                 self.fail(
717                     "%(name)s.set() did not raise InvalidID"
718                     % vars(self.Class))
719             except InvalidID:
720                 pass
721
722         def test_set(self):
723             """
724             Set should define the value returned by get.
725             """
726             self.s.add(self.id, directory=False)
727             self.s.set(self.id, self.val)
728             ret = self.s.get(self.id)
729             self.failUnless(ret == self.val,
730                     "%s.get() returned %s not %s"
731                     % (vars(self.Class)['name'], ret, self.val))
732
733         def test_unicode_set(self):
734             """
735             Set should define the value returned by get.
736             """
737             val = u'Fran\xe7ois'
738             self.s.add(self.id, directory=False)
739             self.s.set(self.id, val)
740             ret = self.s.get(self.id, decode=True)
741             self.failUnless(type(ret) == types.UnicodeType,
742                     "%s.get() returned %s not UnicodeType"
743                     % (vars(self.Class)['name'], type(ret)))
744             self.failUnless(ret == val,
745                     "%s.get() returned %s not %s"
746                     % (vars(self.Class)['name'], ret, self.val))
747             ret = self.s.get(self.id)
748             self.failUnless(type(ret) == types.StringType,
749                     "%s.get() returned %s not StringType"
750                     % (vars(self.Class)['name'], type(ret)))
751             s = unicode(ret, self.s.encoding)
752             self.failUnless(s == val,
753                     "%s.get() returned %s not %s"
754                     % (vars(self.Class)['name'], s, self.val))
755
756
757     class Storage_persistence_TestCase (StorageTestCase):
758         """Test cases for Storage.disconnect and .connect methods."""
759
760         id = 'unlikely id'
761         val = 'unlikely value'
762
763         def test_get_set_persistence(self):
764             """
765             Set should define the value returned by get after reconnect.
766             """
767             self.s.add(self.id, directory=False)
768             self.s.set(self.id, self.val)
769             self.s.disconnect()
770             self.s.connect()
771             ret = self.s.get(self.id)
772             self.failUnless(ret == self.val,
773                     "%s.get() returned %s not %s"
774                     % (vars(self.Class)['name'], ret, self.val))
775
776         def test_add_nonrooted_persistence(self):
777             """
778             Adding entries should increase the number of children after reconnect.
779             """
780             self.s.add('parent', directory=True)
781             ids = []
782             for i in range(10):
783                 ids.append(str(i))
784                 self.s.add(ids[-1], 'parent', directory=(i % 2 == 0))
785             self.s.disconnect()
786             self.s.connect()
787             s = sorted(self.s.children('parent'))
788             self.failUnless(s == ids, '\n  %s\n  !=\n  %s' % (s, ids))
789             s = self.s.children()
790             self.failUnless(s == ['parent'], s)
791
792     class VersionedStorageTestCase (StorageTestCase):
793         """Test cases for VersionedStorage methods."""
794
795         Class = VersionedStorage
796
797     class VersionedStorage_commit_TestCase (VersionedStorageTestCase):
798         """Test cases for VersionedStorage.commit and revision_ids methods."""
799
800         id = 'unlikely id'
801         val = 'Some value'
802         commit_msg = 'Committing something interesting'
803         commit_body = 'Some\nlonger\ndescription\n'
804
805         def _setup_for_empty_commit(self):
806             """
807             Initialization might add some files to version control, so
808             commit those first, before testing the empty commit
809             functionality.
810             """
811             try:
812                 self.s.commit('Added initialization files')
813             except EmptyCommit:
814                 pass
815                 
816         def test_revision_id_exception(self):
817             """
818             Invalid revision id should raise InvalidRevision.
819             """
820             try:
821                 rev = self.s.revision_id('highly unlikely revision id')
822                 self.fail(
823                     "%s.revision_id() didn't raise InvalidRevision, returned %s."
824                     % (vars(self.Class)['name'], rev))
825             except InvalidRevision:
826                 pass
827
828         def test_empty_commit_raises_exception(self):
829             """
830             Empty commit should raise exception.
831             """
832             self._setup_for_empty_commit()
833             try:
834                 self.s.commit(self.commit_msg, self.commit_body)
835                 self.fail(
836                     "Empty %(name)s.commit() didn't raise EmptyCommit."
837                     % vars(self.Class))
838             except EmptyCommit:
839                 pass
840
841         def test_empty_commit_allowed(self):
842             """
843             Empty commit should _not_ raise exception if allow_empty=True.
844             """
845             self._setup_for_empty_commit()
846             self.s.commit(self.commit_msg, self.commit_body,
847                           allow_empty=True)
848
849         def test_commit_revision_ids(self):
850             """
851             Commit / revision_id should agree on revision ids.
852             """
853             def val(i):
854                 return '%s:%d' % (self.val, i+1)
855             self.s.add(self.id, directory=False)
856             revs = []
857             for i in range(10):
858                 self.s.set(self.id, val(i))
859                 revs.append(self.s.commit('%s: %d' % (self.commit_msg, i),
860                                           self.commit_body))
861             for i in range(10):
862                 rev = self.s.revision_id(i+1)
863                 self.failUnless(rev == revs[i],
864                                 "%s.revision_id(%d) returned %s not %s"
865                                 % (vars(self.Class)['name'], i+1, rev, revs[i]))
866             for i in range(-1, -9, -1):
867                 rev = self.s.revision_id(i)
868                 self.failUnless(rev == revs[i],
869                                 "%s.revision_id(%d) returned %s not %s"
870                                 % (vars(self.Class)['name'], i, rev, revs[i]))
871
872         def test_get_previous_version(self):
873             """
874             Get should be able to return the previous version.
875             """
876             def val(i):
877                 return '%s:%d' % (self.val, i+1)
878             self.s.add(self.id, directory=False)
879             revs = []
880             for i in range(10):
881                 self.s.set(self.id, val(i))
882                 revs.append(self.s.commit('%s: %d' % (self.commit_msg, i),
883                                           self.commit_body))
884             for i in range(10):
885                 ret = self.s.get(self.id, revision=revs[i])
886                 self.failUnless(ret == val(i),
887                                 "%s.get() returned %s not %s for revision %s"
888                                 % (vars(self.Class)['name'], ret, val(i), revs[i]))
889
890         def test_get_previous_children(self):
891             """
892             Children list should be revision dependent.
893             """
894             self.s.add('parent', directory=True)
895             revs = []
896             cur_children = []
897             children = []
898             for i in range(10):
899                 new_child = str(i)
900                 self.s.add(new_child, 'parent', directory=(i % 2 == 0))
901                 self.s.set(new_child, self.val)
902                 revs.append(self.s.commit('%s: %d' % (self.commit_msg, i),
903                                           self.commit_body))
904                 cur_children.append(new_child)
905                 children.append(list(cur_children))
906             for i in range(10):
907                 ret = self.s.children('parent', revision=revs[i])
908                 self.failUnless(ret == children[i],
909                                 "%s.get() returned %s not %s for revision %s"
910                                 % (vars(self.Class)['name'], ret,
911                                    children[i], revs[i]))
912
913     class VersionedStorage_changed_TestCase (VersionedStorageTestCase):
914         """Test cases for VersionedStorage.changed() method."""
915
916         def test_changed(self):
917             self.s.add('dir', directory=True)
918             self.s.add('modified', parent='dir')
919             self.s.set('modified', 'some value to be modified')
920             self.s.add('moved', parent='dir')
921             self.s.set('moved', 'this entry will be moved')
922             self.s.add('removed', parent='dir')
923             self.s.set('removed', 'this entry will be deleted')
924             revA = self.s.commit('Initial state')
925             self.s.add('new', parent='dir')
926             self.s.set('new', 'this entry is new')
927             self.s.set('modified', 'a new value')
928             self.s.remove('moved')
929             self.s.add('moved2', parent='dir')
930             self.s.set('moved2', 'this entry will be moved')
931             self.s.remove('removed')
932             revB = self.s.commit('Final state')
933             new,mod,rem = self.s.changed(revA)
934             self.failUnless(sorted(new) == ['moved2', 'new'],
935                             'Unexpected new: %s' % new)
936             self.failUnless(mod == ['modified'],
937                             'Unexpected modified: %s' % mod)
938             self.failUnless(sorted(rem) == ['moved', 'removed'],
939                             'Unexpected removed: %s' % rem)
940
941     def make_storage_testcase_subclasses(storage_class, namespace):
942         """Make StorageTestCase subclasses for storage_class in namespace."""
943         storage_testcase_classes = [
944             c for c in (
945                 ob for ob in globals().values() if isinstance(ob, type))
946             if issubclass(c, StorageTestCase) \
947                 and c.Class == Storage]
948
949         for base_class in storage_testcase_classes:
950             testcase_class_name = storage_class.__name__ + base_class.__name__
951             testcase_class_bases = (base_class,)
952             testcase_class_dict = dict(base_class.__dict__)
953             testcase_class_dict['Class'] = storage_class
954             testcase_class = type(
955                 testcase_class_name, testcase_class_bases, testcase_class_dict)
956             setattr(namespace, testcase_class_name, testcase_class)
957
958     def make_versioned_storage_testcase_subclasses(storage_class, namespace):
959         """Make VersionedStorageTestCase subclasses for storage_class in namespace."""
960         storage_testcase_classes = [
961             c for c in (
962                 ob for ob in globals().values() if isinstance(ob, type))
963             if ((issubclass(c, StorageTestCase) \
964                      and c.Class == Storage)
965                 or
966                 (issubclass(c, VersionedStorageTestCase) \
967                      and c.Class == VersionedStorage))]
968
969         for base_class in storage_testcase_classes:
970             testcase_class_name = storage_class.__name__ + base_class.__name__
971             testcase_class_bases = (base_class,)
972             testcase_class_dict = dict(base_class.__dict__)
973             testcase_class_dict['Class'] = storage_class
974             testcase_class = type(
975                 testcase_class_name, testcase_class_bases, testcase_class_dict)
976             setattr(namespace, testcase_class_name, testcase_class)
977
978     make_storage_testcase_subclasses(VersionedStorage, sys.modules[__name__])
979
980     unitsuite =unittest.TestLoader().loadTestsFromModule(sys.modules[__name__])
981     suite = unittest.TestSuite([unitsuite, doctest.DocTestSuite()])