From d72430fee347e21a9b9e7912417615bbdb22e6d4 Mon Sep 17 00:00:00 2001 From: "W. Trevor King" Date: Fri, 22 Jan 2010 13:28:01 -0500 Subject: [PATCH] Added _EMPTY and Storage.exists() to libbe.storage.base. There seem to be problems distinguishing between "added but unset" IDs and "added and set to ''" IDs. Now _EMPTY lets us mark "added but unset", and Storage.exists() handles "already added?" more clearly than the old hack "does .get() succeed?". --- libbe/storage/base.py | 55 +++++++++++++++++++++++++++++++++---------- 1 file changed, 43 insertions(+), 12 deletions(-) diff --git a/libbe/storage/base.py b/libbe/storage/base.py index 202305b..64ae3e7 100644 --- a/libbe/storage/base.py +++ b/libbe/storage/base.py @@ -84,9 +84,12 @@ class EmptyCommit(Exception): def __init__(self): Exception.__init__(self, 'No changes to commit') +class _EMPTY (object): + """Entry has been added but has no user-set value.""" + pass class Entry (Tree): - def __init__(self, id, value=None, parent=None, directory=False, + def __init__(self, id, value=_EMPTY, parent=None, directory=False, children=None): if children == None: Tree.__init__(self) @@ -241,10 +244,7 @@ class Storage (object): """Add an entry""" if self.is_writeable() == False: raise NotWriteable('Cannot add entry to unwriteable storage.') - try: # Maybe we've already added that id? - self.get(id) - pass # yup, no need to add another - except InvalidID: + if not self.exists(id): self._add(id, *args, **kwargs) def _add(self, id, parent=None, directory=False): @@ -253,6 +253,15 @@ class Storage (object): p = self._data[parent] self._data[id] = Entry(id, parent=p, directory=directory) + def exists(self, *args, **kwargs): + """Check an entry's existence""" + if self.is_readable() == False: + raise NotReadable('Cannot check entry existence in unreadable storage.') + return self._exists(*args, **kwargs) + + def _exists(self, id, revision=None): + return id in self._data + def remove(self, *args, **kwargs): """Remove an entry.""" if self.is_writeable() == False: @@ -332,7 +341,7 @@ class Storage (object): return value def _get(self, id, default=InvalidObject, revision=None): - if id in self._data: + if id in self._data and self._data[id].value != _EMPTY: return self._data[id].value elif default == InvalidObject: raise InvalidID(id) @@ -402,6 +411,13 @@ class VersionedStorage (Storage): p = self._data[-1][parent] self._data[-1][id] = Entry(id, parent=p, directory=directory) + def _exists(self, id, revision=None): + if revision == None: + revision = -1 + else: + revision = int(revision) + return id in self._data[revision] + def _remove(self, id): if self._data[-1][id].directory == True \ and len(self.children(id)) > 0: @@ -446,7 +462,8 @@ class VersionedStorage (Storage): revision = -1 else: revision = int(revision) - if id in self._data[revision]: + if id in self._data[revision] \ + and self._data[revision][id].value != _EMPTY: return self._data[revision][id].value elif default == InvalidObject: raise InvalidID(id) @@ -760,13 +777,14 @@ if TESTING == True: pass def test_get_initial_value(self): - """Data value should be None before any value has been set. + """Data value should be default before any value has been set. """ self.s.add(self.id, directory=False) - ret = self.s.get(self.id) - self.failUnless(ret == None, - "%s.get() returned %s not None" - % (vars(self.Class)['name'], ret)) + val = 'UNLIKELY DEFAULT' + ret = self.s.get(self.id, default=val) + self.failUnless(ret == val, + "%s.get() returned %s not %s" + % (vars(self.Class)['name'], ret, val)) def test_set_exception(self): """Set should raise exception if id not in Storage. @@ -830,6 +848,19 @@ if TESTING == True: "%s.get() returned %s not %s" % (vars(self.Class)['name'], ret, self.val)) + def test_empty_get_set_persistence(self): + """After empty set, get may return either an empty string or default. + """ + self.s.add(self.id, directory=False) + self.s.set(self.id, '') + self.s.disconnect() + self.s.connect() + default = 'UNLIKELY DEFAULT' + ret = self.s.get(self.id, default=default) + self.failUnless(ret in ['', default], + "%s.get() returned %s not in %s" + % (vars(self.Class)['name'], ret, ['', default])) + def test_add_nonrooted_persistence(self): """Adding entries should increase the number of children after reconnect. """ -- 2.26.2