Added hooke.util.yaml fixing YAML/NumPy type issues (by dropping data).
[hooke.git] / hooke / playlist.py
index 26ab8dc7e80afb13cf0bfc84c3c7683331022812..3380bfcacfc4ca24a02f4ee49e11c89370fa9cd0 100644 (file)
+# Copyright (C) 2010 W. Trevor King <wking@drexel.edu>
+#
+# This file is part of Hooke.
+#
+# Hooke is free software: you can redistribute it and/or modify it
+# under the terms of the GNU Lesser General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# Hooke is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
+# or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General
+# Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with Hooke.  If not, see
+# <http://www.gnu.org/licenses/>.
+
+"""The `playlist` module provides a :class:`Playlist` and its subclass
+:class:`FilePlaylist` for manipulating lists of
+:class:`hooke.curve.Curve`\s.
+"""
+
 import copy
+import hashlib
 import os
 import os.path
-import xml.dom.minidom
+import types
 
-from . import hooke as hooke
-from . import curve as lhc
-from . import libhooke as lh
+import yaml
+from yaml.representer import RepresenterError
 
-class Playlist(object):
-    def __init__(self, drivers):
-        self._saved = False
-        self.count = 0
-        self.curves = []
-        self.drivers = drivers
-        self.path = ''
-        self.genericsDict = {}
-        self.hiddenAttributes = ['curve', 'driver', 'name', 'plots']
-        self.index = -1
-        self.name = 'Untitled'
-        self.plotPanel = None
-        self.plotTab = None
-        self.xml = None
-
-    def add_curve(self, path, attributes={}):
-        curve = lhc.HookeCurve(path)
-        for key,value in attribures.items():
-            setattr(curve, key, value)
-        curve.identify(self.drivers)
-        curve.plots = curve.driver.default_plots()
-        self.curves.append(curve)
-        self._saved = False
-        self.count = len(self.curves)
-        return curve
-
-    def close_curve(self, index):
-        if index >= 0 and index < self.count:
-            self.curves.remove(index)
-
-    def filter_curves(self, keeper_fn=labmda curve:True):
-        playlist = copy.deepcopy(self)
-        for curve in reversed(playlist.curves):
-            if not keeper_fn(curve):
-                playlist.curves.remove(curve)
-        try: # attempt to maintain the same active curve
-            playlist.index = playlist.curves.index(self.get_active_curve())
-        except ValueError:
-            playlist.index = 0
-        playlist._saved = False
-        playlist.count = len(playlist.curves)
-        return playlist
+from . import curve as curve
+from .util.itertools import reverse_enumerate
 
-    def get_active_curve(self):
-        return self.curves[self.index]
 
-    #TODO: do we need this?
-    def get_active_plot(self):
-        return self.curves[self.index].plots[0]
+class NoteIndexList (list):
+    """A list that keeps track of a "current" item and additional notes.
 
-    def get_status_string(self):
-        if self.has_curves()
-            return '%s (%s/%s)' % (self.name, self.index + 1, self.count)
-        return 'The file %s does not contain any valid force curve data.' \
-            % self.name
+    :attr:`index` (i.e. "bookmark") is the index of the currently
+    current curve.  Also keep a :class:`dict` of additional information
+    (:attr:`info`).
+    """
+    def __init__(self, name=None):
+        super(NoteIndexList, self).__init__()
+        self.name = name
+        self.info = {}
+        self._index = 0
+        self._set_ignored_attrs()
 
-    def has_curves(self):
-        if self.count > 0:
-            return True
-        return False
+    def __str__(self):
+        return str(self.__unicode__())
 
