From: W. Trevor King Date: Sun, 21 Dec 2008 03:51:15 +0000 (-0500) Subject: Began versioning X-Git-Tag: 0.2~2 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=a38d3c473dae1b4b40f17bd68b464162b22023d4;p=splittable_kwargs.git Began versioning --- a38d3c473dae1b4b40f17bd68b464162b22023d4 diff --git a/README b/README new file mode 100644 index 0000000..fc11b35 --- /dev/null +++ b/README @@ -0,0 +1,42 @@ +Problem: how to split **kwargs among several functions? +Which args should go to which functions? +For example: + + def foo(x, y, z=2): + print "foo: x", x, "y", y, "z", z + def bar(a, b, c=2): + print "bar: a", a, "b", b, "c", c + def baz(d, **kwargs): + bar(c=4, **kwargs) + print "baz: d", d + def simple_joint(**kwargs): + print "joining" + print '\n'.join([" %s : %s" % (k,v) for k,v in kwargs.items()]) + fookw,barkw = splitargs([foo,bar], kwargs) + foo(**fookw) + bar(**barkw) + print "" + def deeper_joint(**kwargs): + print "joining" + print '\n'.join([" %s : %s" % (k,v) for k,v in kwargs.items()]) + fookw,bazkw = splitargs([foo,baz], kwargs) + foo(**fookw) + baz(**bazkw) + print "" + + # A single level is easily handled, since the list of arguments can be + # read off the output of inspect.getargspec(). + simple_joint(x=1, y=2, z=3, a=4, b=5) + try: + simple_joint(x=1, y=2, z=3, a=4, b=5, unknown=5) + except Exception, e: + print e + + # This is much more complicated. With multiple levels of **kwargs to + # drill down, we need to parse the definition of each function to see + # what arguments the first **kwargs can take. + deeper_joint(x=1, y=2, z=3, a=4, b=5, d=6) + +A first attempt at a solution (which works fairly well, but is a bit +of a brain-melter) is given in parse.py. A much simpler and more +robust solution is given in splittable_kwargs.py. diff --git a/parse.py b/parse.py new file mode 100755 index 0000000..414b50c --- /dev/null +++ b/parse.py @@ -0,0 +1,317 @@ +#!/usr/bin/python +# +# Problem: how to split **kwargs among several functions? +# Which args should go to which functions? +# +# The parsing code is inspired by and developed from Prashanth Ellina's +# http://blog.prashanthellina.com/2007/11/14/generating-call-graphs-for-understanding-and-refactoring-python-code/ +# +# See +# http://docs.python.org/reference/grammar.html +# for the Python grammar + +import parser +import symbol +import token +import inspect + +import pprint +import copy + +def get_parse_item_code(parse_item): + if isinstance(parse_item, list): + item_code = parse_item[0] + else: + item_code = parse_item + return item_code + +def code_to_string(parse_code): + if parse_code in symbol.sym_name: + code_string = symbol.sym_name[parse_code] + else: + code_string = token.tok_name[parse_code] + return code_string + +def annotate_parse_list(parse_list, toplevel=True): + if toplevel == True: + parse_list = copy.deepcopy(parse_list) + parse_list[0] = code_to_string(parse_list[0]) + + for index, item in enumerate(parse_list): + if index == 0: continue + if isinstance(item, list): + parse_list[index] = annotate_parse_list(item, toplevel=False) + return parse_list + +def match_item_code(parse_item, expected_code): + parse_code = get_parse_item_code(parse_item) + return parse_code == expected_code + +def assert_item_code(parse_item, expected_code): + if not match_item_code(parse_item, expected_code): + raise Exception, "Unexpected code %s != %s in\n%s" \ + % (code_to_string(parse_code), + code_to_string(expected_code), + pprint.pformat(annotate_parse_list(parse_item))) + +def drill_for_item_code(parse_item, target_code): + """Depth first search for the target_code""" + parse_code = get_parse_item_code(parse_item) + if parse_code == target_code: + return parse_item + if isinstance(parse_item, list): + for child_item in parse_item[1:]: + ret = drill_for_item_code(child_item, target_code) + if ret != None: + return ret + return None + +def parse_atom(atom): + """ + atom: ('(' [yield_expr|testlist_gexp] ')' | + '[' [listmaker] ']' | + '{' [dictmaker] '}' | + '`' testlist1 '`' | + NAME | NUMBER | STRING+) + """ + assert_item_code(atom, symbol.atom) + first_child = atom[1] + first_child_code = first_child[0] + if first_child_code == token.NAME: + return first_child[1] + elif first_child_code == token.NUMBER: + return first_child[1] + elif first_child_code == token.STRING: + return first_child[1] + return None + +def find_test_name(parse_item): + """ + Drill down to the atom naming this test item + """ + assert_item_code(parse_item, symbol.test) + atom_item = drill_for_item_code(parse_item, symbol.atom) + if atom_item == None: + return None + return parse_atom(atom_item) + +def parse_argument(parse_item): + "argument: test [gen_for] | test '=' test # Really [keyword '='] test" + assert_item_code(parse_item, symbol.argument) + if len(parse_item) == 4: # test '=' test + arg_name = find_test_name(parse_item[1]) + assert_item_code(parse_item[2], token.EQUAL) + arg_value = find_test_name(parse_item[3]) + elif len(parse_item) == 2: # test + arg_name = None + arg_value = find_test_name(parse_item[1]) + else: # test gen_for + raise NotImplementedError, \ + '"test gen_for" argument\n%s' \ + % pprint.pformat(annotate_parse_list(parse_item)) + return (arg_name, arg_value) + +def parse_arglist(parse_item): + """ + arglist: (argument ',')* (argument [','] + |'*' test (',' argument)* [',' '**' test] + |'**' test) + """ + assert_item_code(parse_item, symbol.arglist) + args = [] + kwargs = {} + varargs = None + varkw = None + i = 1 + while i < len(parse_item): + item = parse_item[i] + argument_code = get_parse_item_code(item) + if argument_code == token.COMMA: + pass + elif argument_code == symbol.argument: + arg_name,arg_value = parse_argument(item) + if arg_name == None: + args.append(arg_value) + else: + kwargs[arg_name] = arg_value + elif argument_code == token.DOUBLESTAR: + i += 1 + item = parse_item[i] + varkw = find_test_name(item) + else: + raise NotImplementedError, \ + '"Unknown arglist item %s in\n%s' \ + % (code_to_string(argument_code), + pprint.pformat(annotate_parse_list(parse_item))) + i += 1 + return (args, kwargs, varargs, varkw) + +def parse_trailer(parse_item): + """ + trailer: '(' [arglist] ')' | '[' subscriptlist ']' | '.' NAME + """ + assert_item_code(parse_item, symbol.trailer) + if match_item_code(parse_item[1], token.LPAR): + assert_item_code(parse_item[-1], token.RPAR) + if len(parse_item) == 4: # '(' arglist ')' + return parse_arglist(parse_item[2]) + else: # '(' ')' + assert len(parse_item) == 3, \ + pprint.pformat(annotate_parse_list(parse_item)) + return None + # Ignore other trailers (they can't contain arglists) + return None + +def parse_power_fn(parse_item): + """ + power: atom trailer* ['**' factor] + """ + assert_item_code(parse_item, symbol.power) + if len(parse_item) < 3: # "atom" + return None + fn_name = parse_atom(parse_item[1]) + arglists = [] + for item in parse_item[2:]: + if match_item_code(item, symbol.trailer): + ret = parse_trailer(item) + if ret != None: + arglists.append(ret) + # ignore the possible "'**' factor" + return (fn_name, arglists) + +def find_fn_calls(parse_item): + calls = [] + if match_item_code(parse_item, symbol.power): + ret = parse_power_fn(parse_item) + if ret != None: + fn_name,arglists = ret + if fn_name not in dir(__builtins__): + calls.append(ret) + + for item in parse_item[1:]: + if isinstance(item, list): + calls.extend(find_fn_calls(item)) + return calls + +def get_called_functions(fn, witharg=None): + """ + Return a list of functions called by the function FN if they are + passed an argument WITHARG. + """ + called_fns = [] + source = inspect.getsource(fn) + suite = parser.suite(source) + parse_list = parser.st2list(suite) + #annotated_parse_list = annotate_parse_list(parse_list) + #pprint.pprint(annotated_parse_list) + calls = find_fn_calls(parse_list) + for fn_name,arglists in calls: + for i,arglist in enumerate(arglists): + args,kwargs,varargs,varkw = arglist + has_witharg = False + if witharg in kwargs: + has_witharg = True + elif varargs != None and witharg == "*"+varargs: + has_witharg = True + elif varkw != None and witharg == "**"+varkw: + has_witharg = True + if has_witharg: + assert i == 0, \ + "Don't know function name (stacked arglists)\n%s %s\n\n%s"\ + % (fn_name, arglists, source) + if fn_name not in called_fns: + called_fns.append(fn_name) + return called_fns + +def get_fn_by_name(fn_name): + if fn_name in globals(): + fn_instance = globals()[fn_name] + else: + raise Exception, "No source for function %s" % fn_name + return fn_instance + +def get_fn_args(fn, recurse=True, checked_fns=None): + """ + Return a list of argument names accepted by FN. If RECURSE == + True and FN takes a **kwargs style argument, then attempt to drill + down into the children functions called inside FN to find what + argument names that **kwargs argument will accept. + + The recursive drilling is quite cludgy, and currently only supports + child functions where **kwargs is used directly. E.g. if you manually + pop a value from kwargs, and pass the popped value on to a child, this + function will not notice. + """ + if checked_fns == None: + checked_fns = [fn] + print "get args for %s (checked %s)" % (fn, checked_fns) + args,varargs,varkw,defaults = inspect.getargspec(fn) + if varargs != None: + raise NotImplementedError, "\n %s" % varargs + if varkw != None and recurse == True: + # This step is probably not worth the trouble + children = get_called_functions(fn, witharg='**'+varkw) + for f in children: + f_instance = get_fn_by_name(f) + args.extend(get_fn_args(f_instance, checked_fns=checked_fns)) + return args + +def splitargs(fn_list, kwargs): + # get list of allowed arguments + fn_args = [] + for fn in fn_list: + fn_args.append(get_fn_args(fn)) + + # sort the kwargs according to the appropriate function + fn_kwargs = [{} for fn in fn_list] + for key,value in kwargs.items(): + sorted = False + for i,fn,args in zip(range(len(fn_list)), fn_list, fn_args): + if key in args: + fn_kwargs[i][key] = value + sorted = True + break + if sorted != True: + raise Exception, "Unrecognized argument %s = %s" % (key, value) + return fn_kwargs + + +def foo(x, y, z=2): + print "foo: x", x, "y", y, "z", z + +def bar(a, b, c=2): + print "bar: a", a, "b", b, "c", c + +def baz(d, **kwargs): + bar(c=4, **kwargs) + print "baz: d", d + +def simple_joint(**kwargs): + print "joining" + print '\n'.join([" %s : %s" % (k,v) for k,v in kwargs.items()]) + fookw,barkw = splitargs([foo,bar], kwargs) + foo(**fookw) + bar(**barkw) + print "" + +def deeper_joint(**kwargs): + print "joining" + print '\n'.join([" %s : %s" % (k,v) for k,v in kwargs.items()]) + fookw,bazkw = splitargs([foo,baz], kwargs) + foo(**fookw) + baz(**bazkw) + print "" + + +# A single level is easily handled, since the list of arguments can be +# read off the output of inspect.getargspec(). +simple_joint(x=1, y=2, z=3, a=4, b=5) +try: + simple_joint(x=1, y=2, z=3, a=4, b=5, unknown=5) +except Exception, e: + print e + +# This is much more complicated. With multiple levels of **kwargs to +# drill down, we need to parse the definition of each function to see +# what arguments the first **kwargs can take. +deeper_joint(x=1, y=2, z=3, a=4, b=5, d=6) diff --git a/splittable_kwargs.py b/splittable_kwargs.py new file mode 100755 index 0000000..732ad37 --- /dev/null +++ b/splittable_kwargs.py @@ -0,0 +1,256 @@ +#!/usr/bin/python +""" +splittable_kwargs allows the splitting of **kwargs arguments among +several functions. This + +Copyright (C) W. Trevor King 2008 +This code is released to the public domain. + +Example usage (adapted from the unittests) + + @splittableKwargsFunction() + def foo(x, y, z=2): + return "foo: x "+str(x)+", y "+str(y)+", z "+str(z)+"\n" + + @splittableKwargsFunction() + def bar(a, b, c=2): + return "bar: a "+str(a)+", b "+str(b)+", c "+str(c)+"\n" + + @splittableKwargsFunction((bar, 'a')) + def baz(d, **kwargs): + string = bar(c=4, **kwargs) + return string + "baz: d "+str(d)+"\n" + + @splittableKwargsFunction(foo, bar) + def simple_joint(**kwargs): + fookw,bazkw = simple_joint._splitargs(simple_joint, kwargs) + string = foo(**fookw) + string += baz(**bazkw) + return string + + simple_joint(y=3,c=1,d=5) + +If, say, simple_joint's children had not been defined (you wanted to +define bar after simple_joint in your module), you can skip the +decorator in simple_joint, and make it splittable later (after you +define bar) with + + make_splittable_kwargs_function(simple_joint, foo, bar) +""" + +import inspect +import unittest + +class UnknownKwarg (KeyError): + def __init__(self, fn, kwarg, value): + if hasattr(fn, "_kwargs"): + fn_list = [fn] + else: + fn_list = fn + msg = "Unknown kwarg %s = %s. Allowed:\n" \ + % (kwarg, value) + for f in fn_list: + msg += " %s %s\n" % (f.__name__, f._kwargs(f)) + KeyError.__init__(self, msg) + +def _parse_splittable(splittable): + """ + splittable -> (splittable_fn, masked_args) + """ + if hasattr(splittable, "_kwargs"): # bare splittableKwargsFunction + return (splittable, []) + else: # function followed by masked args + return (splittable[0], splittable[1:]) + +def splitargs(kwargs, *internal_splittables): + """ + where + *internal_splittables : a list of splittableKwargsFunctions items + that this function uses internally. + the items can be either + a bare splittableKwargsFunction + or a tuple where the additional elements are arguments to mask + a (bare splittableKwargsFunction, masked argument, ...) + """ + # sort the kwargs according to the appropriate function + fn_kwargs = [{} for splittable in internal_splittables] + for key,value in kwargs.items(): + sorted = False + for i,splittable in enumerate(internal_splittables): + fn,masked = _parse_splittable(splittable) + if key in fn._kwargs(fn) and key not in masked: + fn_kwargs[i][key] = value + sorted = True + break + if sorted != True: + fns = [] + for splittable in internal_splittables: + fn,masked = _parse_splittable(splittable) + fns.append(fn) + raise UnknownKwarg(fns, key, value) + return fn_kwargs + +def _kwargs(self): + "return a list of acceptable kwargs" + args,varargs,varkw,defaults = inspect.getargspec(self) + if varargs != None: + raise NotImplementedError, "\n %s" % varargs + if varkw != None: + for child in self._childSplittables: + child,masked = _parse_splittable(child) + child_args = child._kwargs(child) + for arg in masked: + child_args.remove(arg) + args.extend(child_args) + return args + +def _declareInternalSplittableKwargsFunction(self, function): + """ + FUNCTION can be either a bare splittableKwargsFuntion or one + such function followed by a sequence of masked arguments. + + Example values for FUNCTION: + bar + (bar, "a") + (bar, "a", "b") + """ + self._childSplittables.append(function) + +def make_splittable_kwargs_function(function, *internal_splittables): + function._kwargs = _kwargs + function._childSplittables = [] + function._declareInternalSplittableKwargsFunction = \ + _declareInternalSplittableKwargsFunction + def _splitargs(self, **kwargs): + return splitargs(kwargs, *self._childSplittables) + function._splitargs = lambda self,kwargs : splitargs(kwargs, *self._childSplittables) + + for splittable in internal_splittables: + function._declareInternalSplittableKwargsFunction(function, splittable) + +class splittableKwargsFunction (object): + def __init__(self, *internal_splittables): + self.internal_splittables = internal_splittables + def __call__(self, function): + make_splittable_kwargs_function(function, *self.internal_splittables) + return function + +@splittableKwargsFunction() +def _foo(x, y, z=2): + "foo function" + return "foo: x "+str(x)+", y "+str(y)+", z "+str(z)+"\n" + +@splittableKwargsFunction() +def _bar(a, b, c=2): + "bar function" + return "bar: a "+str(a)+", b "+str(b)+", c "+str(c)+"\n" + +@splittableKwargsFunction((_bar, 'a')) +def _baz(d, **kwargs): + "baz function" + string = _bar(a=6, c=4, **kwargs) + return string + "baz: d "+str(d)+"\n" + +@splittableKwargsFunction(_foo, _bar) +def _fun(**kwargs): + fookw,barkw = _fun._splitargs(_fun, kwargs) + return(fookw, barkw) + +class SplittableKwargsTestCase (unittest.TestCase): + @splittableKwargsFunction(_foo, _bar) + def foobar(self, **kwargs): + fookw,barkw = self.foobar._splitargs(self.foobar, kwargs) + string = _foo(**fookw) + string += _bar(**barkw) + return string + + @staticmethod + def simple_joint(**kwargs): + string = "joining\n" + keys = kwargs.keys() + keys.sort() + string += "\n".join([" %s : %s" % (k,kwargs[k]) for k in keys]) + string += "\n" + fookw,barkw = splitargs(kwargs, _foo, _bar) + string += _foo(**fookw) + string += _bar(a=4, **barkw) + return string + + @staticmethod + def deeper_joint(**kwargs): + string = "joining\n" + keys = kwargs.keys() + keys.sort() + string += "\n".join([" %s : %s" % (k,kwargs[k]) for k in keys]) + string += "\n" + fookw,bazkw = splitargs(kwargs, _foo, _baz) + string += _foo(**fookw) + string += _baz(**bazkw) + return string + + def testDocPreserved(self): + self.failUnless(_foo.__doc__ == "foo function", _foo.__doc__) + self.failUnless(_bar.__doc__ == "bar function", _bar.__doc__) + self.failUnless(_baz.__doc__ == "baz function", _baz.__doc__) + def testSingleLevelMethod(self): + expected = """foo: x 1, y 2, z 3 +bar: a 8, b 5, c 2 +""" + output = self.foobar(x=1, y=2, z=3, a=8, b=5) + self.failUnless(output == expected, + "GOT\n%s\nEXPECTED\n%s" % (output, expected)) + def testSingleLevelFunction(self): + simple_joint = SplittableKwargsTestCase.simple_joint + expected = """joining + b : 5 + x : 1 + y : 2 + z : 3 +foo: x 1, y 2, z 3 +bar: a 4, b 5, c 2 +""" + output = simple_joint(x=1, y=2, z=3, b=5) + self.failUnless(output == expected, + "GOT\n%s\nEXPECTED\n%s" % (output, expected)) + def testSingleLevelUnknownKwarg(self): + simple_joint = SplittableKwargsTestCase.simple_joint + self.assertRaises(UnknownKwarg, simple_joint, + y=2, z=3, b=5, unknown=6) + def testDoubleLevel(self): + deeper_joint = SplittableKwargsTestCase.deeper_joint + expected = """joining + b : 5 + d : 6 + x : 1 + y : 2 + z : 3 +foo: x 1, y 2, z 3 +bar: a 6, b 5, c 4 +baz: d 6 +""" + output = deeper_joint(x=1, y=2, z=3, b=5, d=6) + self.failUnless(output == expected, + "GOT\n%s\nEXPECTED\n%s" % (output, expected)) + def testDoubleLevelOverrideRequired(self): + deeper_joint = SplittableKwargsTestCase.deeper_joint + self.assertRaises(UnknownKwarg, deeper_joint, + x=8, y=2, z=3, b=5, a=1) + def testRecursiveSplitargsReference(self): + # Access the ._splitargs method using the defined function name + expected = ({'y':3, 'z':4}, {'a':1, 'b':2}) + output = _fun(a=1,b=2,y=3,z=4) + self.failUnless(output == expected, + "GOT\n%s\nEXPECTED\n%s" % (str(output), str(expected))) + +if __name__ == "__main__": + import sys + + unitsuite = unittest.TestLoader().loadTestsFromTestCase( \ + SplittableKwargsTestCase) + result = unittest.TextTestRunner(verbosity=2).run(unitsuite) + numErrors = len(result.errors) + numFailures = len(result.failures) + numBad = numErrors + numFailures + if numBad > 126: + numBad = 1 + sys.exit(numBad)