3 splittable_kwargs allows the splitting of **kwargs arguments among
4 several functions. This
6 Copyright (C) W. Trevor King 2008
7 This code is released to the public domain.
9 Example usage (adapted from the unittests)
11 @splittableKwargsFunction()
13 return "foo: x "+str(x)+", y "+str(y)+", z "+str(z)+"\n"
15 @splittableKwargsFunction()
17 return "bar: a "+str(a)+", b "+str(b)+", c "+str(c)+"\n"
19 @splittableKwargsFunction((bar, 'a'))
21 string = bar(c=4, **kwargs)
22 return string + "baz: d "+str(d)+"\n"
24 @splittableKwargsFunction(foo, bar)
25 def simple_joint(**kwargs):
26 fookw,bazkw = simple_joint._splitargs(simple_joint, kwargs)
28 string += baz(**bazkw)
31 simple_joint(y=3,c=1,d=5)
33 If, say, simple_joint's children had not been defined (you wanted to
34 define bar after simple_joint in your module), you can skip the
35 decorator in simple_joint, and make it splittable later (after you
38 make_splittable_kwargs_function(simple_joint, foo, bar)
44 class UnknownKwarg (KeyError):
45 def __init__(self, fn, kwarg, value):
46 if hasattr(fn, "_kwargs"):
50 msg = "Unknown kwarg %s = %s. Allowed:\n" \
53 msg += " %s %s\n" % (f.__name__, f._kwargs(f))
54 KeyError.__init__(self, msg)
56 def _parse_splittable(splittable):
58 splittable -> (splittable_fn, masked_args)
60 if hasattr(splittable, "_kwargs"): # bare splittableKwargsFunction
61 return (splittable, [])
62 else: # function followed by masked args
63 return (splittable[0], splittable[1:])
65 def splitargs(kwargs, *internal_splittables):
68 *internal_splittables : a list of splittableKwargsFunctions items
69 that this function uses internally.
70 the items can be either
71 a bare splittableKwargsFunction
72 or a tuple where the additional elements are arguments to mask
73 a (bare splittableKwargsFunction, masked argument, ...)
75 # sort the kwargs according to the appropriate function
76 fn_kwargs = [{} for splittable in internal_splittables]
77 for key,value in kwargs.items():
79 for i,splittable in enumerate(internal_splittables):
80 fn,masked = _parse_splittable(splittable)
81 if key in fn._kwargs(fn) and key not in masked:
82 fn_kwargs[i][key] = value
87 for splittable in internal_splittables:
88 fn,masked = _parse_splittable(splittable)
90 raise UnknownKwarg(fns, key, value)
94 "return a list of acceptable kwargs"
95 args,varargs,varkw,defaults = inspect.getargspec(self)
97 raise NotImplementedError, "\n %s" % varargs
99 for child in self._childSplittables:
100 child,masked = _parse_splittable(child)
101 child_args = child._kwargs(child)
103 child_args.remove(arg)
104 args.extend(child_args)
107 def _declareInternalSplittableKwargsFunction(self, function):
109 FUNCTION can be either a bare splittableKwargsFuntion or one
110 such function followed by a sequence of masked arguments.
112 Example values for FUNCTION:
117 self._childSplittables.append(function)
119 def make_splittable_kwargs_function(function, *internal_splittables):
120 function._kwargs = _kwargs
121 function._childSplittables = []
122 function._declareInternalSplittableKwargsFunction = \
123 _declareInternalSplittableKwargsFunction
124 def _splitargs(self, **kwargs):
125 return splitargs(kwargs, *self._childSplittables)
126 function._splitargs = lambda self,kwargs : splitargs(kwargs, *self._childSplittables)
128 for splittable in internal_splittables:
129 function._declareInternalSplittableKwargsFunction(function, splittable)
131 class splittableKwargsFunction (object):
132 def __init__(self, *internal_splittables):
133 self.internal_splittables = internal_splittables
134 def __call__(self, function):
135 make_splittable_kwargs_function(function, *self.internal_splittables)
138 @splittableKwargsFunction()
141 return "foo: x "+str(x)+", y "+str(y)+", z "+str(z)+"\n"
143 @splittableKwargsFunction()
146 return "bar: a "+str(a)+", b "+str(b)+", c "+str(c)+"\n"
148 @splittableKwargsFunction((_bar, 'a'))
149 def _baz(d, **kwargs):
151 string = _bar(a=6, c=4, **kwargs)
152 return string + "baz: d "+str(d)+"\n"
154 @splittableKwargsFunction(_foo, _bar)
156 fookw,barkw = _fun._splitargs(_fun, kwargs)
159 class SplittableKwargsTestCase (unittest.TestCase):
160 @splittableKwargsFunction(_foo, _bar)
161 def foobar(self, **kwargs):
162 fookw,barkw = self.foobar._splitargs(self.foobar, kwargs)
163 string = _foo(**fookw)
164 string += _bar(**barkw)
168 def simple_joint(**kwargs):
172 string += "\n".join([" %s : %s" % (k,kwargs[k]) for k in keys])
174 fookw,barkw = splitargs(kwargs, _foo, _bar)
175 string += _foo(**fookw)
176 string += _bar(a=4, **barkw)
180 def deeper_joint(**kwargs):
184 string += "\n".join([" %s : %s" % (k,kwargs[k]) for k in keys])
186 fookw,bazkw = splitargs(kwargs, _foo, _baz)
187 string += _foo(**fookw)
188 string += _baz(**bazkw)
191 def testDocPreserved(self):
192 self.failUnless(_foo.__doc__ == "foo function", _foo.__doc__)
193 self.failUnless(_bar.__doc__ == "bar function", _bar.__doc__)
194 self.failUnless(_baz.__doc__ == "baz function", _baz.__doc__)
195 def testSingleLevelMethod(self):
196 expected = """foo: x 1, y 2, z 3
199 output = self.foobar(x=1, y=2, z=3, a=8, b=5)
200 self.failUnless(output == expected,
201 "GOT\n%s\nEXPECTED\n%s" % (output, expected))
202 def testSingleLevelFunction(self):
203 simple_joint = SplittableKwargsTestCase.simple_joint
204 expected = """joining
212 output = simple_joint(x=1, y=2, z=3, b=5)
213 self.failUnless(output == expected,
214 "GOT\n%s\nEXPECTED\n%s" % (output, expected))
215 def testSingleLevelUnknownKwarg(self):
216 simple_joint = SplittableKwargsTestCase.simple_joint
217 self.assertRaises(UnknownKwarg, simple_joint,
218 y=2, z=3, b=5, unknown=6)
219 def testDoubleLevel(self):
220 deeper_joint = SplittableKwargsTestCase.deeper_joint
221 expected = """joining
231 output = deeper_joint(x=1, y=2, z=3, b=5, d=6)
232 self.failUnless(output == expected,
233 "GOT\n%s\nEXPECTED\n%s" % (output, expected))
234 def testDoubleLevelOverrideRequired(self):
235 deeper_joint = SplittableKwargsTestCase.deeper_joint
236 self.assertRaises(UnknownKwarg, deeper_joint,
237 x=8, y=2, z=3, b=5, a=1)
238 def testRecursiveSplitargsReference(self):
239 # Access the ._splitargs method using the defined function name
240 expected = ({'y':3, 'z':4}, {'a':1, 'b':2})
241 output = _fun(a=1,b=2,y=3,z=4)
242 self.failUnless(output == expected,
243 "GOT\n%s\nEXPECTED\n%s" % (str(output), str(expected)))
245 if __name__ == "__main__":
248 unitsuite = unittest.TestLoader().loadTestsFromTestCase( \
249 SplittableKwargsTestCase)
250 result = unittest.TextTestRunner(verbosity=2).run(unitsuite)
251 numErrors = len(result.errors)
252 numFailures = len(result.failures)
253 numBad = numErrors + numFailures