Began versioning
[splittable_kwargs.git] / splittable_kwargs.py
1 #!/usr/bin/python
2 """
3 splittable_kwargs allows the splitting of **kwargs arguments among
4 several functions.  This
5
6 Copyright (C)  W. Trevor King  2008
7 This code is released to the public domain.
8
9 Example usage (adapted from the unittests)
10
11   @splittableKwargsFunction()
12   def foo(x, y, z=2):
13       return "foo: x "+str(x)+", y "+str(y)+", z "+str(z)+"\n"
14   
15   @splittableKwargsFunction()
16   def bar(a, b, c=2):
17       return "bar: a "+str(a)+", b "+str(b)+", c "+str(c)+"\n"
18   
19   @splittableKwargsFunction((bar, 'a'))
20   def baz(d, **kwargs):
21       string = bar(c=4, **kwargs)
22       return string + "baz: d "+str(d)+"\n"
23   
24   @splittableKwargsFunction(foo, bar)
25   def simple_joint(**kwargs):
26       fookw,bazkw = simple_joint._splitargs(simple_joint, kwargs)
27       string = foo(**fookw)
28       string += baz(**bazkw)
29       return string
30   
31   simple_joint(y=3,c=1,d=5)
32
33 If, say, simple_joint's children had not been defined (you wanted to
34 define bar after simple_joint in your module), you can skip the
35 decorator in simple_joint, and make it splittable later (after you
36 define bar) with
37
38   make_splittable_kwargs_function(simple_joint, foo, bar)
39 """
40
41 import inspect
42 import unittest
43
44 class UnknownKwarg (KeyError):
45     def __init__(self, fn, kwarg, value):
46         if hasattr(fn, "_kwargs"):
47             fn_list = [fn]
48         else:
49             fn_list = fn
50         msg = "Unknown kwarg %s = %s.  Allowed:\n" \
51             % (kwarg, value)
52         for f in fn_list:
53             msg += "  %s %s\n" % (f.__name__, f._kwargs(f))
54         KeyError.__init__(self, msg)
55
56 def _parse_splittable(splittable):
57     """
58       splittable -> (splittable_fn, masked_args)
59     """
60     if hasattr(splittable, "_kwargs"): # bare splittableKwargsFunction
61         return (splittable, [])
62     else: # function followed by masked args
63         return (splittable[0], splittable[1:])
64
65 def splitargs(kwargs, *internal_splittables):
66     """
67     where
68       *internal_splittables : a list of splittableKwargsFunctions items
69           that this function uses internally.
70       the items can be either
71           a bare splittableKwargsFunction
72       or a tuple where the additional elements are arguments to mask
73           a (bare splittableKwargsFunction, masked argument, ...)
74     """
75     # sort the kwargs according to the appropriate function
76     fn_kwargs = [{} for splittable in internal_splittables]
77     for key,value in kwargs.items():
78         sorted = False
79         for i,splittable in enumerate(internal_splittables):
80             fn,masked = _parse_splittable(splittable)
81             if key in fn._kwargs(fn) and key not in masked:
82                 fn_kwargs[i][key] = value
83                 sorted = True
84                 break
85         if sorted != True:
86             fns = []
87             for splittable in internal_splittables:
88                 fn,masked = _parse_splittable(splittable)
89                 fns.append(fn)
90             raise UnknownKwarg(fns, key, value)
91     return fn_kwargs
92
93 def _kwargs(self):
94     "return a list of acceptable kwargs"
95     args,varargs,varkw,defaults = inspect.getargspec(self)
96     if varargs != None:
97         raise NotImplementedError, "\n %s" % varargs
98     if varkw != None:
99         for child in self._childSplittables:
100             child,masked = _parse_splittable(child)
101             child_args = child._kwargs(child)
102             for arg in masked:
103                 child_args.remove(arg)
104             args.extend(child_args)
105     return args
106
107 def _declareInternalSplittableKwargsFunction(self, function):
108     """
109     FUNCTION can be either a bare splittableKwargsFuntion or one
110     such function followed by a sequence of masked arguments.
111     
112     Example values for FUNCTION:
113     bar
114     (bar, "a")
115     (bar, "a", "b")
116     """
117     self._childSplittables.append(function)
118
119 def make_splittable_kwargs_function(function, *internal_splittables):
120     function._kwargs = _kwargs
121     function._childSplittables = []
122     function._declareInternalSplittableKwargsFunction = \
123              _declareInternalSplittableKwargsFunction
124     def _splitargs(self, **kwargs):
125         return splitargs(kwargs, *self._childSplittables)
126     function._splitargs = lambda self,kwargs : splitargs(kwargs, *self._childSplittables)
127     
128     for splittable in internal_splittables:
129         function._declareInternalSplittableKwargsFunction(function, splittable)
130
131 class splittableKwargsFunction (object):
132     def __init__(self, *internal_splittables):
133         self.internal_splittables = internal_splittables
134     def __call__(self, function):
135         make_splittable_kwargs_function(function, *self.internal_splittables)
136         return function
137
138 @splittableKwargsFunction()
139 def _foo(x, y, z=2):
140     "foo function"
141     return "foo: x "+str(x)+", y "+str(y)+", z "+str(z)+"\n"
142
143 @splittableKwargsFunction()
144 def _bar(a, b, c=2):
145     "bar function"
146     return "bar: a "+str(a)+", b "+str(b)+", c "+str(c)+"\n"
147
148 @splittableKwargsFunction((_bar, 'a'))
149 def _baz(d, **kwargs):
150     "baz function"
151     string = _bar(a=6, c=4, **kwargs)
152     return string + "baz: d "+str(d)+"\n"
153
154 @splittableKwargsFunction(_foo, _bar)
155 def _fun(**kwargs):
156     fookw,barkw = _fun._splitargs(_fun, kwargs)
157     return(fookw, barkw)
158
159 class SplittableKwargsTestCase (unittest.TestCase):
160     @splittableKwargsFunction(_foo, _bar)
161     def foobar(self, **kwargs):
162         fookw,barkw = self.foobar._splitargs(self.foobar, kwargs)
163         string = _foo(**fookw)
164         string += _bar(**barkw)
165         return string
166
167     @staticmethod
168     def simple_joint(**kwargs):
169         string = "joining\n"
170         keys = kwargs.keys()
171         keys.sort()
172         string += "\n".join(["  %s : %s" % (k,kwargs[k]) for k in keys])
173         string += "\n"
174         fookw,barkw = splitargs(kwargs, _foo, _bar)
175         string += _foo(**fookw)
176         string += _bar(a=4, **barkw)
177         return string
178
179     @staticmethod
180     def deeper_joint(**kwargs):
181         string = "joining\n"
182         keys = kwargs.keys()
183         keys.sort()
184         string += "\n".join(["  %s : %s" % (k,kwargs[k]) for k in keys])
185         string += "\n"
186         fookw,bazkw = splitargs(kwargs, _foo, _baz)
187         string += _foo(**fookw)
188         string += _baz(**bazkw)
189         return string
190
191     def testDocPreserved(self):
192         self.failUnless(_foo.__doc__ == "foo function", _foo.__doc__)
193         self.failUnless(_bar.__doc__ == "bar function", _bar.__doc__)
194         self.failUnless(_baz.__doc__ == "baz function", _baz.__doc__)
195     def testSingleLevelMethod(self):
196         expected = """foo: x 1, y 2, z 3
197 bar: a 8, b 5, c 2
198 """
199         output = self.foobar(x=1, y=2, z=3, a=8, b=5)
200         self.failUnless(output == expected,
201                         "GOT\n%s\nEXPECTED\n%s" % (output, expected))
202     def testSingleLevelFunction(self):
203         simple_joint = SplittableKwargsTestCase.simple_joint
204         expected = """joining
205   b : 5
206   x : 1
207   y : 2
208   z : 3
209 foo: x 1, y 2, z 3
210 bar: a 4, b 5, c 2
211 """
212         output = simple_joint(x=1, y=2, z=3, b=5)
213         self.failUnless(output == expected,
214                         "GOT\n%s\nEXPECTED\n%s" % (output, expected))
215     def testSingleLevelUnknownKwarg(self):
216         simple_joint = SplittableKwargsTestCase.simple_joint
217         self.assertRaises(UnknownKwarg, simple_joint,
218                           y=2, z=3, b=5, unknown=6)
219     def testDoubleLevel(self):
220         deeper_joint = SplittableKwargsTestCase.deeper_joint
221         expected = """joining
222   b : 5
223   d : 6
224   x : 1
225   y : 2
226   z : 3
227 foo: x 1, y 2, z 3
228 bar: a 6, b 5, c 4
229 baz: d 6
230 """
231         output = deeper_joint(x=1, y=2, z=3, b=5, d=6)
232         self.failUnless(output == expected,
233                         "GOT\n%s\nEXPECTED\n%s" % (output, expected))
234     def testDoubleLevelOverrideRequired(self):
235         deeper_joint = SplittableKwargsTestCase.deeper_joint
236         self.assertRaises(UnknownKwarg, deeper_joint,
237                           x=8, y=2, z=3, b=5, a=1)
238     def testRecursiveSplitargsReference(self):
239         # Access the ._splitargs method using the defined function name
240         expected = ({'y':3, 'z':4}, {'a':1, 'b':2})
241         output = _fun(a=1,b=2,y=3,z=4)
242         self.failUnless(output == expected,
243                         "GOT\n%s\nEXPECTED\n%s" % (str(output), str(expected)))
244
245 if __name__ == "__main__":
246     import sys
247     
248     unitsuite = unittest.TestLoader().loadTestsFromTestCase( \
249                                                       SplittableKwargsTestCase)
250     result = unittest.TextTestRunner(verbosity=2).run(unitsuite)
251     numErrors = len(result.errors)
252     numFailures = len(result.failures)
253     numBad = numErrors + numFailures
254     if numBad > 126:
255         numBad = 1
256     sys.exit(numBad)