-    def is_saved(self):
-        return self._saved
-
-    def load(self, path):
-        '''
-        loads a playlist file
-        '''
-        self.path = path
-        self.name = os.path.basename(path)
-        playlist = lh.delete_empty_lines_from_xmlfile(path)
-        self.xml = xml.dom.minidom.parse(path)
-        # Strip blank spaces:
-        self._removeWhitespaceNodes()
-
-        generics_list = self.xml.getElementsByTagName('generics')
-        curve_list = self.xml.getElementsByTagName('curve')
-        self._loadGenerics(generics_list)
-        self._loadCurves(curve_list)
-        self._saved = True
-
-    def _removeWhitespaceNodes(self, root_node=None):
-        if root_node == None:
-            root_node = self.xml
-        for node in root_node.childNodes:
-            if node.nodeType == node.TEXT_NODE and node.data.strip() == '':
-                root_node.removeChild(node) # drop this whitespace node
-            else:
-                _removeWhitespaceNodes(root_node=node) # recurse down a level
-
-    def _loadGenerics(self, generics_list, clear=True):
-        if clear:
-            self.genericsDict = {}
-        #populate generics
-        generics_list = self.xml.getElementsByTagName('generics')
-        for generics in generics_list:
-            for attribute in generics.attributes.keys():
-                self.genericsDict[attribute] = generics_list[0].getAttribute(attribute)
-        if self.genericsDict.has_key('pointer'):
-            index = int(self.genericsDict['pointer'])
-            if index >= 0 and index < len(self.curves):
-                self.index = index
-            else:
-                index = 0
-
-    def _loadCurves(self, curve_list, clear=True):
-        if clear:
-            self.curves = []
-        #populate playlist with curves
-        for curve in curve_list:
-            #rebuild a data structure from the xml attributes
-            curve_path = lh.get_file_path(element.getAttribute('path'))
-            #extract attributes for the single curve
-            attributes = dict([(k,curve.getAttribute(k))
-                               for k in curve.attributes.keys()])
-            attributes.pop('path')
-            curve = self.add_curve(os.path.join(path, curve_path), attributes)
-            if curve is not None:
-                for plot in curve.plots:
-                    curve.add_data('raw', plot.vectors[0][0], plot.vectors[0][1], color=plot.colors[0], style='plot')
-                    curve.add_data('raw', plot.vectors[1][0], plot.vectors[1][1], color=plot.colors[1], style='plot')
+    def __unicode__(self):
+        return u'<%s %s>' % (self.__class__.__name__, self.name)
+
+    def __repr__(self):
+        return self.__str__()
+
+    def _set_ignored_attrs(self):
+        self._ignored_attrs = ['_ignored_attrs', '_default_attrs']
+        self._default_attrs = {
+            'info': {},
+            }
+
+    def __getstate__(self):
+        state = dict(self.__dict__)
+        for key in self._ignored_attrs:
+            if key in state:
+                del(state[key])
+        for key,value in self._default_attrs.items():
+            if key in state and state[key] == value:
+                del(state[key])
+        assert 'items' not in state
+        state['items'] = []
+        self._assert_clean_state(self, state)
+        for item in self:  # save curves and their attributes
+            item_state = self._item_getstate(item)
+            self._assert_clean_state(item, item_state)
+            state['items'].append(item_state)
+        return state
+
+    def __setstate__(self, state):
+        self._set_ignored_attrs()
+        for key,value in self._default_attrs.items():
+            setattr(self, key, value)
+        for key,value in state.items():
+            if key == 'items':
+                continue
+            setattr(self, key, value)
+        for item_state in state['items']:
+            self.append(self._item_setstate(item_state))
+
+    def _item_getstate(self, item):
+        return item
+
+    def _item_setstate(self, state):
+        return state
+
+    def _assert_clean_state(self, owner, state):
+        for k,v in state.items():
+            if k == 'drivers':  # HACK.  Need better driver serialization.
+                continue
+            try:
+                yaml.dump((k,v))
+            except RepresenterError, e:
+                raise NotImplementedError(
+                    'cannot convert %s.%s = %s (%s) to YAML\n%s'
+                    % (owner.__class__.__name__, k, v, type(v), e))
+
+    def _setup_item(self, item):
+        """Perform any required initialization before returning an item.
+        """
+        pass
+
+    def index(self, value=None, *args, **kwargs):
+        """Extend `list.index`, returning the current index if `value`
+        is `None`.
+        """
+        if value == None:
+            return self._index
+        return super(NoteIndexList, self).index(value, *args, **kwargs)
+
+    def current(self):
+        if len(self) == 0:
+            return None
+        item = self[self._index]
+        self._setup_item(item)
+        return item
+
+    def jump(self, index):
+        if len(self) == 0:
+            self._index = 0
+        else:
+            self._index = index % len(self)
 
     def next(self):
