Add a new sqlite cache module (one I wrote) that has acceptable performance.
authorZac Medico <zmedico@gentoo.org>
Mon, 1 May 2006 22:13:13 +0000 (22:13 -0000)
committerZac Medico <zmedico@gentoo.org>
Mon, 1 May 2006 22:13:13 +0000 (22:13 -0000)
svn path=/main/trunk/; revision=3299

pym/cache/sqlite.py [new file with mode: 0644]

diff --git a/pym/cache/sqlite.py b/pym/cache/sqlite.py
new file mode 100644 (file)
index 0000000..4863b59
--- /dev/null
@@ -0,0 +1,237 @@
+# Copyright 1999-2006 Gentoo Foundation
+# Distributed under the terms of the GNU General Public License v2
+# $Header: $
+
+import fs_template
+import cache_errors
+import errno, os, stat
+from mappings import LazyLoad, ProtectedDict
+from template import reconstruct_eclasses
+from portage_util import writemsg, apply_secpass_permissions
+from portage_data import portage_gid
+from pysqlite2 import dbapi2 as db_module
+DBError = db_module.Error
+
+class database(fs_template.FsBased):
+
+       autocommits = False
+       synchronous = False
+       # cache_bytes is used together with page_size (set at sqlite build time)
+       # to calculate the number of pages requested, according to the following
+       # equation: cache_bytes = page_bytes * page_count
+       cache_bytes = 1024 * 1024 * 10
+       _db_module = db_module
+       _db_error = DBError
+       _db_table = None
+
+       def __init__(self, *args, **config):
+               super(database, self).__init__(*args, **config)
+               self._allowed_keys = ["_mtime_", "_eclasses_"] + self._known_keys
+               self.location = os.path.join(self.location, 
+                       self.label.lstrip(os.path.sep).rstrip(os.path.sep))
+
+               if not os.path.exists(self.location):
+                       self._ensure_dirs()
+
+               config.setdefault("autocommit", self.autocommits)
+               config.setdefault("cache_bytes", self.cache_bytes)
+               config.setdefault("synchronous", self.synchronous)
+               self._db_init_connection(config)
+               self._db_init_structures()
+
+       def _db_escape_string(self, s):
+               """meta escaping, returns quoted string for use in sql statements"""
+               return "'%s'" % str(s).replace("\\","\\\\").replace("'","''")
+
+       def _db_init_connection(self, config):
+               self._dbpath = self.location + ".sqlite"
+               #if os.path.exists(self._dbpath):
+               #       os.unlink(self._dbpath)
+               try:
+                       self._ensure_dirs()
+                       self._db_connection = self._db_module.connect(database=self._dbpath)
+                       self._db_cursor = self._db_connection.cursor()
+                       self._db_cursor.execute("PRAGMA encoding = %s" % self._db_escape_string("UTF-8"))
+                       if not apply_secpass_permissions(self._dbpath, gid=portage_gid, mode=070, mask=02):
+                               raise cache_errors.InitializationError(self.__class__, "can't ensure perms on %s" % self._dbpath)
+                       self._db_init_cache_size(config["cache_bytes"])
+                       self._db_init_synchronous(config["synchronous"])
+               except self._db_error, e:
+                       raise cache_errors.InitializationError(self.__class__, e)
+
+       def _db_init_structures(self):
+               self._db_table = {}
+               self._db_table["packages"] = {}
+               mytable = "portage_packages"
+               self._db_table["packages"]["table_name"] = mytable
+               self._db_table["packages"]["package_id"] = "internal_db_package_id"
+               self._db_table["packages"]["package_key"] = "portage_package_key"
+               self._db_table["packages"]["internal_columns"] = \
+                       [self._db_table["packages"]["package_id"],
+                       self._db_table["packages"]["package_key"]]
+               create_statement = []
+               create_statement.append("CREATE TABLE")
+               create_statement.append(mytable)
+               create_statement.append("(")
+               table_parameters = []
+               table_parameters.append("%s INTEGER PRIMARY KEY AUTOINCREMENT" % self._db_table["packages"]["package_id"])
+               table_parameters.append("%s TEXT" % self._db_table["packages"]["package_key"])
+               for k in self._allowed_keys:
+                       table_parameters.append("%s TEXT" % k)
+               table_parameters.append("UNIQUE(%s)" % self._db_table["packages"]["package_key"])
+               create_statement.append(",".join(table_parameters))
+               create_statement.append(")")
+               
+               self._db_table["packages"]["create"] = " ".join(create_statement)
+               self._db_table["packages"]["columns"] = \
+                       self._db_table["packages"]["internal_columns"] + \
+                       self._allowed_keys
+
+               cursor = self._db_cursor
+               for k, v in self._db_table.iteritems():
+                       if self._db_table_exists(v["table_name"]):
+                               create_statement = self._db_table_get_create(v["table_name"])
+                               if create_statement != v["create"]:
+                                       writemsg("sqlite: dropping old table: %s\n" % v["table_name"])
+                                       cursor.execute("DROP TABLE %s" % v["table_name"])
+                                       cursor.execute(v["create"])
+                       else:
+                               cursor.execute(v["create"])
+
+       def _db_table_exists(self, table_name):
+               """return true/false dependant on a tbl existing"""
+               cursor = self._db_cursor
+               cursor.execute("SELECT name FROM sqlite_master WHERE type=\"table\" AND name=%s" % \
+                       self._db_escape_string(table_name))
+               return len(cursor.fetchall()) == 1
+
+       def _db_table_get_create(self, table_name):
+               """return true/false dependant on a tbl existing"""
+               cursor = self._db_cursor
+               cursor.execute("SELECT sql FROM sqlite_master WHERE name=%s" % \
+                       self._db_escape_string(table_name))
+               return cursor.fetchall()[0][0]
+
+       def _db_init_cache_size(self, cache_bytes):
+               cursor = self._db_cursor
+               cursor.execute("PRAGMA page_size")
+               page_size=int(cursor.fetchone()[0])
+               # number of pages, sqlite default is 2000
+               cache_size = cache_bytes / page_size
+               cursor.execute("PRAGMA cache_size = %d" % cache_size)
+               cursor.execute("PRAGMA cache_size")
+               actual_cache_size = int(cursor.fetchone()[0])
+               del cursor
+               if actual_cache_size != cache_size:
+                       raise cache_errors.InitializationError(self.__class__,"actual cache_size = "+actual_cache_size+" does does not match requested size of "+cache_size)
+
+       def _db_init_synchronous(self, synchronous):
+               cursor = self._db_cursor
+               cursor.execute("PRAGMA synchronous = %d" % synchronous)
+               cursor.execute("PRAGMA synchronous")
+               actual_synchronous=int(cursor.fetchone()[0])
+               del cursor
+               if actual_synchronous!=synchronous:
+                       raise cache_errors.InitializationError(self.__class__,"actual synchronous = "+actual_synchronous+" does does not match requested value of "+synchronous)
+
+       def __getitem__(self, cpv):
+               if not self.has_key(cpv):
+                       raise KeyError(cpv)
+               def curry(*args):
+                       def callit(*args2):
+                               return args[0](*args[1:]+args2)
+                       return callit
+               return ProtectedDict(LazyLoad(curry(self._pull, cpv)))
+
+       def _pull(self, cpv):
+               cursor = self._db_cursor
+               cursor.execute("select * from %s where %s=%s" % \
+                       (self._db_table["packages"]["table_name"],
+                       self._db_table["packages"]["package_key"],
+                       self._db_escape_string(cpv)))
+               result = cursor.fetchall()
+               if len(result) == 1:
+                       pass
+               elif len(result) == 0:
+                       raise KeyError(cpv)
+               else:
+                       raise cache_errors.CacheCorruption(cpv, "key is not unique")
+               d = {}
+               internal_columns = self._db_table["packages"]["internal_columns"]
+               column_index = -1
+               for k in self._db_table["packages"]["columns"]:
+                       column_index +=1
+                       if k not in internal_columns:
+                               d[k] = result[0][column_index]
+               # XXX: The resolver chokes on unicode strings so we convert them here.
+               for k in d.keys():
+                       try:
+                               d[k]=str(d[k]) # convert unicode strings to normal
+                       except UnicodeEncodeError, e:
+                               pass #writemsg("%s: %s\n" % (cpv, str(e)))
+               if "_eclasses_" in d:
+                       d["_eclasses_"] = reconstruct_eclasses(cpv, d["_eclasses_"])
+               for x in self._known_keys:
+                       d.setdefault(x,'')
+               return d
+
+       def _setitem(self, cpv, values):
+               update_statement = []
+               update_statement.append("REPLACE INTO %s" % self._db_table["packages"]["table_name"])
+               update_statement.append("(")
+               update_statement.append(','.join([self._db_table["packages"]["package_key"]] + self._allowed_keys))
+               update_statement.append(")")
+               update_statement.append("VALUES")
+               update_statement.append("(")
+               values_parameters = []
+               values_parameters.append(self._db_escape_string(cpv))
+               for k in self._allowed_keys:
+                       values_parameters.append(self._db_escape_string(values.get(k, '')))
+               update_statement.append(",".join(values_parameters))
+               update_statement.append(")")
+               cursor = self._db_cursor
+               try:
+                       s = " ".join(update_statement)
+                       cursor.execute(s)
+               except self._db_error, e:
+                       writemsg("%s: %s\n" % (cpv, str(e)))
+                       raise
+
+       def commit(self):
+               self._db_connection.commit()
+
+       def _delitem(self, cpv):
+               cursor = self._db_cursor
+               cursor.execute("DELETE FROM %s WHERE %s=%s" % \
+                       (self._db_table["packages"]["table_name"],
+                       self._db_table["packages"]["package_key"],
+                       self._db_escape_string(cpv)))
+
+       def has_key(self, cpv):
+               cursor = self._db_cursor
+               cursor.execute(" ".join(
+                       ["SELECT %s FROM %s" %
+                       (self._db_table["packages"]["package_id"],
+                       self._db_table["packages"]["table_name"]),
+                       "WHERE %s=%s" % (
+                       self._db_table["packages"]["package_key"],
+                       self._db_escape_string(cpv))]))
+               result = cursor.fetchall()
+               if len(result) == 0:
+                       return False
+               elif len(result) == 1:
+                       return True
+               else:
+                       raise cache_errors.CacheCorruption(cpv, "key is not unique")
+
+       def iterkeys(self):
+               """generator for walking the dir struct"""
+               cursor = self._db_cursor
+               cursor.execute("SELECT %s FROM %s" % \
+                       (self._db_table["packages"]["package_key"],
+                       self._db_table["packages"]["table_name"]))
+               result = cursor.fetchall()
+               key_list = [x[0] for x in result]
+               del result
+               while key_list:
+                       yield key_list.pop()