Cleaned up Builtin handling and reduced driver/plugin duplication.
authorW. Trevor King <wking@drexel.edu>
Sun, 9 May 2010 13:57:07 +0000 (09:57 -0400)
committerW. Trevor King <wking@drexel.edu>
Sun, 9 May 2010 13:57:07 +0000 (09:57 -0400)
hooke/driver/__init__.py
hooke/hooke.py
hooke/plugin/__init__.py

index ff90910237c2854ddd6f9063fc69935f3d99eea1..96bbc612f6f4ae2637347bcf879f573073b03b5c 100644 (file)
@@ -7,8 +7,7 @@ to write your own to handle your lab's specific format.
 """
 
 from ..config import Setting
-from ..util.graph import Node, Graph
-
+from ..plugin import construct_graph, IsSubclass
 
 DRIVER_MODULES = [
 #    ('csvdriver', True),
@@ -69,34 +68,13 @@ class Driver(object):
 
 # Construct driver dependency graph and load default drivers.
 
-DRIVERS = {}
-"""(name, instance) :class:`dict` of all possible :class:`Driver`\s.
+DRIVER_GRAPH = construct_graph(
+    this_modname=__name__,
+    submodnames=[name for name,include in DRIVER_MODULES],
+    class_selector=IsSubclass(Driver, blacklist=[Driver]))
+"""Topologically sorted list of all possible :class:`Driver`\s.
 """
 
-for driver_modname,default_include in DRIVER_MODULES:
-    assert len([mod_name for mod_name,di in DRIVER_MODULES]) == 1, \
-        'Multiple %s entries in DRIVER_MODULES' % mod_name
-    this_mod = __import__(__name__, fromlist=[driver_modname])
-    driver_mod = getattr(this_mod, driver_modname)
-    for objname in dir(driver_mod):
-        obj = getattr(driver_mod, objname)
-        try:
-            subclass = issubclass(obj, Driver)
-        except TypeError:
-            continue
-        if subclass == True and obj != Driver:
-            d = obj()
-            if d.name != driver_modname:
-                raise Exception('Driver name %s does not match module name %s'
-                                % (d.name, driver_modname))
-            DRIVERS[d.name] = d
-
-DRIVER_GRAPH = Graph([Node([DRIVERS[name] for name in d.dependencies()],
-                           data=d)
-                      for d in DRIVERS.values()])
-DRIVER_GRAPH.topological_sort()
-
-
 def default_settings():
     settings = [Setting(
             'drivers', help='Enable/disable default drivers.')]
index 923f6cba0e756b12415ca098735a823f3d32ad77..74271ccd39e36b8dd72016f1bcd9df8f9c0ba7d8 100644 (file)
@@ -6,7 +6,6 @@ Hooke - A force spectroscopy review & analysis tool.
 COPYRIGHT
 '''
 
-import ConfigParser as configparser
 import Queue as queue
 import multiprocessing
 
@@ -43,57 +42,26 @@ class Hooke (object):
                 default_settings=default_settings)
             config.read()
         self.config = config
-        self.load_builtins()
         self.load_plugins()
         self.load_drivers()
-        self.setup_commands()
-
-    def load_builtins(self):
-        self.builtins = []
-        for builtin in plugin_mod.BUILTINS.values():
-            builtin = plugin_mod.BUILTINS[builtin_name]
-            try:
-                builtin.config = dict(
-                    self.config.items(builtin.setting_section))
-            except configparser.NoSectionError:
-                pass
-            self.builtins.append(plugin_mod.BUILTINS[builtin_name])
 
     def load_plugins(self):
-        self.plugins = []
-        for plugin_name,include in self.config.items('plugins'):
-            if include == 'True':
-                plugin = plugin_mod.PLUGINS[plugin_name]
-                try:
-                    plugin.config = dict(
-                        self.config.items(plugin.setting_section))
-                except configparser.NoSectionError:
-                    pass
-                self.plugins.append(plugin_mod.PLUGINS[plugin_name])
-
-    def load_drivers(self):
-        self.drivers = []
-        for driver_name,include in self.config.items('drivers'):
-            if include == 'True':
-                driver = driver_mod.DRIVERS[driver_name]
-                try:
-                    driver.config = dict(
-                        self.config.items(driver.setting_section))
-                except configparser.NoSectionError:
-                    pass
-                self.drivers.append(driver_mod.DRIVERS[driver_name])
-
-    def setup_commands(self):
+        self.plugins = plugin_mod.load_graph(
+            plugin_mod.PLUGIN_GRAPH, self.config, include_section='plugins')
         self.commands = []
-        for plugin in self.builtins + self.plugins:
+        for plugin in self.plugins:
             self.commands.extend(plugin.commands())
 
+    def load_drivers(self):
+        self.drivers = plugin_mod.load_graph(
+            driver_mod.DRIVER_GRAPH, self.config, include_section='drivers')
+
     def close(self):
         if self.config.changed:
             self.config.write() # Does not preserve original comments
 
     def playlist_status(self, playlist):
-        if playlist.has_curves()
+        if playlist.has_curves():
             return '%s (%s/%s)' % (playlist.name, playlist._index + 1,
                                    len(playlist))
         return 'The playlist %s does not contain any valid force curve data.' \
index f56644b000a9485a780b32016fd99a0e586af784..544200ae250ecc9ed250d019332dee7815f300fb 100644 (file)
@@ -4,6 +4,7 @@ commands.
 All of the science happens in here.
 """
 
+import ConfigParser as configparser
 import Queue as queue
 
 from ..config import Setting
@@ -355,12 +356,12 @@ def default_settings():
     settings = [Setting(
             'plugins', help='Enable/disable default plugins.')]
     for pnode in PLUGIN_GRAPH:
-        if pnode.name in BUILTIN_MODULES:
+        if pnode.data.name in BUILTIN_MODULES:
             continue # builtin inclusion is not optional
         plugin = pnode.data
         default_include = [di for mod_name,di in PLUGIN_MODULES
                            if mod_name == plugin.name][0]
-        help = 'Commands: ' + ', '.join([c.name for c in p.commands()])
+        help = 'Commands: ' + ', '.join([c.name for c in plugin.commands()])
         settings.append(Setting(
                 section='plugins',
                 option=plugin.name,
@@ -371,3 +372,20 @@ def default_settings():
         plugin = pnode.data
         settings.extend(plugin.default_settings())
     return settings
+
+def load_graph(graph, config, include_section):
+    items = []
+    for node in graph:
+        item = node.data
+        try:
+            include = config.getboolean(include_section, item.name)
+        except configparser.NoOptionError:
+            include = True # non-optional include (e.g. a Builtin)
+        if include == True:
+            try:
+                item.config = dict(
+                    config.items(item.setting_section))
+            except configparser.NoSectionError:
+                pass
+            items.append(item)
+    return items