-        self.index += 1
-        if self.index > self.count - 1:
-            self.index = 0
+        self.jump(self._index + 1)
 
     def previous(self):
-        self.index -= 1
-        if self.index < 0:
-            self.index = self.count - 1
+        self.jump(self._index - 1)
 
-    def reset(self):
-        if self.has_curves():
-            self.index = 0
+    def items(self, reverse=False):
+        """Iterate through `self` calling `_setup_item` on each item
+        before yielding.
+
+        Notes
+        -----
+        Updates :attr:`_index` during the iteration so
+        :func:`~hooke.plugin.curve.current_curve_callback` works as
+        expected in :class:`~hooke.command.Command`\s called from
+        :class:`~hooke.plugin.playlist.ApplyCommand`.  After the
+        iteration completes, :attr:`_index` is restored to its
+        original value.
+        """
+        index = self._index
+        items = self
+        if reverse == True:
+            items = reverse_enumerate(self)
+        else:
+            items = enumerate(self)
+        for i,item in items:
+            self._index = i
+            self._setup_item(item)
+            yield item
+        self._index = index
+
+    def filter(self, keeper_fn=lambda item:True, load_curves=True,
+               *args, **kwargs):
+        c = copy.deepcopy(self)
+        if load_curves == True:
+            items = c.items(reverse=True)
         else:
