Add Sudoku._slice_completion_solver() to sudoku.py.
[parallel_computing.git] / assignments / archive / sudoku / soln / sudoku.py
1 #!/usr/bin/env python
2
3 import itertools
4
5 import numpy
6
7
8 TEST_PUZZLE_STRING = '\n'.join([
9         '5 3 -   - 7 -   - - -',
10         '6 - -   1 9 5   - - -',
11         '- 9 8   - - -   - 6 -',
12         '',
13         '8 - -   - 6 -   - - 3',
14         '4 - -   8 - 3   - - 1',
15         '7 - -   - 2 -   - - 6',
16         '',
17         '- 6 -   - - -   2 8 -',
18         '- - -   4 1 9   - - 5',
19         '- - -   - 8 -   - 7 9',
20         ])
21
22
23 def power_set(iterable):
24     """Return the power set of `iterable`.
25
26     >>> for i in power_set([1,2,3]):
27     ...     print i
28     ()
29     (1,)
30     (2,)
31     (3,)
32     (1, 2)
33     (1, 3)
34     (2, 3)
35     (1, 2, 3)
36     """
37     s = list(iterable)
38     return itertools.chain.from_iterable(
39         itertools.combinations(s, r) for r in range(len(s)+1))
40
41
42 class Sudoku (object):
43     """
44     >>> s = Sudoku()
45     >>> s.load(TEST_PUZZLE_STRING)
46     >>> s._num_solved()
47     30
48     >>> s.solve()
49     >>> print s.status
50     solved in 408 steps
51     >>> print s.dump()
52     5 3 4   6 7 8   9 1 2
53     6 7 2   1 9 5   3 4 8
54     1 9 8   3 4 2   5 6 7
55     <BLANKLINE>
56     8 5 9   7 6 1   4 2 3
57     4 2 6   8 5 3   7 9 1
58     7 1 3   9 2 4   8 5 6
59     <BLANKLINE>
60     9 6 1   5 3 7   2 8 4
61     2 8 7   4 1 9   6 3 5
62     3 4 5   2 8 6   1 7 9
63     >>> s._num_solved()
64     81
65
66     >>> s._row(0)
67     array([5, 3, 4, 6, 7, 8, 9, 1, 2])
68     >>> s._col(0)
69     array([5, 6, 1, 8, 4, 7, 9, 2, 3])
70     >>> s._cell(0,0)
71     array([[5, 3, 4],
72            [6, 7, 2],
73            [1, 9, 8]])
74     """
75     def __init__(self):
76         self._puzzle = numpy.zeros((9,9), dtype=numpy.int)
77         self._empty = 0
78         self._external_empty = '-'
79         self.status = None
80         self.solvers = [self._direct_elimination_solver,
81                         self._slice_completion_solver]
82
83     def load(self, text):
84         row = 0
85         for line in text.splitlines():
86             if len(line) == 0 or line.startswith('#'):
87                 continue
88             assert row < 9, row
89             for col,value in enumerate(line.split()):
90                 assert col < 9, col
91                 self._puzzle[row,col] = self._convert_to_internal_value(value)
92             row += 1
93         self.status = 'loaded puzzle'
94
95     def _convert_to_internal_value(self, value):
96         """
97         >>> s = Sudoku()
98         >>> s._convert_to_internal_value(s._external_empty)
99         0
100         >>> s._convert_to_internal_value('0')
101         Traceback (most recent call last):
102           ...
103         AssertionError: 0
104         >>> s._convert_to_internal_value('1')
105         1
106         >>> s._convert_to_internal_value('9')
107         9
108         >>> s._convert_to_internal_value('10')
109         Traceback (most recent call last):
110           ...
111         AssertionError: 10
112         """
113         if value == self._external_empty:
114             value = self._empty
115         else:
116             value = int(value)
117             assert value >= 1 and value <= 9, value
118         return value
119
120     def dump(self):
121         lines = []
122         for row in range(9):
123             if row in [3, 6]:
124                 lines.append('')  # blank rows between cells
125             line = []
126             for col in range(9):
127                 if col in [3, 6]:
128                     line.append(' ')  # blank columns between cells
129                 line.append(self._convert_to_external_value(
130                         self._puzzle[row,col]))
131             lines.append(' '.join(line))
132         return '\n'.join(lines)
133
134     def _convert_to_external_value(self, value):
135         """
136         >>> s = Sudoku()
137         >>> s._convert_to_external_value(s._empty)
138         '-'
139         >>> s._convert_to_external_value(1)
140         '1'
141         >>> s._convert_to_external_value(9)
142         '9'
143         """
144         if value == self._empty:
145             value = self._external_empty
146         else:
147             assert value >= 1 and value <= 9, value
148             value = str(value)
149         return value
150
151     def _row(self, row):
152         return self._puzzle[row,:]
153
154     def _col(self, col):
155         return self._puzzle[:,col]
156
157     def _cell_bounds(self, cell_row, cell_col):
158         ri = cell_row * 3
159         rf = ri + 3
160         ci = cell_col * 3
161         cf = ci + 3
162         return (ri, rf, ci, cf)
163
164     def _cell(self, cell_row, cell_col):
165         ri,rf,ci,cf = self._cell_bounds(cell_row, cell_col)
166         return self._puzzle[ri:rf,ci:cf]
167
168     def _slices(self):
169         """
170         >>> s = Sudoku()
171         >>> s.load(TEST_PUZZLE_STRING)
172         >>> for type,slice,index in s._slices():
173         ...     print type,slice,index  # doctest: +ELLIPSIS
174         row [5 3 0 0 7 0 0 0 0] 0
175         ...
176         row [0 0 0 0 8 0 0 7 9] 8
177         col [5 6 0 8 4 7 0 0 0] 0
178         ...
179         col [0 0 0 3 1 6 0 5 9] 8
180         cell [[5 3 0]
181          [6 0 0]
182          [0 9 8]] (0, 0)
183         ...
184         cell [[2 8 0]
185          [0 0 5]
186          [0 7 9]] (2, 2)
187         """
188         for row in range(9):
189             yield ('row', self._row(row), row)
190         for col in range(9):
191             yield ('col', self._col(col), col)
192         for cell_row in range(3):
193             for cell_col in range(3):
194                 yield ('cell', self._cell(cell_row, cell_col),
195                        (cell_row, cell_col))
196
197     def _point_to_cell_coords(self, row, col):
198         """
199         >>> s = Sudoku()
200         >>> s._point_to_cell_coords(4, 6)
201         (1, 2, 3)
202
203         The point in question:
204
205         - - -   - - -   - - -
206         - 0 -   - 1 -   - 2 -
207         - - -   - - -   - - -
208         
209         - - -   - - -   0 1 2
210         - 1 -   - - -   * - -
211         - - -   - - -   - - -
212         
213         - - -   - - -   - - -
214         - - -   - - -   - - -
215         - - -   - - -   - - -
216         """
217         cell_row_residual = row % 3
218         cell_col_residual = col % 3
219         return (row/3, col/3,
220                 cell_row_residual*3 + cell_col_residual)
221
222     def _cell_to_point_coords(self, cell_row, cell_col, i):
223         """
224         >>> s = Sudoku()
225         >>> s._cell_to_point_coords(1, 2, 3)
226         (4, 6)
227
228         The point in question:
229
230         0 1 2   3 4 5   6 - -
231         1 - -   - - -   - - -
232         2 - -   - - -   - - -
233         
234         3 - -   - - -   - - -
235         4 - -   - - -   * - -
236         - - -   - - -   - - -
237         
238         - - -   - - -   - - -
239         - - -   - - -   - - -
240         - - -   - - -   - - -
241         """
242         row = cell_row * 3 + (i / 3)
243         col = cell_col * 3 + (i % 3)
244         return (row, col)
245
246     def _nonempty(self, values):
247         return [x for x in values if x != self._empty]
248
249     def _is_valid(self):
250         """
251         >>> s = Sudoku()
252         >>> s.load(TEST_PUZZLE_STRING)
253         >>> s._is_valid()
254         True
255
256         Test an invalid row.
257
258         >>> s._puzzle[0,3] = 5
259         >>> s._is_valid()
260         False
261         >>> s._puzzle[0,3] = s._empty
262
263         Test an invalid column.
264
265         >>> s._puzzle[8,0] = 5
266         >>> s._is_valid()
267         False
268         >>> s._puzzle[8,0] = s._empty
269
270         Test and invalid cell.
271
272         >>> s._puzzle[2,0] = 3
273         >>> s._is_valid()
274         False
275         >>> s._puzzle[2,0] = s._empty
276         """
277         for type,slice,index in self._slices():
278             values = self._nonempty(slice.flat)
279             if len(values) != len(set(values)):
280                 return False
281         return True
282
283     def _num_solved(self):
284         return len(self._nonempty(self._puzzle.flatten()))
285
286     def solve(self):
287         actions = 0
288         trials = self._setup_trials()
289         while True:
290             start_actions = actions
291             for solver in self.solvers:
292                 acts,trials = solver(trials)
293                 self._apply_trials(trials)
294                 actions += acts
295                 if acts > 0:
296                     break  # don't use slow solvers unless they're required
297             if self._num_solved() == 81:
298                 self.status = 'solved in %d steps' % actions
299                 return  # puzzle solved
300             elif actions == start_actions:
301                 self.status = 'aborted after %d steps' % actions
302                 return  # puzzle too hard to solve
303
304     def _setup_trials(self):
305         trials = numpy.zeros((9,9,9), dtype=numpy.int)
306         for row in range(9):
307             for col in range(9):
308                 if self._puzzle[row,col] == self._empty:
309                     trials[row,col,:] = range(1,10)
310                 else:
311                     x = self._puzzle[row,col]
312                     trials[row,col,x-1] = x
313         return trials
314
315     def _apply_trials(self, trials):
316         for row in range(9):
317             for col in range(9):
318                 if len(self._nonempty(trials[row,col,:])) == 1:
319                     self._puzzle[row][col] = (
320                         self._nonempty(trials[row,col,:])[0])
321                     assert self._is_valid(), (
322                         'error setting [%d,%d] to %d'
323                         % (row, col, self._puzzle[row][col]))
324
325     def _trial_slice(self, trials, type, index):
326         """Return a slice from the trials array.
327
328         >>> s = Sudoku()
329         >>> s.load(TEST_PUZZLE_STRING)
330         >>> trials = s._setup_trials()
331         >>> t = s._trial_slice(trials, 'row', 0)
332         >>> t  # doctest: +REPORT_UDIFF
333         array([[0, 0, 0, 0, 5, 0, 0, 0, 0],
334                [0, 0, 3, 0, 0, 0, 0, 0, 0],
335                [1, 2, 3, 4, 5, 6, 7, 8, 9],
336                [1, 2, 3, 4, 5, 6, 7, 8, 9],
337                [0, 0, 0, 0, 0, 0, 7, 0, 0],
338                [1, 2, 3, 4, 5, 6, 7, 8, 9],
339                [1, 2, 3, 4, 5, 6, 7, 8, 9],
340                [1, 2, 3, 4, 5, 6, 7, 8, 9],
341                [1, 2, 3, 4, 5, 6, 7, 8, 9]])
342
343         For `row` and `column` slices, the original `trials` array
344         responds to changes in `t`.
345
346         >>> trials[0,2,:]
347         array([1, 2, 3, 4, 5, 6, 7, 8, 9])
348         >>> t[2,:] = 0
349         >>> t[2,3] = 4
350         >>> trials[0,2,:]
351         array([0, 0, 0, 4, 0, 0, 0, 0, 0])
352
353         `cell` slices don't work with "flat" indexing, because the
354         stride would not be constant.  You'll have to push changes
355         back to `trials` by hand.
356
357         >>> t = s._trial_slice(trials, 'cell', (0, 0))
358         >>> t  # doctest: +REPORT_UDIFF
359         array([[0, 0, 0, 0, 5, 0, 0, 0, 0],
360                [0, 0, 3, 0, 0, 0, 0, 0, 0],
361                [0, 0, 0, 4, 0, 0, 0, 0, 0],
362                [0, 0, 0, 0, 0, 6, 0, 0, 0],
363                [1, 2, 3, 4, 5, 6, 7, 8, 9],
364                [1, 2, 3, 4, 5, 6, 7, 8, 9],
365                [1, 2, 3, 4, 5, 6, 7, 8, 9],
366                [0, 0, 0, 0, 0, 0, 0, 0, 9],
367                [0, 0, 0, 0, 0, 0, 0, 8, 0]])
368         >>> trials[1,1,:]
369         array([1, 2, 3, 4, 5, 6, 7, 8, 9])
370         >>> t[4,:] = 0
371         >>> t[4,6] = 7
372         >>> trials[1,1,:] = t[4,:]
373         >>> trials[1,1,:]
374         array([0, 0, 0, 0, 0, 0, 7, 0, 0])
375         """
376         if type == 'row':
377             t = trials[index,:,:]
378         elif type == 'col':
379             t = trials[:,index,:]
380         else:
381             assert type == 'cell', type
382             cell_row,cell_col = index
383             ri,rf,ci,cf = self._cell_bounds(cell_row, cell_col)
384             t = trials[ri:rf,ci:cf,:]
385             t = t.reshape((t.shape[0]*t.shape[1], t.shape[2])).copy()
386         if type in ['row', 'col']:
387             assert t.flags.owndata == False, t.flags.owndata
388             assert t.base is trials, t.base
389         else:
390             assert t.flags.owndata == True, t.flags.owndata
391         return t
392
393     def _direct_elimination_solver(self, trials):
394         r"""Eliminate trials if a point already has the trial digit in
395         its row/col/cell.
396
397         >>> puzzle = '\n'.join([
398         ...         '1 2 3   - - -   - - -',
399         ...         '4 5 6   - - -   - - -',
400         ...         '7 8 -   - - -   - - -',
401         ...         '',
402         ...         '- - -   - - -   - - -',
403         ...         '- - -   - - -   - - -',
404         ...         '- - -   - - -   - - -',
405         ...         '',
406         ...         '- - -   - - -   - - -',
407         ...         '- - -   - - -   - - -',
408         ...         '- - -   - - -   - - -',
409         ...         ])
410         >>> s = Sudoku()
411         >>> s.load(puzzle)
412         >>> trials = s._setup_trials()
413         >>> actions,trials = s._direct_elimination_solver(trials)
414
415         The solver eliminated three numbers for two columns and rows
416         and two numbers for three columns and rows, which makes for
417             104 =   8    # point 2,2 (the solved point)
418                   + 6*3  # row 0
419                   + 6*3  # row 1
420                   + 6*2  # row 2
421                   + 6*3  # col 0
422                   + 6*3  # col 1
423                   + 6*2  # col 2
424         eliminations.
425
426         >>> actions
427         104
428         >>> s._apply_trials(trials)
429         >>> print s.dump()
430         1 2 3   - - -   - - -
431         4 5 6   - - -   - - -
432         7 8 9   - - -   - - -
433         <BLANKLINE>
434         - - -   - - -   - - -
435         - - -   - - -   - - -
436         - - -   - - -   - - -
437         <BLANKLINE>
438         - - -   - - -   - - -
439         - - -   - - -   - - -
440         - - -   - - -   - - -
441         """
442         actions = 0
443         for row in range(9):
444             for col in range(9):
445                 if self._puzzle[row][col] == self._empty:
446                     for x in self._nonempty(trials[row,col,:]):
447                         self._puzzle[row,col] = x
448                         if not self._is_valid():
449                             actions += 1
450                             trials[row,col,x-1] = self._empty
451                         self._puzzle[row,col] = self._empty
452         return (actions, trials)
453
454     def _slice_completion_solver(self, trials):
455         r"""Eliminate trials if a set of N points have trials drawn
456         only from a list of N options.
457
458         For example, a slice like
459           [1, 2, 3, {4,5}, {4,6}, {4,5,6}, {4,5,6,7,8,9},
460            {4,5,6,7,8,9}, {4,5,6,7,8,9}]
461         has three points {4,5}, {4,6}, and {4,5,6}, with three possible
462         numbers: 4, 5, and 6.  That means, 4, 5, and 6 would definitely
463         occupy the three points and other points should not have those
464         numbers.
465
466         >>> puzzle = '\n'.join([
467         ...         '1 2 3   - - -   - - -',
468         ...         '- - -   7 - -   - - -',
469         ...         '- - -   - 8 9   - - -',
470         ...         '',
471         ...         '- - -   6 5 -   - - -',
472         ...         '- - -   - - -   - - -',
473         ...         '- - -   - - -   - - -',
474         ...         '',
475         ...         '- - -   - - -   - - -',
476         ...         '- - -   - - -   - - -',
477         ...         '- - -   - - -   - - -',
478         ...         ])
479         >>> s = Sudoku()
480         >>> s.load(puzzle)
481         >>> trials = s._setup_trials()
482
483         Take a first pass through `_direct_elimination_solver()` to
484         eliminate trial values in the rows/columns/cells blocked by
485         the initialized points.
486
487         >>> actions,trials = s._direct_elimination_solver(trials)
488
489         The solver eliminated the following possibilities (by number)
490             142 =   6 + 6 + 6  # number 1, cell+row+col
491                   + 6 + 6 + 6  # number 2, cell+row+col
492                   + 6 + 6 + 6  # number 3, cell+row+col
493                   + 7 + 6 + 5  # number 5, cell+row+col
494                   + 7 + 6 + 5  # number 6, cell+row+col
495                   + 6 + 6 + 5  # number 7, cell+row+col
496                   + 6 + 6 + 5  # number 8, cell+row+col
497                   + 6 + 6 + 6  # number 9, cell+row+col
498
499         >>> actions
500         142
501
502         However the direct solver was unable to actually solve any new
503         points.
504
505         >>> s._apply_trials(trials)
506         >>> print s.dump()
507         1 2 3   - - -   - - -
508         - - -   7 - -   - - -
509         - - -   - 8 9   - - -
510         <BLANKLINE>
511         - - -   6 5 -   - - -
512         - - -   - - -   - - -
513         - - -   - - -   - - -
514         <BLANKLINE>
515         - - -   - - -   - - -
516         - - -   - - -   - - -
517         - - -   - - -   - - -
518         >>> trials[0,:,:]  # doctest: +REPORT_UDIFF
519         array([[1, 0, 0, 0, 0, 0, 0, 0, 0],
520                [0, 2, 0, 0, 0, 0, 0, 0, 0],
521                [0, 0, 3, 0, 0, 0, 0, 0, 0],
522                [0, 0, 0, 4, 5, 0, 0, 0, 0],
523                [0, 0, 0, 4, 0, 6, 0, 0, 0],
524                [0, 0, 0, 4, 5, 6, 0, 0, 0],
525                [0, 0, 0, 4, 5, 6, 7, 8, 9],
526                [0, 0, 0, 4, 5, 6, 7, 8, 9],
527                [0, 0, 0, 4, 5, 6, 7, 8, 9]])
528
529         Now we proceed with the slice comparison solver.
530
531         >>> actions,trials = s._slice_completion_solver(trials)
532
533         The solver reduced trials in the following points
534             12 =   1  # point 0,6, removed 4,5,6  must be in center of row 0
535                  + 1  # point 0,7, removed 4,5,6  must be in center of row 0
536                  + 1  # point 0,8, removed 4,5,6  must be in center of row 0
537                  + 1  # point 1,4, removed 4,6    must be in top of cell 0,1
538                  + 1  # point 1,5, removed 4,5,6  must be in top of cell 0,1
539                  + 1  # point 2,3, removed 4,5    must be in top of cell 0,1
540                  + 1  # point 1,6, removed 8,9    must be in top of cell 0,2
541                  + 1  # point 1,7, removed 8,9    must be in top of cell 0,2
542                  + 1  # point 1,8, removed 8,9    must be in top of cell 0,2
543                  + 1  # point 2,6, removed 7      must be in top of cell 0,2
544                  + 1  # point 2,7, removed 7      must be in top of cell 0,2
545                  + 1  # point 2,8, removed 7      must be in top of cell 0,2
546
547         >>> actions
548         12
549         >>> trials[0,:,:]  # doctest: +REPORT_UDIFF
550         array([[1, 0, 0, 0, 0, 0, 0, 0, 0],
551                [0, 2, 0, 0, 0, 0, 0, 0, 0],
552                [0, 0, 3, 0, 0, 0, 0, 0, 0],
553                [0, 0, 0, 4, 5, 0, 0, 0, 0],
554                [0, 0, 0, 4, 0, 6, 0, 0, 0],
555                [0, 0, 0, 4, 5, 6, 0, 0, 0],
556                [0, 0, 0, 0, 0, 0, 7, 8, 9],
557                [0, 0, 0, 0, 0, 0, 7, 8, 9],
558                [0, 0, 0, 0, 0, 0, 7, 8, 9]])
559         """
560         actions = 0
561         for _type,slice,index in self._slices():
562             assert slice.size == 9, slice
563             trial_slice = self._trial_slice(trials, _type, index)
564             missing = set(self._nonempty(trial_slice.flat))
565             for possible in power_set(missing):
566                 possible = set(possible)
567                 if possible in [set(), missing]:
568                     continue
569                 points = []
570                 for k in range(slice.size):
571                     trial_set = set(self._nonempty(trial_slice[k,:]))
572                     if trial_set.issubset(possible):
573                         points.append(k)
574                 if len(points) == len(possible):
575                     possible_trial_slice = [0]*9
576                     for p in possible:
577                         possible_trial_slice[p-1] = p
578                     for k in range(slice.size):
579                         if k in points:
580                             ts = numpy.array(possible_trial_slice)
581                             for i,p in enumerate(ts):
582                                 if trial_slice[k,i] == self._empty:
583                                     ts[i] = self._empty
584                         else:
585                             ts = trial_slice[k,:].copy()
586                             for i,p in enumerate(ts):
587                                 if p in possible:
588                                     ts[i] = self._empty
589                         if _type == 'cell':
590                             cell_row,cell_col = index
591                             row,col = self._cell_to_point_coords(
592                                 cell_row, cell_col, k)
593                             if (ts != trials[row,col,:]).any():
594                                 actions += 1
595                                 #print _type, index, k, row, col, trial_slice[k,:], ts
596                                 trials[row,col,:] = ts
597                                 trial_slice[k,:] = ts
598                         else:  # row or column
599                             if (ts != trial_slice[k,:]).any():
600                                 actions += 1
601                                 #print _type, index, k, trial_slice[k,:], ts
602                                 trial_slice[k,:] = ts
603         return (actions, trials)
604
605
606 def test():
607     import doctest
608     doctest.testmod()
609
610 if __name__ == '__main__':
611     import optparse
612     import sys
613
614     p = optparse.OptionParser()
615     p.add_option('--test', dest='test', default=False, action='store_true',
616                  help='Run unit tests and exit.')
617     p.add_option('-d', '--disable-direct', dest='direct', default=True,
618                  action='store_false',
619                  help='Disable the direct elimination solver')
620     p.add_option('-c', '--disable-completion', dest='completion', default=True,
621                  action='store_false',
622                  help='Disable the slice completion solver')
623
624     options,args = p.parse_args()
625
626     if options.test:
627         test()
628         sys.exit(0)
629
630     s = Sudoku()
631     if not options.direct:
632         s.solvers.remove(s._direct_elimination_solver)
633     if not options.completion:
634         s.solvers.remove(s._slice_completion_solver)
635
636     puzzle = sys.stdin.read()
637     s.load(puzzle)
638     try:
639         s.solve()
640     except KeyboardInterrupt, e:
641         s.status = 'interrupted'
642     print >> sys.stderr, s.status
643     print s.dump()