Added better error handling to _kwargs() and added doctests.
[splittable_kwargs.git] / splittable_kwargs.py
1 #!/usr/bin/python
2 """
3 splittable_kwargs allows the splitting of **kwargs arguments among
4 several functions.
5
6 Copyright (C)  W. Trevor King  2008, 2009
7 This code is released to the public domain.
8
9 Example usage (adapted from the unittests)
10
11     >>> from splittable_kwargs import *
12     >>> 
13     >>> @splittableKwargsFunction()
14     ... def foo(x, y, z=2):
15     ...     return "foo: x "+str(x)+", y "+str(y)+", z "+str(z)+"\\n"
16     ... 
17     >>> @splittableKwargsFunction()
18     ... def bar(a, b, c=2):
19     ...     return "bar: a "+str(a)+", b "+str(b)+", c "+str(c)+"\\n"
20     >>> 
21     >>> @splittableKwargsFunction((bar, 'a'))
22     ... def baz(d, **kwargs):
23     ...     string = bar(a=5, c=4, **kwargs)
24     ...     return string + "baz: d "+str(d)+"\\n"
25     >>> 
26     >>> @splittableKwargsFunction(foo, baz)
27     ... def simple_joint(**kwargs):
28     ...     fookw,bazkw = simple_joint._splitargs(simple_joint, kwargs)
29     ...     string = foo(**fookw)
30     ...     string += baz(**bazkw)
31     ...     return string.strip()
32     >>> 
33     >>> print simple_joint(x=1,y=3,b=6,d=7)
34     foo: x 1, y 3, z 2
35     bar: a 5, b 6, c 4
36     baz: d 7
37
38 If, say, simple_joint's children had not been defined (you wanted to
39 define baz after simple_joint in your module), you can skip the
40 decorator in simple_joint, and make it splittable later (after you
41 define baz) with
42
43    make_splittable_kwargs_function(simple_joint, foo, baz)
44
45 You can also get a list of the available named arguments with
46
47     >>> simple_joint._kwargs(simple_joint)
48     ['x', 'y', 'z', 'd', 'b', 'c']
49
50 It may seem redundant to need to pass the function (here simple_joint)
51 to a method of simple_joint, but remember that simple_joint is a
52 _function_, not a class instance.  If it really bothers you, try
53 something like
54
55     >>> class ClassJoint (object):
56     ...     @splittableKwargsFunction(foo, baz)
57     ...     def __call__(self, **kwargs):
58     ...         fookw,bazkw = simple_joint._splitargs(simple_joint, kwargs)
59     ...         string = foo(**fookw)
60     ...         string += baz(**bazkw)
61     ...         return string.strip()
62     ...     def _kwargs(self):
63     ...         return self.__call__._kwargs(self.__call__)
64     >>> cj = ClassJoint()
65     >>> print cj(x=1,y=3,b=6,d=7)
66     foo: x 1, y 3, z 2
67     bar: a 5, b 6, c 4
68     baz: d 7
69     >>> cj._kwargs()
70     ['self', 'x', 'y', 'z', 'd', 'b', 'c']
71 """
72
73 import inspect
74 import doctest
75 import unittest
76
77 VERSION = "0.2"
78
79 class UnknownKwarg (KeyError):
80     def __init__(self, fn, kwarg, value):
81         if hasattr(fn, "_kwargs"):
82             fn_list = [fn]
83         else:
84             fn_list = fn
85         msg = "Unknown kwarg %s = %s.  Allowed:\n" \
86             % (kwarg, value)
87         for f in fn_list:
88             msg += "  %s %s\n" % (f.__name__, f._kwargs(f))
89         KeyError.__init__(self, msg)
90
91 def _parse_splittable(splittable):
92     """
93       splittable -> (splittable_fn, masked_args)
94     """
95     if hasattr(splittable, "_kwargs"): # bare splittableKwargsFunction
96         return (splittable, [])
97     else: # function followed by masked args
98         return (splittable[0], splittable[1:])
99
100 def splitargs(kwargs, *internal_splittables):
101     """
102     where
103       *internal_splittables : a list of splittableKwargsFunctions items
104           that this function uses internally.
105       the items can be either
106           a bare splittableKwargsFunction
107       or a tuple where the additional elements are arguments to mask
108           a (bare splittableKwargsFunction, masked argument, ...)
109     """
110     # sort the kwargs according to the appropriate function
111     fn_kwargs = [{} for splittable in internal_splittables]
112     for key,value in kwargs.items():
113         sorted = False
114         for i,splittable in enumerate(internal_splittables):
115             fn,masked = _parse_splittable(splittable)
116             if key in fn._kwargs(fn) and key not in masked:
117                 fn_kwargs[i][key] = value
118                 sorted = True
119                 break
120         if sorted != True:
121             fns = []
122             for splittable in internal_splittables:
123                 fn,masked = _parse_splittable(splittable)
124                 fns.append(fn)
125             raise UnknownKwarg(fns, key, value)
126     return fn_kwargs
127
128 def _kwargs(self):
129     "return a list of acceptable kwargs"
130     args,varargs,varkw,defaults = inspect.getargspec(self)
131     if varargs != None:
132         raise NotImplementedError, "\n %s" % varargs
133     if varkw != None:
134         for child in self._childSplittables:
135             child,masked = _parse_splittable(child)
136             child_args = child._kwargs(child)
137             for arg in masked:
138                 try:
139                     child_args.remove(arg)
140                 except ValueError, e:
141                     msg = "%s not in %s" % (arg, child_args)
142                     raise ValueError(msg)
143             args.extend(child_args)
144     return args
145
146 def _declareInternalSplittableKwargsFunction(self, function):
147     """
148     FUNCTION can be either a bare splittableKwargsFuntion or one
149     such function followed by a sequence of masked arguments.
150     
151     Example values for FUNCTION:
152     bar
153     (bar, "a")
154     (bar, "a", "b")
155     """
156     self._childSplittables.append(function)
157
158 def make_splittable_kwargs_function(function, *internal_splittables):
159     function._kwargs = _kwargs
160     function._childSplittables = []
161     function._declareInternalSplittableKwargsFunction = \
162              _declareInternalSplittableKwargsFunction
163     def _splitargs(self, **kwargs):
164         return splitargs(kwargs, *self._childSplittables)
165     function._splitargs = lambda self,kwargs : splitargs(kwargs, *self._childSplittables)
166     
167     for splittable in internal_splittables:
168         function._declareInternalSplittableKwargsFunction(function, splittable)
169
170 class splittableKwargsFunction (object):
171     def __init__(self, *internal_splittables):
172         self.internal_splittables = internal_splittables
173     def __call__(self, function):
174         make_splittable_kwargs_function(function, *self.internal_splittables)
175         return function
176
177 @splittableKwargsFunction()
178 def _foo(x, y, z=2):
179     "foo function"
180     return "foo: x "+str(x)+", y "+str(y)+", z "+str(z)+"\n"
181
182 @splittableKwargsFunction()
183 def _bar(a, b, c=2):
184     "bar function"
185     return "bar: a "+str(a)+", b "+str(b)+", c "+str(c)+"\n"
186
187 @splittableKwargsFunction((_bar, 'a'))
188 def _baz(d, **kwargs):
189     "baz function"
190     string = _bar(a=6, c=4, **kwargs)
191     return string + "baz: d "+str(d)+"\n"
192
193 @splittableKwargsFunction(_foo, _bar)
194 def _fun(**kwargs):
195     fookw,barkw = _fun._splitargs(_fun, kwargs)
196     return(fookw, barkw)
197
198 class SplittableKwargsTestCase (unittest.TestCase):
199     @splittableKwargsFunction(_foo, _bar)
200     def foobar(self, **kwargs):
201         fookw,barkw = self.foobar._splitargs(self.foobar, kwargs)
202         string = _foo(**fookw)
203         string += _bar(**barkw)
204         return string
205
206     @staticmethod
207     def simple_joint(**kwargs):
208         string = "joining\n"
209         keys = kwargs.keys()
210         keys.sort()
211         string += "\n".join(["  %s : %s" % (k,kwargs[k]) for k in keys])
212         string += "\n"
213         fookw,barkw = splitargs(kwargs, _foo, _bar)
214         string += _foo(**fookw)
215         string += _bar(a=4, **barkw)
216         return string
217
218     @staticmethod
219     def deeper_joint(**kwargs):
220         string = "joining\n"
221         keys = kwargs.keys()
222         keys.sort()
223         string += "\n".join(["  %s : %s" % (k,kwargs[k]) for k in keys])
224         string += "\n"
225         fookw,bazkw = splitargs(kwargs, _foo, _baz)
226         string += _foo(**fookw)
227         string += _baz(**bazkw)
228         return string
229
230     def testDocPreserved(self):
231         self.failUnless(_foo.__doc__ == "foo function", _foo.__doc__)
232         self.failUnless(_bar.__doc__ == "bar function", _bar.__doc__)
233         self.failUnless(_baz.__doc__ == "baz function", _baz.__doc__)
234     def testSingleLevelMethod(self):
235         expected = """foo: x 1, y 2, z 3
236 bar: a 8, b 5, c 2
237 """
238         output = self.foobar(x=1, y=2, z=3, a=8, b=5)
239         self.failUnless(output == expected,
240                         "GOT\n%s\nEXPECTED\n%s" % (output, expected))
241     def testSingleLevelFunction(self):
242         simple_joint = SplittableKwargsTestCase.simple_joint
243         expected = """joining
244   b : 5
245   x : 1
246   y : 2
247   z : 3
248 foo: x 1, y 2, z 3
249 bar: a 4, b 5, c 2
250 """
251         output = simple_joint(x=1, y=2, z=3, b=5)
252         self.failUnless(output == expected,
253                         "GOT\n%s\nEXPECTED\n%s" % (output, expected))
254     def testSingleLevelUnknownKwarg(self):
255         simple_joint = SplittableKwargsTestCase.simple_joint
256         self.assertRaises(UnknownKwarg, simple_joint,
257                           y=2, z=3, b=5, unknown=6)
258     def testDoubleLevel(self):
259         deeper_joint = SplittableKwargsTestCase.deeper_joint
260         expected = """joining
261   b : 5
262   d : 6
263   x : 1
264   y : 2
265   z : 3
266 foo: x 1, y 2, z 3
267 bar: a 6, b 5, c 4
268 baz: d 6
269 """
270         output = deeper_joint(x=1, y=2, z=3, b=5, d=6)
271         self.failUnless(output == expected,
272                         "GOT\n%s\nEXPECTED\n%s" % (output, expected))
273     def testDoubleLevelOverrideRequired(self):
274         deeper_joint = SplittableKwargsTestCase.deeper_joint
275         self.assertRaises(UnknownKwarg, deeper_joint,
276                           x=8, y=2, z=3, b=5, a=1)
277     def testRecursiveSplitargsReference(self):
278         # Access the ._splitargs method using the defined function name
279         expected = ({'y':3, 'z':4}, {'a':1, 'b':2})
280         output = _fun(a=1,b=2,y=3,z=4)
281         self.failUnless(output == expected,
282                         "GOT\n%s\nEXPECTED\n%s" % (str(output), str(expected)))
283
284 if __name__ == "__main__":
285     import sys
286     
287     unitsuite = unittest.TestLoader().loadTestsFromTestCase( \
288                                                       SplittableKwargsTestCase)
289     unitsuite.addTest(doctest.DocTestSuite())
290     result = unittest.TextTestRunner(verbosity=2).run(unitsuite)
291     numErrors = len(result.errors)
292     numFailures = len(result.failures)
293     numBad = numErrors + numFailures
294     if numBad > 126:
295         numBad = 1
296     sys.exit(numBad)