-            self.index = None
-
-    def save(self, path):
-        '''
-        saves the playlist in a XML file.
-        '''
-        try:
-            output_file = file(path, 'w')
-        except IOError, e:
-            #TODO: send message
-            print 'Cannot save playlist: %s' % e
-            return
-        self.xml.writexml(output_file, indent='\n')
-        output_file.close()
-        self._saved = True
-
-    def set_XML(self):
-        '''
-        Creates an initial playlist from a list of files.
-        A playlist is an XML document with the following syntax:
-          <?xml version="1.0" encoding="utf-8"?>
-          <playlist>
-            <generics pointer="0"/>
-            <curve path="/my/file/path/"/ attribute="value" ...>
-            <curve path="...">
-          </playlist>
+            items = reversed(c)
+        for item in items: 
+            if keeper_fn(item, *args, **kwargs) != True:
+                c.remove(item)
+        try: # attempt to maintain the same current item
+            c._index = c.index(self.current())
+        except ValueError:
+            c._index = 0
+        return c
+
+
+class Playlist (NoteIndexList):
+    """A :class:`NoteIndexList` of :class:`hooke.curve.Curve`\s.
+
+    Keeps a list of :attr:`drivers` for loading curves.
+    """
+    def __init__(self, drivers, name=None):
+        super(Playlist, self).__init__(name=name)
+        self.drivers = drivers
+        self._max_loaded = 100 # curves to hold in memory simultaneously.
+
+    def _set_ignored_attrs(self):
+        super(Playlist, self)._set_ignored_attrs()
+        self._ignored_attrs.extend([
+                '_item_ignored_attrs', '_item_default_attrs',
+                '_loaded'])
+        self._item_ignored_attrs = ['data']
+        self._item_default_attrs = {
+            'command_stack': [],
+            'driver': None,
+            'info': {},
+            'name': None,
+            }
+        self._loaded = [] # List of loaded curves, see :meth:`._setup_item`.
+
+    def _item_getstate(self, item):
+        assert isinstance(item, curve.Curve), type(item)
+        state = item.__getstate__()
+        for key in self._item_ignored_attrs:
+            if key in state:
+                del(state[key])
+        for key,value in self._item_default_attrs.items():
+            if key in state and state[key] == value:
+                del(state[key])
+        return state
+
+    def _item_setstate(self, state):
+        for key,value in self._item_default_attrs.items():
+            if key not in state:
+                state[key] = value
+        item = curve.Curve(path=None)
+        item.__setstate__(state)
+        return item
+
+    def append_curve_by_path(self, path, info=None, identify=True, hooke=None):
+        path = os.path.normpath(path)
+        c = curve.Curve(path, info=info)
+        c.set_hooke(hooke)
+        if identify == True:
+            c.identify(self.drivers)
+        self.append(c)
+        return c
+
+    def _setup_item(self, curve):
+        if curve != None and curve not in self._loaded:
+            if curve not in self:
+                self.append(curve)
+            if curve.driver == None:
+                c.identify(self.drivers)
+            if curve.data == None:
+                curve.load()
+            self._loaded.append(curve)
+            if len(self._loaded) > self._max_loaded:
+                oldest = self._loaded.pop(0)
+                oldest.unload()
+
+
+class FilePlaylist (Playlist):
+    """A file-backed :class:`Playlist`.
+    """
+    version = '0.2'
+
+    def __init__(self, drivers, name=None, path=None):
+        super(FilePlaylist, self).__init__(drivers, name)
+        self.path = self._base_path = None
+        self.set_path(path)
+        self._relative_curve_paths = True
+
+    def _set_ignored_attrs(self):
+        super(FilePlaylist, self)._set_ignored_attrs()
+        self._ignored_attrs.append('_digest')
+        self._digest = None
+
+    def __getstate__(self):
+        state = super(FilePlaylist, self).__getstate__()
+        assert 'version' not in state, state
+        state['version'] = self.version
+        return state
+
+    def __setstate__(self, state):
+        assert('version') in state, state
+        version = state.pop('version')
+        assert version == FilePlaylist.version, (
+            'invalid version %s (%s) != %s (%s)'
+            % (version, type(version),
+               FilePlaylist.version, type(FilePlaylist.version)))
+        super(FilePlaylist, self).__setstate__(state)
+
+    def _item_getstate(self, item):
+        state = super(FilePlaylist, self)._item_getstate(item)
+        if state.get('path', None) != None:
+            path = os.path.abspath(os.path.expanduser(state['path']))
+            if self._relative_curve_paths == True:
+                path = os.path.relpath(path, self._base_path)
+            state['path'] = path
+        return state
+
+    def _item_setstate(self, state):
+        item = super(FilePlaylist, self)._item_setstate(state)
+        if 'path' in state:
+            item.set_path(os.path.join(self._base_path, state['path']))
+        return item
+
+    def set_path(self, path):
+        if path == None:
+            if self._base_path == None:
+                self._base_path = os.getcwd()
+        else:
+            if not path.endswith('.hkp'):
+                path += '.hkp'
+            self.path = path
+            self._base_path = os.path.dirname(os.path.abspath(
+                os.path.expanduser(self.path)))
+            if self.name == None:
+                self.name = os.path.basename(path)
+
+    def append_curve_by_path(self, path, *args, **kwargs):
+        if self._base_path != None:
+            path = os.path.join(self._base_path, path)
+        super(FilePlaylist, self).append_curve_by_path(path, *args, **kwargs)
+
+    def is_saved(self):
+        return self.digest() == self._digest
+
+    def digest(self):
+        r"""Compute the sha1 digest of the flattened playlist
+        representation.
+
+        Examples
+        --------
+
+        >>> root_path = os.path.sep + 'path'
+        >>> p = FilePlaylist(drivers=[],
+        ...                  path=os.path.join(root_path, 'to','playlist'))
+        >>> p.info['note'] = 'An example playlist'
+        >>> c = curve.Curve(os.path.join(root_path, 'to', 'curve', 'one'))
+        >>> c.info['note'] = 'The first curve'
+        >>> p.append(c)
+        >>> c = curve.Curve(os.path.join(root_path, 'to', 'curve', 'two'))
+        >>> c.info['note'] = 'The second curve'
+        >>> p.append(c)
+        >>> p.digest()
+        '\xa1\x1ax\xb1|\x84uA\xe4\x1d\xbf`\x004|\x82\xc2\xdd\xc1\x9e'
+        """
+        string = self.flatten()
+        return hashlib.sha1(string).digest()
+
+    def flatten(self):
+        """Create a string representation of the playlist.
+
+        A playlist is a YAML document with the following minimal syntax::
+
+            version: '0.2'
+            items:
+            - path: picoforce.000
+            - path: picoforce.001
+
         Relative paths are interpreted relative to the location of the
         playlist file.
-        '''
-        #create the output playlist, a simple XML document
-        implementation = xml.dom.minidom.getDOMImplementation()
-        #create the document DOM object and the root element
-        self.xml = implementation.createDocument(None, 'playlist', None)
-        root = self.xml.documentElement
-
-        #save generics variables
-        playlist_generics = self.xml.createElement('generics')
-        root.appendChild(playlist_generics)
-        self.genericsDict['pointer'] = self.index
-        for key in self.genericsDict.keys():
-            self.xml.createAttribute(key)
-            playlist_generics.setAttribute(key, str(self.genericsDict[key]))
-            
-        #save curves and their attributes
-        for item in self.curves:
-            playlist_curve = self.xml.createElement('curve')
-            root.appendChild(playlist_curve)
-            for key in item.__dict__:
-                if not (key in self.hiddenAttributes):
-                    self.xml.createAttribute(key)
-                    playlist_curve.setAttribute(key, str(item.__dict__[key]))
-        self._saved = False
+
+        Examples
+        --------
+
+        >>> from .engine import CommandMessage
+
+        >>> root_path = os.path.sep + 'path'
+        >>> p = FilePlaylist(drivers=[],
+        ...                  path=os.path.join(root_path, 'to','playlist'))
+        >>> p.info['note'] = 'An example playlist'
+        >>> c = curve.Curve(os.path.join(root_path, 'to', 'curve', 'one'))
+        >>> c.info['note'] = 'The first curve'
+        >>> p.append(c)
+        >>> c = curve.Curve(os.path.join(root_path, 'to', 'curve', 'two'))
+        >>> c.info['attr with spaces'] = 'The second curve\\nwith endlines'
+        >>> c.command_stack.extend([
+        ...         CommandMessage('command A', {'arg 0':0, 'arg 1':'X'}),
+        ...         CommandMessage('command B', {'arg 0':1, 'arg 1':'Y'}),
+        ...         ])
+        >>> p.append(c)
+        >>> print p.flatten()  # doctest: +REPORT_UDIFF
+        # Hooke playlist version 0.2
+        _base_path: /path/to
+        _index: 0
+        _max_loaded: 100
+        _relative_curve_paths: true
+        drivers: []
+        info: {note: An example playlist}
+        items:
+        - info: {note: The first curve}
+          name: one
+          path: curve/one
+        - command_stack: !!python/object/new:hooke.command_stack.CommandStack
+            listitems:
+            - !!python/object:hooke.engine.CommandMessage
+              arguments: {arg 0: 0, arg 1: X}
+              command: command A
+            - !!python/object:hooke.engine.CommandMessage
+              arguments: {arg 0: 1, arg 1: Y}
+              command: command B
+          info: {attr with spaces: 'The second curve
+        <BLANKLINE>
+              with endlines'}
+          name: two
+          path: curve/two
+        name: playlist.hkp
+        path: /path/to/playlist.hkp
+        version: '0.2'
+        <BLANKLINE>
+        >>> p._relative_curve_paths = False
+        >>> print p.flatten()  # doctest: +REPORT_UDIFF
+        # Hooke playlist version 0.2
+        _base_path: /path/to
+        _index: 0
+        _max_loaded: 100
+        _relative_curve_paths: false
+        drivers: []
+        info: {note: An example playlist}
+        items:
+        - info: {note: The first curve}
+          name: one
+          path: /path/to/curve/one
+        - command_stack: !!python/object/new:hooke.command_stack.CommandStack
+            listitems:
+            - !!python/object:hooke.engine.CommandMessage
+              arguments: {arg 0: 0, arg 1: X}
+              command: command A
+            - !!python/object:hooke.engine.CommandMessage
+              arguments: {arg 0: 1, arg 1: Y}
+              command: command B
+          info: {attr with spaces: 'The second curve
+        <BLANKLINE>
+              with endlines'}
+          name: two
+          path: /path/to/curve/two
+        name: playlist.hkp
+        path: /path/to/playlist.hkp
+        version: '0.2'
+        <BLANKLINE>
+        """
+        yaml_string = yaml.dump(self.__getstate__(), allow_unicode=True)
+        return ('# Hooke playlist version %s\n' % self.version) + yaml_string
+
+    def from_string(self, string):
+        u"""Load a playlist from a string.
+
+        Examples
+        --------
+
+        Minimal example.
+
+        >>> string = '''# Hooke playlist version 0.2
+        ... version: '0.2'
+        ... items:
+        ... - path: picoforce.000
+        ... - path: picoforce.001
+        ... '''
+        >>> p = FilePlaylist(drivers=[],
+        ...                 path=os.path.join('/path', 'to', 'my', 'playlist'))
+        >>> p.from_string(string)
+        >>> for curve in p:
+        ...     print curve.path
+        /path/to/my/picoforce.000
+        /path/to/my/picoforce.001
+
+        More complicated example.
+
+        >>> string = '''# Hooke playlist version 0.2
+        ... _base_path: /path/to
+        ... _digest: null
+        ... _index: 1
+        ... _max_loaded: 100
+        ... _relative_curve_paths: true
+        ... info: {note: An example playlist}
+        ... items:
+        ... - info: {note: The first curve}
+        ...   path: curve/one
+        ... - command_stack: !!python/object/new:hooke.command_stack.CommandStack
+        ...      listitems:
+        ...      - !!python/object:hooke.engine.CommandMessage
+        ...        arguments: {arg 0: 0, arg 1: X}
+        ...        command: command A
+        ...      - !!python/object:hooke.engine.CommandMessage
+        ...        arguments: {arg 0: 1, arg 1: Y}
+        ...        command: command B
+        ...   info: {attr with spaces: 'The second curve
+        ... 
+        ...       with endlines'}
+        ...   name: two
+        ...   path: curve/two
+        ... name: playlist.hkp
+        ... path: /path/to/playlist.hkp
+        ... version: '0.2'
+        ... '''
+        >>> p = FilePlaylist(drivers=[],
+        ...                  path=os.path.join('path', 'to', 'my', 'playlist'))
+        >>> p.from_string(string)
+        >>> p._index
+        1
+        >>> p.info
+        {'note': 'An example playlist'}
+        >>> for curve in p:
+        ...     print curve.name, curve.path
+        one /path/to/curve/one
+        two /path/to/curve/two
+        >>> p[-1].info['attr with spaces']
+        'The second curve\\nwith endlines'
+        >>> type(p[-1].command_stack)
+        <class 'hooke.command_stack.CommandStack'>
+        >>> p[-1].command_stack  # doctest: +NORMALIZE_WHITESPACE
+        [<CommandMessage command A {arg 0: 0, arg 1: X}>,
+         <CommandMessage command B {arg 0: 1, arg 1: Y}>]
+        """
+        state = yaml.load(string)
+        self.__setstate__(state)
+
+    def save(self, path=None, makedirs=True):
+        """Saves the playlist to a YAML file.
+        """
+        self.set_path(path)
+        dirname = os.path.dirname(self.path) or '.'
+        if makedirs == True and not os.path.isdir(dirname):
+            os.makedirs(dirname)
+        with open(self.path, 'w') as f:
+            f.write(self.flatten())
+            self._digest = self.digest()
+
+    def load(self, path=None, identify=True, hooke=None):
+        """Load a playlist from a file.
+        """
+        self.set_path(path)
+        with open(self.path, 'r') as f:
+            text = f.read()
+        self.from_string(text)
+        self._digest = self.digest()
+        for curve in self:
+            curve.set_hooke(hooke)
+            if identify == True:
+                curve.identify(self.drivers)
+
+
+class Playlists (NoteIndexList):
+    """A :class:`NoteIndexList` of :class:`FilePlaylist`\s.
+    """
+    def __init__(self, *arg, **kwargs):
+        super(Playlists, self).__init__(*arg, **kwargs)
+
+    def _item_getstate(self, item):
+        assert isinstance(item, FilePlaylist), type(item)
+        return item.__getstate__()
+
+    def _item_setstate(self, state):
+        item = FilePlaylist(drivers=[])
+        item.__setstate__(state)
+        return item