12c6d266b68a958b51c95039df5f3679ba613d55
[hooke.git] / hooke / playlist.py
1 import copy\r
2 import os\r
3 import os.path\r
4 import xml.dom.minidom\r
5 \r
6 from . import hooke as hooke\r
7 from . import libhookecurve as lhc\r
8 from . import libhooke as lh\r
9 \r
10 class Playlist(object):\r
11     def __init__(self, drivers):\r
12         self._saved = False\r
13         self.count = 0\r
14         self.curves = []\r
15         self.drivers = drivers\r
16         self.path = ''\r
17         self.genericsDict = {}\r
18         self.hiddenAttributes = ['curve', 'driver', 'name', 'plots']\r
19         self.index = -1\r
20         self.name = 'Untitled'\r
21         self.plotPanel = None\r
22         self.plotTab = None\r
23         self.xml = None\r
24 \r
25     def add_curve(self, path, attributes={}):\r
26         curve = lhc.HookeCurve(path)\r
27         for key,value in attribures.items():\r
28             setattr(curve, key, value)\r
29         curve.identify(self.drivers)\r
30         curve.plots = curve.driver.default_plots()\r
31         self.curves.append(curve)\r
32         self._saved = False\r
33         self.count = len(self.curves)\r
34         return curve\r
35 \r
36     def close_curve(self, index):\r
37         if index >= 0 and index < self.count:\r
38             self.curves.remove(index)\r
39 \r
40     def filter_curves(self, keeper_fn=labmda curve:True):\r
41         playlist = copy.deepcopy(self)\r
42         for curve in reversed(playlist.curves):\r
43             if not keeper_fn(curve):\r
44                 playlist.curves.remove(curve)\r
45         try: # attempt to maintain the same active curve\r
46             playlist.index = playlist.curves.index(self.get_active_curve())\r
47         except ValueError:\r
48             playlist.index = 0\r
49         playlist._saved = False\r
50         playlist.count = len(playlist.curves)\r
51         return playlist\r
52 \r
53     def get_active_curve(self):\r
54         return self.curves[self.index]\r
55 \r
56     #TODO: do we need this?\r
57     def get_active_plot(self):\r
58         return self.curves[self.index].plots[0]\r
59 \r
60     def get_status_string(self):\r
61         if self.has_curves()\r
62             return '%s (%s/%s)' % (self.name, self.index + 1, self.count)\r
63         return 'The file %s does not contain any valid force curve data.' \\r
64             % self.name\r
65 \r
66     def has_curves(self):\r
67         if self.count > 0:\r
68             return True\r
69         return False\r
70 \r
71     def is_saved(self):\r
72         return self._saved\r
73 \r
74     def load(self, path):\r
75         '''\r
76         loads a playlist file\r
77         '''\r
78         self.path = path\r
79         self.name = os.path.basename(path)\r
80         playlist = lh.delete_empty_lines_from_xmlfile(path)\r
81         self.xml = xml.dom.minidom.parse(path)\r
82         # Strip blank spaces:\r
83         self._removeWhitespaceNodes()\r
84 \r
85         generics_list = self.xml.getElementsByTagName('generics')\r
86         curve_list = self.xml.getElementsByTagName('curve')\r
87         self._loadGenerics(generics_list)\r
88         self._loadCurves(curve_list)\r
89         self._saved = True\r
90 \r
91     def _removeWhitespaceNodes(self, root_node=None):\r
92         if root_node == None:\r
93             root_node = self.xml\r
94         for node in root_node.childNodes:\r
95             if node.nodeType == node.TEXT_NODE and node.data.strip() == '':\r
96                 root_node.removeChild(node) # drop this whitespace node\r
97             else:\r
98                 _removeWhitespaceNodes(root_node=node) # recurse down a level\r
99 \r
100     def _loadGenerics(self, generics_list, clear=True):\r
101         if clear:\r
102             self.genericsDict = {}\r
103         #populate generics\r
104         generics_list = self.xml.getElementsByTagName('generics')\r
105         for generics in generics_list:\r
106             for attribute in generics.attributes.keys():\r
107                 self.genericsDict[attribute] = generics_list[0].getAttribute(attribute)\r
108         if self.genericsDict.has_key('pointer'):\r
109             index = int(self.genericsDict['pointer'])\r
110             if index >= 0 and index < len(self.curves):\r
111                 self.index = index\r
112             else:\r
113                 index = 0\r
114 \r
115     def _loadCurves(self, curve_list, clear=True):\r
116         if clear:\r
117             self.curves = []\r
118         #populate playlist with curves\r
119         for curve in curve_list:\r
120             #rebuild a data structure from the xml attributes\r
121             curve_path = lh.get_file_path(element.getAttribute('path'))\r
122             #extract attributes for the single curve\r
123             attributes = dict([(k,curve.getAttribute(k))\r
124                                for k in curve.attributes.keys()])\r
125             attributes.pop('path')\r
126             curve = self.add_curve(os.path.join(path, curve_path), attributes)\r
127             if curve is not None:\r
128                 for plot in curve.plots:\r
129                     curve.add_data('raw', plot.vectors[0][0], plot.vectors[0][1], color=plot.colors[0], style='plot')\r
130                     curve.add_data('raw', plot.vectors[1][0], plot.vectors[1][1], color=plot.colors[1], style='plot')\r
131 \r
132     def next(self):\r
133         self.index += 1\r
134         if self.index > self.count - 1:\r
135             self.index = 0\r
136 \r
137     def previous(self):\r
138         self.index -= 1\r
139         if self.index < 0:\r
140             self.index = self.count - 1\r
141 \r
142     def reset(self):\r
143         if self.has_curves():\r
144             self.index = 0\r
145         else:\r
146             self.index = None\r
147 \r
148     def save(self, path):\r
149         '''\r
150         saves the playlist in a XML file.\r
151         '''\r
152         try:\r
153             output_file = file(path, 'w')\r
154         except IOError, e:\r
155             #TODO: send message\r
156             print 'Cannot save playlist: %s' % e\r
157             return\r
158         self.xml.writexml(output_file, indent='\n')\r
159         output_file.close()\r
160         self._saved = True\r
161 \r
162     def set_XML(self):\r
163         '''\r
164         Creates an initial playlist from a list of files.\r
165         A playlist is an XML document with the following syntax:\r
166           <?xml version="1.0" encoding="utf-8"?>\r
167           <playlist>\r
168             <generics pointer="0"/>\r
169             <curve path="/my/file/path/"/ attribute="value" ...>\r
170             <curve path="...">\r
171           </playlist>\r
172         Relative paths are interpreted relative to the location of the\r
173         playlist file.\r
174         '''\r
175         #create the output playlist, a simple XML document\r
176         implementation = xml.dom.minidom.getDOMImplementation()\r
177         #create the document DOM object and the root element\r
178         self.xml = implementation.createDocument(None, 'playlist', None)\r
179         root = self.xml.documentElement\r
180 \r
181         #save generics variables\r
182         playlist_generics = self.xml.createElement('generics')\r
183         root.appendChild(playlist_generics)\r
184         self.genericsDict['pointer'] = self.index\r
185         for key in self.genericsDict.keys():\r
186             self.xml.createAttribute(key)\r
187             playlist_generics.setAttribute(key, str(self.genericsDict[key]))\r
188             \r
189         #save curves and their attributes\r
190         for item in self.curves:\r
191             playlist_curve = self.xml.createElement('curve')\r
192             root.appendChild(playlist_curve)\r
193             for key in item.__dict__:\r
194                 if not (key in self.hiddenAttributes):\r
195                     self.xml.createAttribute(key)\r
196                     playlist_curve.setAttribute(key, str(item.__dict__[key]))\r
197         self._saved = False\r