Added tip-to-root/root-to-tip flexibility to hooke.util.graph
[hooke.git] / hooke / util / graph.py
1 # COPYRIGHT
2
3 """Define :class:`Graph`, a directed, acyclic graph structure.
4 :class:`Graph`\s are composed of :class:`Node`\s, also defined by this
5 module.
6 """
7
8 import bisect
9
10 class CyclicGraphError (ValueError):
11     pass
12
13 class Node (list):
14     """A node/element in a graph.
15
16     Contains a list of the node's parents, and stores the node's
17     `data`.
18
19     Examples
20     --------
21
22     >>> a = Node(data='a')
23     >>> b = Node(parents=[a], data='b')
24     >>> c = Node(parents=[a], data='c')
25     >>> d = Node(parents=[b, c], data='d')
26     >>> str(d)
27     'd'
28
29     We can list all of a node's ancestors.
30
31     >>> print [node for node in d.ancestors()]
32     [b, c, a]
33     >>> print [node for node in d.ancestors(depth_first=True)]
34     [b, a, c]
35
36     Ancestors works with cycles.
37
38     >>> a.append(d)
39     >>> print [node for node in d.ancestors()]
40     [b, c, a, d]
41
42     We can find the cycle path.
43
44     >>> print d.parent_path(d)
45     [b, a, d]
46
47     After a run through :meth:`Graph.set_children`, we can also
48     list children
49
50     >>> g = Graph([a, b, c, d])
51     >>> g.set_children()
52     >>> print a.children
53     [b, c]
54
55     And descendents.
56
57     >>> print [node for node in a.descendents(depth_first=True)]
58     [b, d, a, c]
59     """
60     def __init__(self, parents=[], data=None):
61         list.__init__(self, parents)
62         self.data = data
63         self.children = []
64
65     def __cmp__(self, other):
66         return -cmp(self.data, other.data)
67     def __eq__(self, other):
68         return self.__cmp__(other) == 0
69     def __ne__(self, other):
70         return self.__cmp__(other) != 0
71     def __lt__(self, other):
72         return self.__cmp__(other) < 0
73     def __gt__(self, other):
74         return self.__cmp__(other) > 0
75     def __le__(self, other):
76         return self.__cmp__(other) <= 0
77     def __ge__(self, other):
78         return self.__cmp__(other) >= 0
79
80     def __str__(self):
81         return str(self.data)
82     def __repr__(self):
83         return self.__str__()
84
85     def traverse(self, next, depth_first=False):
86         """Iterate through all nodes returned by `next(node)`.
87
88         Will only yield each traversed node once, even in the case of
89         diamond inheritance, etc.
90
91         Breadth first by default.  Set `depth_first==True` for a
92         depth first search.
93         """
94         stack = list(next(self))
95         popped = []
96         while len(stack) > 0:
97             node = stack.pop(0)
98             if node in popped:
99                 continue
100             popped.append(node)
101             yield node
102             if depth_first == True:
103                 for target in reversed(next(node)):
104                     stack.insert(0, target)
105             else:
106                 stack.extend(next(node))
107
108     def ancestors(self, depth_first=False):
109         """Generate all ancestors.
110
111         This is a small wrapper around :meth:`traverse`.
112         """
113         next = lambda node : node  # list node's parents
114         for node in self.traverse(next=next, depth_first=depth_first):
115             yield node
116
117     def descendents(self, depth_first=False):
118         """Generate all decendents.
119
120         This is a small wrapper around :meth:`traverse`.
121         """
122         next = lambda node : node.children
123         for node in self.traverse(next=next, depth_first=depth_first):
124             yield node
125
126     def path(self, next, node):
127         """Return the shortest list of nodes connecting `self` to
128         `node` via `next(node)`.
129         """
130         if node in self:
131             return [node]
132         stack = list(next(self))
133         paths = dict((id(n), [n]) for n in stack)
134         while len(stack) > 0:
135             n = stack.pop(0)
136             n_path = paths[id(n)]
137             for target in next(n):
138                 if id(target) in paths:
139                     continue
140                 t_path = list(n_path)
141                 t_path.append(target)
142                 if id(target) == id(node):
143                     return t_path
144                 stack.append(target)
145                 paths[id(target)] = t_path
146
147     def parent_path(self, node):
148         """Return the shortest list of nodes connecting `self` to
149         its parent `node`.
150
151         This is a small wrapper around :meth:`path`.
152         """
153         next = lambda node : node  # list node's parents
154         return self.path(next, node)
155
156     def child_path(self, node):
157         """Return the shortest list of nodes connecting `self` to
158         its child `node`.
159
160         This is a small wrapper around :meth:`path`.
161         """
162         next = lambda node : node.children
163         return self.path(next, node)
164
165
166 class GraphRow (object):
167     """Represent the state of a single row in a graph.
168
169     Generated by :class:`GraphRowGenerator`, printed with
170     :class:`GraphRowPrinter`.
171
172     :attr:`node` is the active node and :attr:`active` is its branch
173     column index.  :attr:`width` is the number of current branch
174     columns.
175
176     :attr:`born`, :attr:`dead`, and :attr:`inherited` are lists of
177     `(branch_column_index, target_node)` pairs.  `dead` lists nodes
178     from previous rows whose branches complete on this row,
179     `inherited` lists nodes from previous rows whose branches continue
180     through this row, and `born` list nodes whose branches start on
181     this row.
182     """
183     def __init__(self, node, active=-1, dead=None, inherited=None, born=None,
184                  tip_to_root=False):
185         self.node = node
186         self.active = active
187         if dead == None:
188             dead = []
189         self.dead = dead
190         if inherited == None:
191             inherited = []
192         self.inherited = inherited
193         if born == None:
194             born = []
195         self.born = born
196         self.tip_to_root = tip_to_root
197
198 class GraphRowPrinter (object):
199     """Customizable printing for :class:`GraphRow`.
200
201     The string rendering can be customized by changing :attr:`chars`.
202     Control over the branch columns:
203
204     ================= ===========================================
205     `node: ...`       the active (most recently inserted) node
206     `split/join: ...` branching/merging runs from the active node
207     `run: connected`  connect a branch to its eventual node
208     `run: blank`      place-holder for extinct runs
209     ================= ===========================================
210
211     Branch columns are seperated by separator columns:
212
213     ================= =======================================================
214     `sep: split/join` separate split/join branch columns from the active node
215     `sep: default`    separate all remaining branch columns
216     ================= =======================================================
217     """
218     def __init__(self, chars=None):
219         if chars == None:
220             chars = {
221                 'node: both tip and root': 'b',
222                 'node: root': 'r',
223                 'node: tip': 't',
224                 'node: regular': '*',
225                 'split/join: born and died left of active': '>',
226                 'split/join: born and died right of active': '<',
227                 'split/join: born left of active': '/',
228                 'split/join: born right of active': '\\',
229                 'split/join: died left of active': '\\',
230                 'split/join: died right of active': '/',
231                 'run: blank': ' ',
232                 'run: connected': '|',
233                 'sep: split/join': '-',
234                 'sep: default': ' ',               
235                 }
236         self.chars = chars
237     def __call__(self, graph_row):
238         """Render the :class:`GraphRow` instance `graph_row` as a
239         string.
240         """
241         dead = [i for i,node in graph_row.dead]
242         inherited = [i for i,node in graph_row.inherited]
243         born = [i for i,node in graph_row.born]
244         right_connect = max(graph_row.active,
245                             max(born+[-1]), # +[-1] protects against empty born
246                             max(dead+[-1]))
247         left_connect = min(graph_row.active,
248                            min(born+[right_connect]),
249                            min(dead+[right_connect]))
250         max_col = max(right_connect, max(inherited+[-1]))
251         string = []
252         for i in range(max_col + 1):
253             # Get char, the node or branch column character.
254             if i == graph_row.active:
255                 if len(born) == 0:
256                     if len(dead) == 0:
257                         char = self.chars['node: both tip and root']
258                     elif graph_row.tip_to_root == True:
259                         # The dead are children
260                         char = self.chars['node: root']
261                     else: # The dead are parents
262                         char = self.chars['node: tip']
263                 elif len(dead) == 0:
264                     if graph_row.tip_to_root == True:
265                         # The born are parents
266                         char = self.chars['node: tip']
267                     else: # The born are children
268                         char = self.chars['node: root']
269                 else:
270                     char = self.chars['node: regular']
271             elif i in born:
272                 if i in dead: # born and died
273                     if i < graph_row.active:
274                         char = self.chars[
275                             'split/join: born and died left of active']
276                     else:
277                         char = self.chars[
278                             'split/join: born and died right of active']
279                 else: # just born
280                     if i < graph_row.active:
281                         char = self.chars['split/join: born left of active']
282                     else:
283                         char = self.chars['split/join: born right of active']
284             elif i in dead: # just died
285                 if i < graph_row.active:
286                     char = self.chars['split/join: died left of active']
287                 else:
288                     char = self.chars['split/join: died right of active']
289             elif i in inherited:
290                 char = self.chars['run: connected']
291             else:
292                 char = self.chars['run: blank']
293             # Get sep, the separation character.
294             if i < left_connect or i >= right_connect:
295                 sep = self.chars['sep: default']
296             else:
297                 sep = self.chars['sep: split/join']
298                 if char == self.chars['run: blank']:
299                     char = self.chars['sep: split/join']
300             string.extend([char, sep])
301         return ''.join(string)[:-1] # [-1] strips final sep
302
303 class GraphRowGenerator (list):
304     """A :class:`GraphRow` generator.
305
306     Contains a list of :class:`GraphRow`\s (added with
307     :meth:`insert`(:class:`hooke.util.graph.Node`)).  You should
308     generate a graph with repeated calls::
309
310         tip_to_root = True
311         g = GraphRowGenerator(tip_to_root=tip_to_root)
312         p = GraphRowPrinter(tip_to_root=tip_to_root)
313         for node in nodes:
314             g.insert(node)
315             print p(g[-1])
316
317     For the split/join branch columns, "born" and "dead" are defined
318     from the point of view of `GraphRow`.  For root-to-tip ordering
319     (`tip_to_root==False`, the default), "born" runs are determined
320     by the active node's children (which have yet to be printed) and
321     "dead" runs by its parents (which have been printed).  If
322     `tip_to_root==True`, swap "children" and "parents" in the above
323     sentence.
324     """
325     def __init__(self, tip_to_root=False):
326         list.__init__(self)
327         self.tip_to_root = tip_to_root
328     def insert(self, node):
329         """Add a new node to the graph.
330
331         If `tip_to_root==True`, nodes should be inserted in
332         tip-to-root topological order (i.e. node must be inserted
333         before any of its parents).
334
335         If `tip_to_root==False`, nodes must be inserted before any
336         of their children.
337         """
338         if len(self) == 0:
339             previous = GraphRow(node=None, active=-1)
340         else:
341             previous = self[-1]
342         current = GraphRow(node=node, active=-1, tip_to_root=self.tip_to_root)
343         if self.tip_to_root == True: # children die, parents born
344             dead_nodes = list(current.node.children)
345             born_nodes = list(current.node)
346         else: # root-to-tip: parents die, children born
347             dead_nodes = list(current.node)
348             born_nodes = list(current.node.children)
349         # Mark the dead and inherited branch columns
350         for i,node in previous.inherited + previous.born:
351             if node in dead_nodes or node == current.node:
352                 current.dead.append((i, node))
353             else:
354                 current.inherited.append((i, node))
355         # Place born and active branch columns
356         num_born = max(len(born_nodes), 1) # 1 to ensure slot for active node
357         remaining = num_born # number of nodes left to place
358         used_slots = [i for i,n in current.inherited]
359         old_max = max(used_slots+[-1]) # +[-1] in case used_slots is empty
360         slots = sorted([i for i in range(old_max+1) if i not in used_slots])
361         remaining -= len(slots)
362         slots.extend(range(old_max+1, old_max+1+remaining))
363         current.active = slots[0]
364         current.born = zip(slots, born_nodes)
365         # TODO: sharing branches vs. current 1 per child
366         self.append(current)
367
368
369 class Graph (list):
370     """A directed, acyclic graph structure.
371
372     Contains methods for sorting and printing graphs.
373
374     Examples
375     --------
376
377     >>> class Nodes (object): pass
378     >>> n = Nodes()
379     >>> for char in ['a','b','c','d','e','f','g','h','i']:
380     ...     setattr(n, char, Node(data=char))
381     >>> n.b.append(n.a)
382     >>> n.c.append(n.a)
383     >>> n.d.append(n.a)
384     >>> n.e.extend([n.b, n.c, n.d])
385     >>> n.f.append(n.e)
386     >>> n.g.append(n.e)
387     >>> n.h.append(n.e)
388     >>> n.i.extend([n.f, n.g, n.h])
389     >>> g = Graph([n.a,n.b,n.c,n.d,n.e,n.f,n.g,n.h,n.i])
390     >>> g.topological_sort(tip_to_root=True)
391     >>> print [node for node in g]
392     [i, h, g, f, e, d, c, b, a]
393     >>> print g.ascii_graph()
394     r-\-\ a
395     | | * b
396     | * | c
397     * | | d
398     *-<-< e
399     | | * f
400     | * | g
401     * | | h
402     t-/-/ i
403     >>> print g.ascii_graph(tip_to_root=True)
404     t-\-\ i
405     | | * h
406     | * | g
407     * | | f
408     *-<-< e
409     | | * d
410     | * | c
411     * | | b
412     r-/-/ a
413
414     >>> for char in ['a','b','c','d','e','f','g','h']:
415     ...     setattr(n, char, Node(data=char))
416     >>> n.b.append(n.a)
417     >>> n.c.append(n.b)
418     >>> n.d.append(n.a)
419     >>> n.e.append(n.d)
420     >>> n.f.extend([n.b, n.d])
421     >>> n.g.extend([n.e, n.f])
422     >>> n.h.extend([n.c, n.g])
423     >>> g = Graph([n.a,n.b,n.c,n.d,n.e,n.f,n.g,n.h])
424     >>> print g.ascii_graph(tip_to_root=True)
425     t-\ h
426     | *-\ g
427     | | *-\ f
428     | * | | e
429     | *-|-/ d
430     * | | c
431     *-|-/ b
432     r-/ a
433
434     >>> for char in ['a','b','c','d','e','f','g','h','i']:
435     ...     setattr(n, char, Node(data=char))
436     >>> for char in ['a', 'b','c','d','e','f','g','h']:
437     ...     nx = getattr(n, char)
438     ...     n.i.append(nx)
439     >>> g = Graph([n.a,n.b,n.c,n.d,n.e,n.f,n.g,n.h,n.i])
440     >>> print g.ascii_graph(tip_to_root=True)
441     t-\-\-\-\-\-\-\ i
442     | | | | | | | r h
443     | | | | | | r g
444     | | | | | r f
445     | | | | r e
446     | | | r d
447     | | r c
448     | r b
449     r a
450
451     >>> for char in ['a','b','c','d','e','f','g','h','i']:
452     ...     setattr(n, char, Node(data=char))
453     >>> for char in ['b','c','d','e','f','g','h','i']:
454     ...     nx = getattr(n, char)
455     ...     nx.append(n.a)
456     >>> g = Graph([n.a,n.b,n.c,n.d,n.e,n.f,n.g,n.h,n.i])
457     >>> print g.ascii_graph(tip_to_root=True)
458     t i
459     | t h
460     | | t g
461     | | | t f
462     | | | | t e
463     | | | | | t d
464     | | | | | | t c
465     | | | | | | | t b
466     r-/-/-/-/-/-/-/ a
467
468     >>> for char in ['a','b','c','d','e','f','g','h','i']:
469     ...     setattr(n, char, Node(data=char))
470     >>> n.d.append(n.a)
471     >>> n.e.extend([n.a, n.c])
472     >>> n.f.extend([n.c, n.d, n.e])
473     >>> n.g.extend([n.b, n.e, n.f])
474     >>> n.h.extend([n.a, n.c, n.d, n.g])
475     >>> n.i.extend([n.a, n.b, n.c, n.g])
476     >>> g = Graph([n.a,n.b,n.c,n.d,n.e,n.f,n.g,n.h,n.i])
477     >>> print g.ascii_graph(tip_to_root=True)
478     t-\-\-\ i
479     | | | | t-\-\-\ h
480     | | | *-|-|-|-<-\ g
481     | | | | | | | | *-\-\ f
482     | | | | | | | *-|-|-< e
483     | | | | | | *-|-|-/ | d
484     | | r-|-|-/-|-|-/---/ c
485     | r---/ |   | | b
486     r-------/---/-/ a
487
488     Ok, enough pretty graphs ;).  Here's an example of cycle
489     detection.
490
491     >>> for char in ['a','b','c','d']:
492     ...     setattr(n, char, Node(data=char))
493     >>> n.b.append(n.a)
494     >>> n.c.append(n.a)
495     >>> n.d.extend([n.b, n.c])
496     >>> n.a.append(n.d)
497     >>> g = Graph([n.a,n.b,n.c,n.d])
498     >>> g.check_for_cycles()
499     Traceback (most recent call last):
500       ...
501     CyclicGraphError: cycle detected:
502       a
503       d
504       b
505       a
506     """
507     def set_children(self):
508         """Fill out each node's :attr:`Node.children` list.
509         """
510         for node in self:
511             for parent in node:
512                 if node not in parent.children:
513                     parent.children.append(node)
514
515     def topological_sort(self, tip_to_root=False):
516         """Algorithm from git's commit.c `sort_in_topological_order`_.
517
518         Default ordering is root-to-tip.  Set `tip_to_root=True` for
519         tip-to-root.
520
521         In situations where topological sorting is ambiguous, the
522         nodes are sorted using the node comparison functions (__cmp__,
523         __lt__, ...).  If `tip_to_root==True`, the inverse
524         comparison functions are used.
525
526         .. _sort_in_topological_order:
527           http://git.kernel.org/?p=git/git.git;a=blob;f=commit.c;h=731191e63bd39a89a8ea4ed0390c49d5605cdbed;hb=HEAD#l425
528         """
529         # sort tip-to-root first, then reverse if neccessary
530         for node in self:
531             node._outcount = 0
532         for node in self:
533             for parent in node:
534                 parent._outcount += 1
535         tips = sorted([node for node in self if node._outcount == 0])
536         orig_len = len(self)
537         del self[:]
538         while len(tips) > 0:
539             node = tips.pop(0)
540             for parent in node:
541                 parent._outcount -= 1
542                 if parent._outcount == 0:
543                     bisect.insort(tips, parent)
544             node._outcount = -1
545             self.append(node)
546         final_len = len(self)
547         if final_len != orig_len:
548             raise CyclicGraphError(
549                 '%d of %d elements not reachable from tips'
550                 % (orig_len - final_len, orig_len))
551         if tip_to_root == False:
552             self.reverse()
553
554     def check_for_cycles(self):
555         """Check for cyclic references.
556         """
557         for node in self:
558             if node in node.ancestors():
559                 path = node.parent_path(node)
560                 raise CyclicGraphError(
561                     'cycle detected:\n  %s'
562                     % '\n  '.join([repr(node)]+[repr(node) for node in path]))
563
564     def graph_rows(self, tip_to_root=False):
565         """Generate a sequence of (`graph_row`, `node`) tuples.
566
567         Preforms :meth:`set_children` and :meth:`topological_sort`
568         internally.
569         """
570         graph_row_generator = GraphRowGenerator(tip_to_root=tip_to_root)
571         self.set_children()
572         self.topological_sort(tip_to_root=tip_to_root)
573         for node in self:
574             graph_row_generator.insert(node)
575             yield (graph_row_generator[-1], node)
576
577     def ascii_graph(self, graph_row_printer=None, string_fn=str,
578                     tip_to_root=False):
579         """Print an ascii graph on the left with `string_fn(node)` on
580         the right.  If `graph_row_printer` is `None`, a default
581         instance of :class:`GraphRowPrinter` will be used.
582
583         See the class docstring for example output.
584         """
585         if graph_row_printer == None:
586             graph_row_printer = GraphRowPrinter()
587         graph = []
588         for row,node in self.graph_rows(tip_to_root=tip_to_root):
589             graph.append('%s %s' % (graph_row_printer(row), string_fn(node)))
590         return '\n'.join(graph)