--- /dev/null
+Problem: how to split **kwargs among several functions?
+Which args should go to which functions?
+For example:
+
+ def foo(x, y, z=2):
+ print "foo: x", x, "y", y, "z", z
+ def bar(a, b, c=2):
+ print "bar: a", a, "b", b, "c", c
+ def baz(d, **kwargs):
+ bar(c=4, **kwargs)
+ print "baz: d", d
+ def simple_joint(**kwargs):
+ print "joining"
+ print '\n'.join([" %s : %s" % (k,v) for k,v in kwargs.items()])
+ fookw,barkw = splitargs([foo,bar], kwargs)
+ foo(**fookw)
+ bar(**barkw)
+ print ""
+ def deeper_joint(**kwargs):
+ print "joining"
+ print '\n'.join([" %s : %s" % (k,v) for k,v in kwargs.items()])
+ fookw,bazkw = splitargs([foo,baz], kwargs)
+ foo(**fookw)
+ baz(**bazkw)
+ print ""
+
+ # A single level is easily handled, since the list of arguments can be
+ # read off the output of inspect.getargspec().
+ simple_joint(x=1, y=2, z=3, a=4, b=5)
+ try:
+ simple_joint(x=1, y=2, z=3, a=4, b=5, unknown=5)
+ except Exception, e:
+ print e
+
+ # This is much more complicated. With multiple levels of **kwargs to
+ # drill down, we need to parse the definition of each function to see
+ # what arguments the first **kwargs can take.
+ deeper_joint(x=1, y=2, z=3, a=4, b=5, d=6)
+
+A first attempt at a solution (which works fairly well, but is a bit
+of a brain-melter) is given in parse.py. A much simpler and more
+robust solution is given in splittable_kwargs.py.
--- /dev/null
+#!/usr/bin/python
+#
+# Problem: how to split **kwargs among several functions?
+# Which args should go to which functions?
+#
+# The parsing code is inspired by and developed from Prashanth Ellina's
+# http://blog.prashanthellina.com/2007/11/14/generating-call-graphs-for-understanding-and-refactoring-python-code/
+#
+# See
+# http://docs.python.org/reference/grammar.html
+# for the Python grammar
+
+import parser
+import symbol
+import token
+import inspect
+
+import pprint
+import copy
+
+def get_parse_item_code(parse_item):
+ if isinstance(parse_item, list):
+ item_code = parse_item[0]
+ else:
+ item_code = parse_item
+ return item_code
+
+def code_to_string(parse_code):
+ if parse_code in symbol.sym_name:
+ code_string = symbol.sym_name[parse_code]
+ else:
+ code_string = token.tok_name[parse_code]
+ return code_string
+
+def annotate_parse_list(parse_list, toplevel=True):
+ if toplevel == True:
+ parse_list = copy.deepcopy(parse_list)
+ parse_list[0] = code_to_string(parse_list[0])
+
+ for index, item in enumerate(parse_list):
+ if index == 0: continue
+ if isinstance(item, list):
+ parse_list[index] = annotate_parse_list(item, toplevel=False)
+ return parse_list
+
+def match_item_code(parse_item, expected_code):
+ parse_code = get_parse_item_code(parse_item)
+ return parse_code == expected_code
+
+def assert_item_code(parse_item, expected_code):
+ if not match_item_code(parse_item, expected_code):
+ raise Exception, "Unexpected code %s != %s in\n%s" \
+ % (code_to_string(parse_code),
+ code_to_string(expected_code),
+ pprint.pformat(annotate_parse_list(parse_item)))
+
+def drill_for_item_code(parse_item, target_code):
+ """Depth first search for the target_code"""
+ parse_code = get_parse_item_code(parse_item)
+ if parse_code == target_code:
+ return parse_item
+ if isinstance(parse_item, list):
+ for child_item in parse_item[1:]:
+ ret = drill_for_item_code(child_item, target_code)
+ if ret != None:
+ return ret
+ return None
+
+def parse_atom(atom):
+ """
+ atom: ('(' [yield_expr|testlist_gexp] ')' |
+ '[' [listmaker] ']' |
+ '{' [dictmaker] '}' |
+ '`' testlist1 '`' |
+ NAME | NUMBER | STRING+)
+ """
+ assert_item_code(atom, symbol.atom)
+ first_child = atom[1]
+ first_child_code = first_child[0]
+ if first_child_code == token.NAME:
+ return first_child[1]
+ elif first_child_code == token.NUMBER:
+ return first_child[1]
+ elif first_child_code == token.STRING:
+ return first_child[1]
+ return None
+
+def find_test_name(parse_item):
+ """
+ Drill down to the atom naming this test item
+ """
+ assert_item_code(parse_item, symbol.test)
+ atom_item = drill_for_item_code(parse_item, symbol.atom)
+ if atom_item == None:
+ return None
+ return parse_atom(atom_item)
+
+def parse_argument(parse_item):
+ "argument: test [gen_for] | test '=' test # Really [keyword '='] test"
+ assert_item_code(parse_item, symbol.argument)
+ if len(parse_item) == 4: # test '=' test
+ arg_name = find_test_name(parse_item[1])
+ assert_item_code(parse_item[2], token.EQUAL)
+ arg_value = find_test_name(parse_item[3])
+ elif len(parse_item) == 2: # test
+ arg_name = None
+ arg_value = find_test_name(parse_item[1])
+ else: # test gen_for
+ raise NotImplementedError, \
+ '"test gen_for" argument\n%s' \
+ % pprint.pformat(annotate_parse_list(parse_item))
+ return (arg_name, arg_value)
+
+def parse_arglist(parse_item):
+ """
+ arglist: (argument ',')* (argument [',']
+ |'*' test (',' argument)* [',' '**' test]
+ |'**' test)
+ """
+ assert_item_code(parse_item, symbol.arglist)
+ args = []
+ kwargs = {}
+ varargs = None
+ varkw = None
+ i = 1
+ while i < len(parse_item):
+ item = parse_item[i]
+ argument_code = get_parse_item_code(item)
+ if argument_code == token.COMMA:
+ pass
+ elif argument_code == symbol.argument:
+ arg_name,arg_value = parse_argument(item)
+ if arg_name == None:
+ args.append(arg_value)
+ else:
+ kwargs[arg_name] = arg_value
+ elif argument_code == token.DOUBLESTAR:
+ i += 1
+ item = parse_item[i]
+ varkw = find_test_name(item)
+ else:
+ raise NotImplementedError, \
+ '"Unknown arglist item %s in\n%s' \
+ % (code_to_string(argument_code),
+ pprint.pformat(annotate_parse_list(parse_item)))
+ i += 1
+ return (args, kwargs, varargs, varkw)
+
+def parse_trailer(parse_item):
+ """
+ trailer: '(' [arglist] ')' | '[' subscriptlist ']' | '.' NAME
+ """
+ assert_item_code(parse_item, symbol.trailer)
+ if match_item_code(parse_item[1], token.LPAR):
+ assert_item_code(parse_item[-1], token.RPAR)
+ if len(parse_item) == 4: # '(' arglist ')'
+ return parse_arglist(parse_item[2])
+ else: # '(' ')'
+ assert len(parse_item) == 3, \
+ pprint.pformat(annotate_parse_list(parse_item))
+ return None
+ # Ignore other trailers (they can't contain arglists)
+ return None
+
+def parse_power_fn(parse_item):
+ """
+ power: atom trailer* ['**' factor]
+ """
+ assert_item_code(parse_item, symbol.power)
+ if len(parse_item) < 3: # "atom"
+ return None
+ fn_name = parse_atom(parse_item[1])
+ arglists = []
+ for item in parse_item[2:]:
+ if match_item_code(item, symbol.trailer):
+ ret = parse_trailer(item)
+ if ret != None:
+ arglists.append(ret)
+ # ignore the possible "'**' factor"
+ return (fn_name, arglists)
+
+def find_fn_calls(parse_item):
+ calls = []
+ if match_item_code(parse_item, symbol.power):
+ ret = parse_power_fn(parse_item)
+ if ret != None:
+ fn_name,arglists = ret
+ if fn_name not in dir(__builtins__):
+ calls.append(ret)
+
+ for item in parse_item[1:]:
+ if isinstance(item, list):
+ calls.extend(find_fn_calls(item))
+ return calls
+
+def get_called_functions(fn, witharg=None):
+ """
+ Return a list of functions called by the function FN if they are
+ passed an argument WITHARG.
+ """
+ called_fns = []
+ source = inspect.getsource(fn)
+ suite = parser.suite(source)
+ parse_list = parser.st2list(suite)
+ #annotated_parse_list = annotate_parse_list(parse_list)
+ #pprint.pprint(annotated_parse_list)
+ calls = find_fn_calls(parse_list)
+ for fn_name,arglists in calls:
+ for i,arglist in enumerate(arglists):
+ args,kwargs,varargs,varkw = arglist
+ has_witharg = False
+ if witharg in kwargs:
+ has_witharg = True
+ elif varargs != None and witharg == "*"+varargs:
+ has_witharg = True
+ elif varkw != None and witharg == "**"+varkw:
+ has_witharg = True
+ if has_witharg:
+ assert i == 0, \
+ "Don't know function name (stacked arglists)\n%s %s\n\n%s"\
+ % (fn_name, arglists, source)
+ if fn_name not in called_fns:
+ called_fns.append(fn_name)
+ return called_fns
+
+def get_fn_by_name(fn_name):
+ if fn_name in globals():
+ fn_instance = globals()[fn_name]
+ else:
+ raise Exception, "No source for function %s" % fn_name
+ return fn_instance
+
+def get_fn_args(fn, recurse=True, checked_fns=None):
+ """
+ Return a list of argument names accepted by FN. If RECURSE ==
+ True and FN takes a **kwargs style argument, then attempt to drill
+ down into the children functions called inside FN to find what
+ argument names that **kwargs argument will accept.
+
+ The recursive drilling is quite cludgy, and currently only supports
+ child functions where **kwargs is used directly. E.g. if you manually
+ pop a value from kwargs, and pass the popped value on to a child, this
+ function will not notice.
+ """
+ if checked_fns == None:
+ checked_fns = [fn]
+ print "get args for %s (checked %s)" % (fn, checked_fns)
+ args,varargs,varkw,defaults = inspect.getargspec(fn)
+ if varargs != None:
+ raise NotImplementedError, "\n %s" % varargs
+ if varkw != None and recurse == True:
+ # This step is probably not worth the trouble
+ children = get_called_functions(fn, witharg='**'+varkw)
+ for f in children:
+ f_instance = get_fn_by_name(f)
+ args.extend(get_fn_args(f_instance, checked_fns=checked_fns))
+ return args
+
+def splitargs(fn_list, kwargs):
+ # get list of allowed arguments
+ fn_args = []
+ for fn in fn_list:
+ fn_args.append(get_fn_args(fn))
+
+ # sort the kwargs according to the appropriate function
+ fn_kwargs = [{} for fn in fn_list]
+ for key,value in kwargs.items():
+ sorted = False
+ for i,fn,args in zip(range(len(fn_list)), fn_list, fn_args):
+ if key in args:
+ fn_kwargs[i][key] = value
+ sorted = True
+ break
+ if sorted != True:
+ raise Exception, "Unrecognized argument %s = %s" % (key, value)
+ return fn_kwargs
+
+
+def foo(x, y, z=2):
+ print "foo: x", x, "y", y, "z", z
+
+def bar(a, b, c=2):
+ print "bar: a", a, "b", b, "c", c
+
+def baz(d, **kwargs):
+ bar(c=4, **kwargs)
+ print "baz: d", d
+
+def simple_joint(**kwargs):
+ print "joining"
+ print '\n'.join([" %s : %s" % (k,v) for k,v in kwargs.items()])
+ fookw,barkw = splitargs([foo,bar], kwargs)
+ foo(**fookw)
+ bar(**barkw)
+ print ""
+
+def deeper_joint(**kwargs):
+ print "joining"
+ print '\n'.join([" %s : %s" % (k,v) for k,v in kwargs.items()])
+ fookw,bazkw = splitargs([foo,baz], kwargs)
+ foo(**fookw)
+ baz(**bazkw)
+ print ""
+
+
+# A single level is easily handled, since the list of arguments can be
+# read off the output of inspect.getargspec().
+simple_joint(x=1, y=2, z=3, a=4, b=5)
+try:
+ simple_joint(x=1, y=2, z=3, a=4, b=5, unknown=5)
+except Exception, e:
+ print e
+
+# This is much more complicated. With multiple levels of **kwargs to
+# drill down, we need to parse the definition of each function to see
+# what arguments the first **kwargs can take.
+deeper_joint(x=1, y=2, z=3, a=4, b=5, d=6)
--- /dev/null
+#!/usr/bin/python
+"""
+splittable_kwargs allows the splitting of **kwargs arguments among
+several functions. This
+
+Copyright (C) W. Trevor King 2008
+This code is released to the public domain.
+
+Example usage (adapted from the unittests)
+
+ @splittableKwargsFunction()
+ def foo(x, y, z=2):
+ return "foo: x "+str(x)+", y "+str(y)+", z "+str(z)+"\n"
+
+ @splittableKwargsFunction()
+ def bar(a, b, c=2):
+ return "bar: a "+str(a)+", b "+str(b)+", c "+str(c)+"\n"
+
+ @splittableKwargsFunction((bar, 'a'))
+ def baz(d, **kwargs):
+ string = bar(c=4, **kwargs)
+ return string + "baz: d "+str(d)+"\n"
+
+ @splittableKwargsFunction(foo, bar)
+ def simple_joint(**kwargs):
+ fookw,bazkw = simple_joint._splitargs(simple_joint, kwargs)
+ string = foo(**fookw)
+ string += baz(**bazkw)
+ return string
+
+ simple_joint(y=3,c=1,d=5)
+
+If, say, simple_joint's children had not been defined (you wanted to
+define bar after simple_joint in your module), you can skip the
+decorator in simple_joint, and make it splittable later (after you
+define bar) with
+
+ make_splittable_kwargs_function(simple_joint, foo, bar)
+"""
+
+import inspect
+import unittest
+
+class UnknownKwarg (KeyError):
+ def __init__(self, fn, kwarg, value):
+ if hasattr(fn, "_kwargs"):
+ fn_list = [fn]
+ else:
+ fn_list = fn
+ msg = "Unknown kwarg %s = %s. Allowed:\n" \
+ % (kwarg, value)
+ for f in fn_list:
+ msg += " %s %s\n" % (f.__name__, f._kwargs(f))
+ KeyError.__init__(self, msg)
+
+def _parse_splittable(splittable):
+ """
+ splittable -> (splittable_fn, masked_args)
+ """
+ if hasattr(splittable, "_kwargs"): # bare splittableKwargsFunction
+ return (splittable, [])
+ else: # function followed by masked args
+ return (splittable[0], splittable[1:])
+
+def splitargs(kwargs, *internal_splittables):
+ """
+ where
+ *internal_splittables : a list of splittableKwargsFunctions items
+ that this function uses internally.
+ the items can be either
+ a bare splittableKwargsFunction
+ or a tuple where the additional elements are arguments to mask
+ a (bare splittableKwargsFunction, masked argument, ...)
+ """
+ # sort the kwargs according to the appropriate function
+ fn_kwargs = [{} for splittable in internal_splittables]
+ for key,value in kwargs.items():
+ sorted = False
+ for i,splittable in enumerate(internal_splittables):
+ fn,masked = _parse_splittable(splittable)
+ if key in fn._kwargs(fn) and key not in masked:
+ fn_kwargs[i][key] = value
+ sorted = True
+ break
+ if sorted != True:
+ fns = []
+ for splittable in internal_splittables:
+ fn,masked = _parse_splittable(splittable)
+ fns.append(fn)
+ raise UnknownKwarg(fns, key, value)
+ return fn_kwargs
+
+def _kwargs(self):
+ "return a list of acceptable kwargs"
+ args,varargs,varkw,defaults = inspect.getargspec(self)
+ if varargs != None:
+ raise NotImplementedError, "\n %s" % varargs
+ if varkw != None:
+ for child in self._childSplittables:
+ child,masked = _parse_splittable(child)
+ child_args = child._kwargs(child)
+ for arg in masked:
+ child_args.remove(arg)
+ args.extend(child_args)
+ return args
+
+def _declareInternalSplittableKwargsFunction(self, function):
+ """
+ FUNCTION can be either a bare splittableKwargsFuntion or one
+ such function followed by a sequence of masked arguments.
+
+ Example values for FUNCTION:
+ bar
+ (bar, "a")
+ (bar, "a", "b")
+ """
+ self._childSplittables.append(function)
+
+def make_splittable_kwargs_function(function, *internal_splittables):
+ function._kwargs = _kwargs
+ function._childSplittables = []
+ function._declareInternalSplittableKwargsFunction = \
+ _declareInternalSplittableKwargsFunction
+ def _splitargs(self, **kwargs):
+ return splitargs(kwargs, *self._childSplittables)
+ function._splitargs = lambda self,kwargs : splitargs(kwargs, *self._childSplittables)
+
+ for splittable in internal_splittables:
+ function._declareInternalSplittableKwargsFunction(function, splittable)
+
+class splittableKwargsFunction (object):
+ def __init__(self, *internal_splittables):
+ self.internal_splittables = internal_splittables
+ def __call__(self, function):
+ make_splittable_kwargs_function(function, *self.internal_splittables)
+ return function
+
+@splittableKwargsFunction()
+def _foo(x, y, z=2):
+ "foo function"
+ return "foo: x "+str(x)+", y "+str(y)+", z "+str(z)+"\n"
+
+@splittableKwargsFunction()
+def _bar(a, b, c=2):
+ "bar function"
+ return "bar: a "+str(a)+", b "+str(b)+", c "+str(c)+"\n"
+
+@splittableKwargsFunction((_bar, 'a'))
+def _baz(d, **kwargs):
+ "baz function"
+ string = _bar(a=6, c=4, **kwargs)
+ return string + "baz: d "+str(d)+"\n"
+
+@splittableKwargsFunction(_foo, _bar)
+def _fun(**kwargs):
+ fookw,barkw = _fun._splitargs(_fun, kwargs)
+ return(fookw, barkw)
+
+class SplittableKwargsTestCase (unittest.TestCase):
+ @splittableKwargsFunction(_foo, _bar)
+ def foobar(self, **kwargs):
+ fookw,barkw = self.foobar._splitargs(self.foobar, kwargs)
+ string = _foo(**fookw)
+ string += _bar(**barkw)
+ return string
+
+ @staticmethod
+ def simple_joint(**kwargs):
+ string = "joining\n"
+ keys = kwargs.keys()
+ keys.sort()
+ string += "\n".join([" %s : %s" % (k,kwargs[k]) for k in keys])
+ string += "\n"
+ fookw,barkw = splitargs(kwargs, _foo, _bar)
+ string += _foo(**fookw)
+ string += _bar(a=4, **barkw)
+ return string
+
+ @staticmethod
+ def deeper_joint(**kwargs):
+ string = "joining\n"
+ keys = kwargs.keys()
+ keys.sort()
+ string += "\n".join([" %s : %s" % (k,kwargs[k]) for k in keys])
+ string += "\n"
+ fookw,bazkw = splitargs(kwargs, _foo, _baz)
+ string += _foo(**fookw)
+ string += _baz(**bazkw)
+ return string
+
+ def testDocPreserved(self):
+ self.failUnless(_foo.__doc__ == "foo function", _foo.__doc__)
+ self.failUnless(_bar.__doc__ == "bar function", _bar.__doc__)
+ self.failUnless(_baz.__doc__ == "baz function", _baz.__doc__)
+ def testSingleLevelMethod(self):
+ expected = """foo: x 1, y 2, z 3
+bar: a 8, b 5, c 2
+"""
+ output = self.foobar(x=1, y=2, z=3, a=8, b=5)
+ self.failUnless(output == expected,
+ "GOT\n%s\nEXPECTED\n%s" % (output, expected))
+ def testSingleLevelFunction(self):
+ simple_joint = SplittableKwargsTestCase.simple_joint
+ expected = """joining
+ b : 5
+ x : 1
+ y : 2
+ z : 3
+foo: x 1, y 2, z 3
+bar: a 4, b 5, c 2
+"""
+ output = simple_joint(x=1, y=2, z=3, b=5)
+ self.failUnless(output == expected,
+ "GOT\n%s\nEXPECTED\n%s" % (output, expected))
+ def testSingleLevelUnknownKwarg(self):
+ simple_joint = SplittableKwargsTestCase.simple_joint
+ self.assertRaises(UnknownKwarg, simple_joint,
+ y=2, z=3, b=5, unknown=6)
+ def testDoubleLevel(self):
+ deeper_joint = SplittableKwargsTestCase.deeper_joint
+ expected = """joining
+ b : 5
+ d : 6
+ x : 1
+ y : 2
+ z : 3
+foo: x 1, y 2, z 3
+bar: a 6, b 5, c 4
+baz: d 6
+"""
+ output = deeper_joint(x=1, y=2, z=3, b=5, d=6)
+ self.failUnless(output == expected,
+ "GOT\n%s\nEXPECTED\n%s" % (output, expected))
+ def testDoubleLevelOverrideRequired(self):
+ deeper_joint = SplittableKwargsTestCase.deeper_joint
+ self.assertRaises(UnknownKwarg, deeper_joint,
+ x=8, y=2, z=3, b=5, a=1)
+ def testRecursiveSplitargsReference(self):
+ # Access the ._splitargs method using the defined function name
+ expected = ({'y':3, 'z':4}, {'a':1, 'b':2})
+ output = _fun(a=1,b=2,y=3,z=4)
+ self.failUnless(output == expected,
+ "GOT\n%s\nEXPECTED\n%s" % (str(output), str(expected)))
+
+if __name__ == "__main__":
+ import sys
+
+ unitsuite = unittest.TestLoader().loadTestsFromTestCase( \
+ SplittableKwargsTestCase)
+ result = unittest.TextTestRunner(verbosity=2).run(unitsuite)
+ numErrors = len(result.errors)
+ numFailures = len(result.failures)
+ numBad = numErrors + numFailures
+ if numBad > 126:
+ numBad = 1
+ sys.exit(numBad)