Bumped to version 0.3, adding translated functionality.
[splittable_kwargs.git] / splittable_kwargs.py
1 #!/usr/bin/python
2 """
3 splittable_kwargs allows the splitting of **kwargs arguments among
4 several functions.
5
6 Copyright (C)  W. Trevor King  2008, 2009
7 This code is released to the public domain.
8
9 Example usage (adapted from the unittests)
10
11     >>> from splittable_kwargs import *
12     >>> @splittableKwargsFunction()
13     ... def foo(x, y, z=2):
14     ...     return "foo: x "+str(x)+", y "+str(y)+", z "+str(z)+"\\n"
15     ... 
16     >>> @splittableKwargsFunction()
17     ... def bar(a, b, c=2):
18     ...     return "bar: a "+str(a)+", b "+str(b)+", c "+str(c)+"\\n"
19     >>> @splittableKwargsFunction({'fn':bar, 'mask':'a', 'translate':{'b':'b_bar'}})
20     ... def baz(d, **kwargs):
21     ...     barkw, = baz._splitargs(baz, kwargs) # for translation/masking
22     ...     string = bar(a=5, c=4, **barkw)
23     ...     return string + "baz: d "+str(d)+"\\n"
24     >>> @splittableKwargsFunction(foo, baz)
25     ... def simple_joint(**kwargs):
26     ...     fookw,bazkw = simple_joint._splitargs(simple_joint, kwargs)
27     ...     string = foo(**fookw)
28     ...     string += baz(**bazkw)
29     ...     return string.strip()
30     >>> print simple_joint(x=1,y=3,b_bar=6,d=7)
31     foo: x 1, y 3, z 2
32     bar: a 5, b 6, c 4
33     baz: d 7
34
35 If, say, simple_joint's children had not been defined (you wanted to
36 define baz after simple_joint in your module), you can skip the
37 decorator in simple_joint, and make it splittable later (after you
38 define baz) with
39
40    make_splittable_kwargs_function(simple_joint, foo, baz)
41
42 You can also get a list of the available named arguments with
43
44     >>> print baz._kwargs(baz)
45     ['d', 'b_bar', 'c']
46     >>> simple_joint._kwargs(simple_joint)
47     ['x', 'y', 'z', 'd', 'b_bar', 'c']
48
49 It may seem redundant to need to pass the function (here simple_joint)
50 to a method of simple_joint, but remember that simple_joint is a
51 _function_, not a class instance.  If it really bothers you, try
52 something like
53
54     >>> class ClassJoint (object):
55     ...     @splittableKwargsFunction(foo, baz)
56     ...     def __call__(self, **kwargs):
57     ...         fookw,bazkw = simple_joint._splitargs(simple_joint, kwargs)
58     ...         string = foo(**fookw)
59     ...         string += baz(**bazkw)
60     ...         return string.strip()
61     ...     def _kwargs(self):
62     ...         return self.__call__._kwargs(self.__call__)
63     >>> cj = ClassJoint()
64     >>> cj._kwargs()
65     ['self', 'x', 'y', 'z', 'd', 'b_bar', 'c']
66     >>> print cj(x=1,y=3,b_bar=6,d=7)
67     foo: x 1, y 3, z 2
68     bar: a 5, b 6, c 4
69     baz: d 7
70 """
71
72 import inspect
73 import doctest
74 import unittest
75
76 VERSION = "0.3"
77
78 class UnknownKwarg (KeyError):
79     def __init__(self, fn, kwarg, value):
80         if hasattr(fn, "_kwargs"):
81             fn_list = [fn]
82         else:
83             fn_list = fn
84         msg = "Unknown kwarg %s = %s.  Allowed:\n" \
85             % (kwarg, value)
86         for f in fn_list:
87             msg += "  %s %s\n" % (f.__name__, f._kwargs(f))
88         KeyError.__init__(self, msg)
89
90 def _parse_splittable(splittable):
91     """
92       splittable -> (splittable_fn, masked_args, translated_args)
93     """
94     if hasattr(splittable, "_kwargs"): # bare splittableKwargsFunction
95         return (splittable, [], {})
96     else: # masked/translated dict
97         return (splittable['fn'],
98                 splittable.get('mask', []),
99                 splittable.get('translate', {}))
100
101 def splitargs(kwargs, *internal_splittables):
102     """
103     where
104       *internal_splittables : a list of splittableKwargsFunctions
105           items that this function uses internally.  The items must be
106           parsable by _parse_splittable().
107     """
108     # sort the kwargs according to the appropriate function
109     fn_kwargs = [{} for splittable in internal_splittables]
110     for key,value in kwargs.items():
111         sorted = False
112         for i,splittable in enumerate(internal_splittables):
113             fn,masked,translated = _parse_splittable(splittable)
114             inv_translation = dict([(v,k) for k,v in translated.items()])
115             if key in inv_translation:
116                 child_key = inv_translation[key]
117                 sorted = True
118             elif key in fn._kwargs(fn) and key not in masked:
119                 child_key = key
120                 sorted = True
121             if sorted == True:
122                 fn_kwargs[i][child_key] = value
123                 break
124         if sorted != True: # couldn't find anywhere for that key.
125             fns = []
126             for splittable in internal_splittables:
127                 fn,masked,translated = _parse_splittable(splittable)
128                 fns.append(fn)
129             raise UnknownKwarg(fns, key, value)
130     return fn_kwargs
131
132 def _kwargs(self):
133     "Return a list of acceptable kwargs"
134     args,varargs,varkw,defaults = inspect.getargspec(self)
135     if varargs != None:
136         raise NotImplementedError, "\n %s" % varargs
137     if varkw != None:
138         for child in self._childSplittables:
139             child,masked,translated = _parse_splittable(child)
140             child_args = child._kwargs(child)
141             for arg in masked:
142                 try:
143                     child_args.remove(arg)
144                 except ValueError, e:
145                     msg = "%s not in %s" % (arg, child_args)
146                     raise ValueError(msg)
147             for original,new in translated.items():
148                 try:
149                     child_args[child_args.index(original)] = new
150                 except ValueError, e:
151                     msg = "%s not in %s" % (arg, child_args)
152                     raise ValueError(msg)
153             args.extend(child_args)
154     return args
155
156 def _declareInternalSplittableKwargsFunction(self, function):
157     """
158     FUNCTION can be either a bare splittableKwargsFuntion or a dict of
159     the form:
160       {'fn': splittableKwargsFuntion,
161        'mask':["argumentA",...],
162        'translate':{"argumentB":"argB_new_name", ...}}
163     
164     Example values for FUNCTION:
165       bar
166       {'fn':bar, mask:["a"]}
167       {'fn':bar, mask:["a", "b"], translate:{"c":"bar_c"}}}
168     """
169     self._childSplittables.append(function)
170
171 def make_splittable_kwargs_function(function, *internal_splittables):
172     function._kwargs = _kwargs
173     function._childSplittables = []
174     function._declareInternalSplittableKwargsFunction = \
175              _declareInternalSplittableKwargsFunction
176     def _splitargs(self, **kwargs):
177         return splitargs(kwargs, *self._childSplittables)
178     function._splitargs = lambda self,kwargs : splitargs(kwargs, *self._childSplittables)
179     
180     for splittable in internal_splittables:
181         function._declareInternalSplittableKwargsFunction(function, splittable)
182
183 class splittableKwargsFunction (object):
184     def __init__(self, *internal_splittables):
185         self.internal_splittables = internal_splittables
186     def __call__(self, function):
187         make_splittable_kwargs_function(function, *self.internal_splittables)
188         return function
189
190 @splittableKwargsFunction()
191 def _foo(x, y, z=2):
192     "foo function"
193     return "foo: x "+str(x)+", y "+str(y)+", z "+str(z)+"\n"
194
195 @splittableKwargsFunction()
196 def _bar(a, b, c=2):
197     "bar function"
198     return "bar: a "+str(a)+", b "+str(b)+", c "+str(c)+"\n"
199
200 @splittableKwargsFunction({'fn':_bar, 'mask':'a', 'translate':{'b':'b_bar'}})
201 def _baz(d, **kwargs):
202     "baz function"
203     barkw, = _baz._splitargs(_baz, kwargs) # for translation/masking
204     string = _bar(a=6, c=4, **barkw)
205     return string + "baz: d "+str(d)+"\n"
206
207 @splittableKwargsFunction(_foo, _bar)
208 def _fun(**kwargs):
209     fookw,barkw = _fun._splitargs(_fun, kwargs)
210     return(fookw, barkw)
211
212 class SplittableKwargsTestCase (unittest.TestCase):
213     @splittableKwargsFunction(_foo, _bar)
214     def foobar(self, **kwargs):
215         fookw,barkw = self.foobar._splitargs(self.foobar, kwargs)
216         string = _foo(**fookw)
217         string += _bar(**barkw)
218         return string
219
220     @staticmethod
221     def simple_joint(**kwargs):
222         string = "joining\n"
223         keys = kwargs.keys()
224         keys.sort()
225         string += "\n".join(["  %s : %s" % (k,kwargs[k]) for k in keys])
226         string += "\n"
227         fookw,barkw = splitargs(kwargs, _foo, _bar)
228         string += _foo(**fookw)
229         string += _bar(a=4, **barkw)
230         return string
231
232     @staticmethod
233     def deeper_joint(**kwargs):
234         string = "joining\n"
235         keys = kwargs.keys()
236         keys.sort()
237         string += "\n".join(["  %s : %s" % (k,kwargs[k]) for k in keys])
238         string += "\n"
239         fookw,bazkw = splitargs(kwargs, _foo, _baz)
240         string += _foo(**fookw)
241         string += _baz(**bazkw)
242         return string
243
244     def testDocPreserved(self):
245         self.failUnless(_foo.__doc__ == "foo function", _foo.__doc__)
246         self.failUnless(_bar.__doc__ == "bar function", _bar.__doc__)
247         self.failUnless(_baz.__doc__ == "baz function", _baz.__doc__)
248     def testSingleLevelMethod(self):
249         expected = """foo: x 1, y 2, z 3
250 bar: a 8, b 5, c 2
251 """
252         output = self.foobar(x=1, y=2, z=3, a=8, b=5)
253         self.failUnless(output == expected,
254                         "GOT\n%s\nEXPECTED\n%s" % (output, expected))
255     def testSingleLevelFunction(self):
256         simple_joint = SplittableKwargsTestCase.simple_joint
257         expected = """joining
258   b : 5
259   x : 1
260   y : 2
261   z : 3
262 foo: x 1, y 2, z 3
263 bar: a 4, b 5, c 2
264 """
265         output = simple_joint(x=1, y=2, z=3, b=5)
266         self.failUnless(output == expected,
267                         "GOT\n%s\nEXPECTED\n%s" % (output, expected))
268     def testSingleLevelUnknownKwarg(self):
269         simple_joint = SplittableKwargsTestCase.simple_joint
270         self.assertRaises(UnknownKwarg, simple_joint,
271                           y=2, z=3, b=5, unknown=6)
272     def testDoubleLevel(self):
273         deeper_joint = SplittableKwargsTestCase.deeper_joint
274         expected = """joining
275   b_bar : 5
276   d : 6
277   x : 1
278   y : 2
279   z : 3
280 foo: x 1, y 2, z 3
281 bar: a 6, b 5, c 4
282 baz: d 6
283 """
284         output = deeper_joint(x=1, y=2, z=3, b_bar=5, d=6)
285         self.failUnless(output == expected,
286                         "GOT\n%s\nEXPECTED\n%s" % (output, expected))
287     def testDoubleLevelOverrideRequired(self):
288         deeper_joint = SplittableKwargsTestCase.deeper_joint
289         self.assertRaises(UnknownKwarg, deeper_joint,
290                           x=8, y=2, z=3, b=5, a=1)
291     def testRecursiveSplitargsReference(self):
292         # Access the ._splitargs method using the defined function name
293         expected = ({'y':3, 'z':4}, {'a':1, 'b':2})
294         output = _fun(a=1,b=2,y=3,z=4)
295         self.failUnless(output == expected,
296                         "GOT\n%s\nEXPECTED\n%s" % (str(output), str(expected)))
297
298 if __name__ == "__main__":
299     import sys
300     
301     unitsuite = unittest.TestLoader().loadTestsFromTestCase( \
302                                                       SplittableKwargsTestCase)
303     unitsuite.addTest(doctest.DocTestSuite())
304     result = unittest.TextTestRunner(verbosity=2).run(unitsuite)
305     numErrors = len(result.errors)
306     numFailures = len(result.failures)
307     numBad = numErrors + numFailures
308     if numBad > 126:
309         numBad = 1
310     sys.exit(numBad)