Track connection status to allow multiple Storage.disconnect() calls.
authorW. Trevor King <wking@drexel.edu>
Thu, 31 Dec 2009 16:47:33 +0000 (11:47 -0500)
committerW. Trevor King <wking@drexel.edu>
Thu, 31 Dec 2009 16:47:33 +0000 (11:47 -0500)
This makes cleaning up UIs easier: just call disconnect() :p.

libbe/storage/base.py
libbe/storage/vcs/base.py

index ffde475fb6b8099bd026c579363b467b1d9f4915..1c711fa8143083323149f45a73f54e8e7076d712 100644 (file)
@@ -139,6 +139,7 @@ class Storage (object):
         self._writeable = True # hard limit (backend choice)
         self.versioned = False
         self.can_init = True
+        self.connected = False
 
     def __str__(self):
         return '<%s %s %s>' % (self.__class__.__name__, id(self), self.repo)
@@ -190,6 +191,7 @@ class Storage (object):
         if self.is_readable() == False:
             raise NotReadable('Cannot connect to unreadable storage.')
         self._connect()
+        self.connected = True
 
     def _connect(self):
         try:
@@ -204,6 +206,12 @@ class Storage (object):
         """Close the connection to the repository."""
         if self.is_writeable() == False:
             return
+        if self.connected == False:
+            return
+        self._disconnect()
+        self.connected = False
+
+    def _disconnect(self):
         f = open(os.path.join(self.repo, 'repo.pkl'), 'wb')
         pickle.dump(dict((k,v._objects_to_ids())
                          for k,v in self._data.items()), f, -1)
@@ -342,10 +350,7 @@ class VersionedStorage (Storage):
                       for t in d]
         f.close()
 
-    def disconnect(self):
-        """Close the connection to the repository."""
-        if self.is_writeable() == False:
-            return
+    def _disconnect(self):
         f = open(os.path.join(self.repo, 'repo.pkl'), 'wb')
         pickle.dump([dict((k,v._objects_to_ids())
                           for k,v in t.items()) for t in self._data], f, -1)
@@ -478,6 +483,14 @@ if TESTING == True:
             """Should connect after initialization."""
             self.s.connect()
 
+    class Storage_connect_disconnect_TestCase (StorageTestCase):
+        """Test cases for Storage.connect and .disconnect methods."""
+
+        def test_multiple_disconnects(self):
+            """Should be able to call .disconnect multiple times."""
+            self.s.disconnect()
+            self.s.disconnect()
+
     class Storage_add_remove_TestCase (StorageTestCase):
         """Test cases for Storage.add, .remove, and .recursive_remove methods."""
 
index b47ed2fa2c34c3797b9cdf68af765184106c492f..99f43f3a40230bded839a55e054b6daacea24b5b 100644 (file)
@@ -683,7 +683,7 @@ os.listdir(self.get_path("bugs")):
         self._cached_path_id.connect()
         self.check_storage_version()
 
-    def disconnect(self):
+    def _disconnect(self):
         self._cached_path_id.disconnect()
 
     def _add_path(self, path, directory=False):