3 splittable_kwargs allows the splitting of **kwargs arguments among
6 Copyright (C) W. Trevor King 2008, 2009
7 This code is released to the public domain.
9 Example usage (adapted from the unittests)
11 >>> from splittable_kwargs import *
12 >>> @splittableKwargsFunction()
13 ... def foo(x, y, z=2):
14 ... return "foo: x "+str(x)+", y "+str(y)+", z "+str(z)+"\\n"
16 >>> @splittableKwargsFunction()
17 ... def bar(a, b, c=2):
18 ... return "bar: a "+str(a)+", b "+str(b)+", c "+str(c)+"\\n"
19 >>> @splittableKwargsFunction({'fn':bar, 'mask':'a', 'translate':{'b':'b_bar'}})
20 ... def baz(d, **kwargs):
21 ... barkw, = baz._splitargs(baz, kwargs) # for translation/masking
22 ... string = bar(a=5, c=4, **barkw)
23 ... return string + "baz: d "+str(d)+"\\n"
24 >>> @splittableKwargsFunction(foo, baz)
25 ... def simple_joint(**kwargs):
26 ... fookw,bazkw = simple_joint._splitargs(simple_joint, kwargs)
27 ... string = foo(**fookw)
28 ... string += baz(**bazkw)
29 ... return string.strip()
30 >>> print simple_joint(x=1,y=3,b_bar=6,d=7)
35 If, say, simple_joint's children had not been defined (you wanted to
36 define baz after simple_joint in your module), you can skip the
37 decorator in simple_joint, and make it splittable later (after you
40 make_splittable_kwargs_function(simple_joint, foo, baz)
42 You can also get a list of the available named arguments with
44 >>> print baz._kwargs(baz)
46 >>> simple_joint._kwargs(simple_joint)
47 ['x', 'y', 'z', 'd', 'b_bar', 'c']
49 It may seem redundant to need to pass the function (here simple_joint)
50 to a method of simple_joint, but remember that simple_joint is a
51 _function_, not a class instance. If it really bothers you, try
54 >>> class ClassJoint (object):
55 ... @splittableKwargsFunction(foo, baz)
56 ... def __call__(self, **kwargs):
57 ... fookw,bazkw = simple_joint._splitargs(simple_joint, kwargs)
58 ... string = foo(**fookw)
59 ... string += baz(**bazkw)
60 ... return string.strip()
61 ... def _kwargs(self):
62 ... return self.__call__._kwargs(self.__call__)
65 ['self', 'x', 'y', 'z', 'd', 'b_bar', 'c']
66 >>> print cj(x=1,y=3,b_bar=6,d=7)
78 class UnknownKwarg (KeyError):
79 def __init__(self, fn, kwarg, value):
80 if hasattr(fn, "_kwargs"):
84 msg = "Unknown kwarg %s = %s. Allowed:\n" \
87 msg += " %s %s\n" % (f.__name__, f._kwargs(f))
88 KeyError.__init__(self, msg)
90 def _parse_splittable(splittable):
92 splittable -> (splittable_fn, masked_args, translated_args)
94 if hasattr(splittable, "_kwargs"): # bare splittableKwargsFunction
95 return (splittable, [], {})
96 else: # masked/translated dict
97 return (splittable['fn'],
98 splittable.get('mask', []),
99 splittable.get('translate', {}))
101 def splitargs(kwargs, *internal_splittables):
104 *internal_splittables : a list of splittableKwargsFunctions
105 items that this function uses internally. The items must be
106 parsable by _parse_splittable().
108 # sort the kwargs according to the appropriate function
109 fn_kwargs = [{} for splittable in internal_splittables]
110 for key,value in kwargs.items():
112 for i,splittable in enumerate(internal_splittables):
113 fn,masked,translated = _parse_splittable(splittable)
114 inv_translation = dict([(v,k) for k,v in translated.items()])
115 if key in inv_translation:
116 child_key = inv_translation[key]
118 elif key in fn._kwargs(fn) and key not in masked:
122 fn_kwargs[i][child_key] = value
124 if sorted != True: # couldn't find anywhere for that key.
126 for splittable in internal_splittables:
127 fn,masked,translated = _parse_splittable(splittable)
129 raise UnknownKwarg(fns, key, value)
133 "Return a list of acceptable kwargs"
134 args,varargs,varkw,defaults = inspect.getargspec(self)
136 raise NotImplementedError, "\n %s" % varargs
138 for child in self._childSplittables:
139 child,masked,translated = _parse_splittable(child)
140 child_args = child._kwargs(child)
143 child_args.remove(arg)
144 except ValueError, e:
145 msg = "%s not in %s" % (arg, child_args)
146 raise ValueError(msg)
147 for original,new in translated.items():
149 child_args[child_args.index(original)] = new
150 except ValueError, e:
151 msg = "%s not in %s" % (arg, child_args)
152 raise ValueError(msg)
153 args.extend(child_args)
156 def _declareInternalSplittableKwargsFunction(self, function):
158 FUNCTION can be either a bare splittableKwargsFuntion or a dict of
160 {'fn': splittableKwargsFuntion,
161 'mask':["argumentA",...],
162 'translate':{"argumentB":"argB_new_name", ...}}
164 Example values for FUNCTION:
166 {'fn':bar, mask:["a"]}
167 {'fn':bar, mask:["a", "b"], translate:{"c":"bar_c"}}}
169 self._childSplittables.append(function)
171 def make_splittable_kwargs_function(function, *internal_splittables):
172 function._kwargs = _kwargs
173 function._childSplittables = []
174 function._declareInternalSplittableKwargsFunction = \
175 _declareInternalSplittableKwargsFunction
176 def _splitargs(self, **kwargs):
177 return splitargs(kwargs, *self._childSplittables)
178 function._splitargs = lambda self,kwargs : splitargs(kwargs, *self._childSplittables)
180 for splittable in internal_splittables:
181 function._declareInternalSplittableKwargsFunction(function, splittable)
183 class splittableKwargsFunction (object):
184 def __init__(self, *internal_splittables):
185 self.internal_splittables = internal_splittables
186 def __call__(self, function):
187 make_splittable_kwargs_function(function, *self.internal_splittables)
190 @splittableKwargsFunction()
193 return "foo: x "+str(x)+", y "+str(y)+", z "+str(z)+"\n"
195 @splittableKwargsFunction()
198 return "bar: a "+str(a)+", b "+str(b)+", c "+str(c)+"\n"
200 @splittableKwargsFunction({'fn':_bar, 'mask':'a', 'translate':{'b':'b_bar'}})
201 def _baz(d, **kwargs):
203 barkw, = _baz._splitargs(_baz, kwargs) # for translation/masking
204 string = _bar(a=6, c=4, **barkw)
205 return string + "baz: d "+str(d)+"\n"
207 @splittableKwargsFunction(_foo, _bar)
209 fookw,barkw = _fun._splitargs(_fun, kwargs)
212 class SplittableKwargsTestCase (unittest.TestCase):
213 @splittableKwargsFunction(_foo, _bar)
214 def foobar(self, **kwargs):
215 fookw,barkw = self.foobar._splitargs(self.foobar, kwargs)
216 string = _foo(**fookw)
217 string += _bar(**barkw)
221 def simple_joint(**kwargs):
225 string += "\n".join([" %s : %s" % (k,kwargs[k]) for k in keys])
227 fookw,barkw = splitargs(kwargs, _foo, _bar)
228 string += _foo(**fookw)
229 string += _bar(a=4, **barkw)
233 def deeper_joint(**kwargs):
237 string += "\n".join([" %s : %s" % (k,kwargs[k]) for k in keys])
239 fookw,bazkw = splitargs(kwargs, _foo, _baz)
240 string += _foo(**fookw)
241 string += _baz(**bazkw)
244 def testDocPreserved(self):
245 self.failUnless(_foo.__doc__ == "foo function", _foo.__doc__)
246 self.failUnless(_bar.__doc__ == "bar function", _bar.__doc__)
247 self.failUnless(_baz.__doc__ == "baz function", _baz.__doc__)
248 def testSingleLevelMethod(self):
249 expected = """foo: x 1, y 2, z 3
252 output = self.foobar(x=1, y=2, z=3, a=8, b=5)
253 self.failUnless(output == expected,
254 "GOT\n%s\nEXPECTED\n%s" % (output, expected))
255 def testSingleLevelFunction(self):
256 simple_joint = SplittableKwargsTestCase.simple_joint
257 expected = """joining
265 output = simple_joint(x=1, y=2, z=3, b=5)
266 self.failUnless(output == expected,
267 "GOT\n%s\nEXPECTED\n%s" % (output, expected))
268 def testSingleLevelUnknownKwarg(self):
269 simple_joint = SplittableKwargsTestCase.simple_joint
270 self.assertRaises(UnknownKwarg, simple_joint,
271 y=2, z=3, b=5, unknown=6)
272 def testDoubleLevel(self):
273 deeper_joint = SplittableKwargsTestCase.deeper_joint
274 expected = """joining
284 output = deeper_joint(x=1, y=2, z=3, b_bar=5, d=6)
285 self.failUnless(output == expected,
286 "GOT\n%s\nEXPECTED\n%s" % (output, expected))
287 def testDoubleLevelOverrideRequired(self):
288 deeper_joint = SplittableKwargsTestCase.deeper_joint
289 self.assertRaises(UnknownKwarg, deeper_joint,
290 x=8, y=2, z=3, b=5, a=1)
291 def testRecursiveSplitargsReference(self):
292 # Access the ._splitargs method using the defined function name
293 expected = ({'y':3, 'z':4}, {'a':1, 'b':2})
294 output = _fun(a=1,b=2,y=3,z=4)
295 self.failUnless(output == expected,
296 "GOT\n%s\nEXPECTED\n%s" % (str(output), str(expected)))
298 if __name__ == "__main__":
301 unitsuite = unittest.TestLoader().loadTestsFromTestCase( \
302 SplittableKwargsTestCase)
303 unitsuite.addTest(doctest.DocTestSuite())
304 result = unittest.TextTestRunner(verbosity=2).run(unitsuite)
305 numErrors = len(result.errors)
306 numFailures = len(result.failures)
307 numBad = numErrors + numFailures