Began versioning
authorW. Trevor King <wking@drexel.edu>
Sun, 21 Dec 2008 03:51:15 +0000 (22:51 -0500)
committerW. Trevor King <wking@drexel.edu>
Sun, 21 Dec 2008 03:51:15 +0000 (22:51 -0500)
README [new file with mode: 0644]
parse.py [new file with mode: 0755]
splittable_kwargs.py [new file with mode: 0755]

diff --git a/README b/README
new file mode 100644 (file)
index 0000000..fc11b35
--- /dev/null
+++ b/README
@@ -0,0 +1,42 @@
+Problem: how to split **kwargs among several functions?
+Which args should go to which functions?
+For example:
+
+  def foo(x, y, z=2):
+      print "foo: x", x, "y", y, "z", z
+  def bar(a, b, c=2):
+      print "bar: a", a, "b", b, "c", c
+  def baz(d, **kwargs):
+      bar(c=4, **kwargs)
+      print "baz: d", d
+  def simple_joint(**kwargs):
+      print "joining"
+      print '\n'.join(["  %s : %s" % (k,v) for k,v in kwargs.items()])
+      fookw,barkw = splitargs([foo,bar], kwargs)
+      foo(**fookw)
+      bar(**barkw)
+      print ""
+  def deeper_joint(**kwargs):
+      print "joining"
+      print '\n'.join(["  %s : %s" % (k,v) for k,v in kwargs.items()])
+      fookw,bazkw = splitargs([foo,baz], kwargs)
+      foo(**fookw)
+      baz(**bazkw)
+      print ""
+  
+  # A single level is easily handled, since the list of arguments can be
+  # read off the output of inspect.getargspec().
+  simple_joint(x=1, y=2, z=3, a=4, b=5)
+  try:
+      simple_joint(x=1, y=2, z=3, a=4, b=5, unknown=5)
+  except Exception, e:
+      print e
+  
+  # This is much more complicated.  With multiple levels of **kwargs to
+  # drill down, we need to parse the definition of each function to see
+  # what arguments the first **kwargs can take.
+  deeper_joint(x=1, y=2, z=3, a=4, b=5, d=6)
+
+A first attempt at a solution (which works fairly well, but is a bit
+of a brain-melter) is given in parse.py.  A much simpler and more
+robust solution is given in splittable_kwargs.py.
diff --git a/parse.py b/parse.py
new file mode 100755 (executable)
index 0000000..414b50c
--- /dev/null
+++ b/parse.py
@@ -0,0 +1,317 @@
+#!/usr/bin/python
+#
+# Problem: how to split **kwargs among several functions?
+# Which args should go to which functions?
+#
+# The parsing code is inspired by and developed from Prashanth Ellina's
+#   http://blog.prashanthellina.com/2007/11/14/generating-call-graphs-for-understanding-and-refactoring-python-code/
+#
+# See
+#   http://docs.python.org/reference/grammar.html
+# for the Python grammar
+
+import parser
+import symbol
+import token
+import inspect
+
+import pprint
+import copy
+
+def get_parse_item_code(parse_item):
+    if isinstance(parse_item, list):
+        item_code = parse_item[0]
+    else:
+        item_code = parse_item
+    return item_code
+
+def code_to_string(parse_code):
+    if parse_code in symbol.sym_name:
+        code_string = symbol.sym_name[parse_code]
+    else:
+        code_string = token.tok_name[parse_code]
+    return code_string
+
+def annotate_parse_list(parse_list, toplevel=True):
+    if toplevel == True:
+        parse_list = copy.deepcopy(parse_list)
+    parse_list[0] = code_to_string(parse_list[0])
+
+    for index, item in enumerate(parse_list):
+        if index == 0: continue
+        if isinstance(item, list):
+            parse_list[index] = annotate_parse_list(item, toplevel=False)
+    return parse_list
+
+def match_item_code(parse_item, expected_code):
+    parse_code = get_parse_item_code(parse_item)
+    return parse_code == expected_code
+
+def assert_item_code(parse_item, expected_code):
+    if not match_item_code(parse_item, expected_code):
+        raise Exception, "Unexpected code %s != %s in\n%s" \
+            % (code_to_string(parse_code),
+               code_to_string(expected_code),
+               pprint.pformat(annotate_parse_list(parse_item)))
+
+def drill_for_item_code(parse_item, target_code):
+    """Depth first search for the target_code"""
+    parse_code = get_parse_item_code(parse_item)
+    if parse_code == target_code:
+        return parse_item
+    if isinstance(parse_item, list):
+        for child_item in parse_item[1:]:
+            ret = drill_for_item_code(child_item, target_code)
+            if ret != None:
+                return ret
+    return None
+
+def parse_atom(atom):
+    """
+    atom: ('(' [yield_expr|testlist_gexp] ')' |
+       '[' [listmaker] ']' |
+       '{' [dictmaker] '}' |
+       '`' testlist1 '`' |
+       NAME | NUMBER | STRING+)
+    """
+    assert_item_code(atom, symbol.atom)
+    first_child = atom[1]
+    first_child_code = first_child[0]
+    if first_child_code == token.NAME:
+        return first_child[1]
+    elif first_child_code == token.NUMBER:
+        return first_child[1]
+    elif first_child_code == token.STRING:
+        return first_child[1]
+    return None
+
+def find_test_name(parse_item):
+    """
+    Drill down to the atom naming this test item
+    """
+    assert_item_code(parse_item, symbol.test)
+    atom_item = drill_for_item_code(parse_item, symbol.atom)
+    if atom_item == None:
+        return None
+    return parse_atom(atom_item)
+
+def parse_argument(parse_item):
+    "argument: test [gen_for] | test '=' test  # Really [keyword '='] test"
+    assert_item_code(parse_item, symbol.argument)
+    if len(parse_item) == 4: # test '=' test
+        arg_name = find_test_name(parse_item[1])
+        assert_item_code(parse_item[2], token.EQUAL)
+        arg_value = find_test_name(parse_item[3])
+    elif len(parse_item) == 2: # test
+        arg_name = None
+        arg_value = find_test_name(parse_item[1])
+    else: # test gen_for
+        raise NotImplementedError, \
+            '"test gen_for" argument\n%s' \
+            % pprint.pformat(annotate_parse_list(parse_item))
+    return (arg_name, arg_value)
+
+def parse_arglist(parse_item):
+    """
+    arglist: (argument ',')* (argument [',']
+                         |'*' test (',' argument)* [',' '**' test] 
+                         |'**' test)
+    """
+    assert_item_code(parse_item, symbol.arglist)
+    args = []
+    kwargs = {}
+    varargs = None
+    varkw = None
+    i = 1
+    while i < len(parse_item):
+        item = parse_item[i]
+        argument_code = get_parse_item_code(item)
+        if argument_code == token.COMMA:
+            pass
+        elif argument_code == symbol.argument:
+            arg_name,arg_value = parse_argument(item)
+            if arg_name == None:
+                args.append(arg_value)
+            else:
+                kwargs[arg_name] = arg_value
+        elif argument_code == token.DOUBLESTAR:
+            i += 1
+            item = parse_item[i]
+            varkw = find_test_name(item)
+        else:
+            raise NotImplementedError, \
+                '"Unknown arglist item %s in\n%s' \
+                % (code_to_string(argument_code),
+                   pprint.pformat(annotate_parse_list(parse_item)))
+        i += 1
+    return (args, kwargs, varargs, varkw)
+
+def parse_trailer(parse_item):
+    """
+    trailer: '(' [arglist] ')' | '[' subscriptlist ']' | '.' NAME
+    """
+    assert_item_code(parse_item, symbol.trailer)
+    if match_item_code(parse_item[1], token.LPAR):
+        assert_item_code(parse_item[-1], token.RPAR)
+        if len(parse_item) == 4: # '(' arglist ')'
+            return parse_arglist(parse_item[2])
+        else: # '(' ')'
+            assert len(parse_item) == 3, \
+                pprint.pformat(annotate_parse_list(parse_item))
+            return None
+    # Ignore other trailers (they can't contain arglists)
+    return None
+
+def parse_power_fn(parse_item):
+    """
+    power: atom trailer* ['**' factor]
+    """
+    assert_item_code(parse_item, symbol.power)
+    if len(parse_item) < 3: # "atom"
+        return None
+    fn_name = parse_atom(parse_item[1])
+    arglists = []
+    for item in parse_item[2:]:
+        if match_item_code(item, symbol.trailer):
+            ret = parse_trailer(item)
+            if ret != None:
+                arglists.append(ret)
+        # ignore the possible "'**' factor"
+    return (fn_name, arglists)
+
+def find_fn_calls(parse_item):
+    calls = []
+    if match_item_code(parse_item, symbol.power):
+        ret = parse_power_fn(parse_item)
+        if ret != None:
+            fn_name,arglists = ret
+            if fn_name not in dir(__builtins__):
+                calls.append(ret)
+
+    for item in parse_item[1:]:
+        if isinstance(item, list):
+            calls.extend(find_fn_calls(item))
+    return calls
+
+def get_called_functions(fn, witharg=None):
+    """
+    Return a list of functions called by the function FN if they are
+    passed an argument WITHARG.
+    """
+    called_fns = []
+    source = inspect.getsource(fn)
+    suite = parser.suite(source)
+    parse_list = parser.st2list(suite)
+    #annotated_parse_list = annotate_parse_list(parse_list)
+    #pprint.pprint(annotated_parse_list)
+    calls = find_fn_calls(parse_list)
+    for fn_name,arglists in calls:
+        for i,arglist in enumerate(arglists):
+            args,kwargs,varargs,varkw = arglist
+            has_witharg = False
+            if witharg in kwargs:
+                has_witharg = True
+            elif varargs != None and witharg == "*"+varargs:
+                has_witharg = True
+            elif varkw != None and witharg == "**"+varkw:
+                has_witharg = True
+            if has_witharg:
+                assert i == 0, \
+                    "Don't know function name (stacked arglists)\n%s %s\n\n%s"\
+                    % (fn_name, arglists, source)
+                if fn_name not in called_fns:
+                    called_fns.append(fn_name)
+    return called_fns
+
+def get_fn_by_name(fn_name):
+    if fn_name in globals():
+        fn_instance = globals()[fn_name]
+    else:
+        raise Exception, "No source for function %s" % fn_name
+    return fn_instance
+
+def get_fn_args(fn, recurse=True, checked_fns=None):
+    """
+    Return a list of argument names accepted by FN.  If RECURSE ==
+    True and FN takes a **kwargs style argument, then attempt to drill
+    down into the children functions called inside FN to find what
+    argument names that **kwargs argument will accept.
+    
+    The recursive drilling is quite cludgy, and currently only supports
+    child functions where **kwargs is used directly.  E.g. if you manually
+    pop a value from kwargs, and pass the popped value on to a child, this
+    function will not notice.
+    """
+    if checked_fns == None:
+        checked_fns = [fn]
+    print "get args for %s (checked %s)" % (fn, checked_fns)
+    args,varargs,varkw,defaults = inspect.getargspec(fn)
+    if varargs != None:
+       raise NotImplementedError, "\n %s" % varargs
+    if varkw != None and recurse == True:
+        # This step is probably not worth the trouble
+        children = get_called_functions(fn, witharg='**'+varkw)
+        for f in children:
+            f_instance = get_fn_by_name(f)
+            args.extend(get_fn_args(f_instance, checked_fns=checked_fns))
+    return args
+
+def splitargs(fn_list, kwargs):
+    # get list of allowed arguments
+    fn_args = []
+    for fn in fn_list:
+        fn_args.append(get_fn_args(fn))
+
+    # sort the kwargs according to the appropriate function
+    fn_kwargs = [{} for fn in fn_list]
+    for key,value in kwargs.items():
+        sorted = False
+        for i,fn,args in zip(range(len(fn_list)), fn_list, fn_args):
+            if key in args:
+                fn_kwargs[i][key] = value
+                sorted = True
+                break
+        if sorted != True:
+            raise Exception, "Unrecognized argument %s = %s" % (key, value)
+    return fn_kwargs
+
+
+def foo(x, y, z=2):
+    print "foo: x", x, "y", y, "z", z
+
+def bar(a, b, c=2):
+    print "bar: a", a, "b", b, "c", c
+
+def baz(d, **kwargs):
+    bar(c=4, **kwargs)
+    print "baz: d", d
+
+def simple_joint(**kwargs):
+    print "joining"
+    print '\n'.join(["  %s : %s" % (k,v) for k,v in kwargs.items()])
+    fookw,barkw = splitargs([foo,bar], kwargs)
+    foo(**fookw)
+    bar(**barkw)
+    print ""
+
+def deeper_joint(**kwargs):
+    print "joining"
+    print '\n'.join(["  %s : %s" % (k,v) for k,v in kwargs.items()])
+    fookw,bazkw = splitargs([foo,baz], kwargs)
+    foo(**fookw)
+    baz(**bazkw)
+    print ""
+
+
+# A single level is easily handled, since the list of arguments can be
+# read off the output of inspect.getargspec().
+simple_joint(x=1, y=2, z=3, a=4, b=5)
+try:
+    simple_joint(x=1, y=2, z=3, a=4, b=5, unknown=5)
+except Exception, e:
+    print e
+
+# This is much more complicated.  With multiple levels of **kwargs to
+# drill down, we need to parse the definition of each function to see
+# what arguments the first **kwargs can take.
+deeper_joint(x=1, y=2, z=3, a=4, b=5, d=6)
diff --git a/splittable_kwargs.py b/splittable_kwargs.py
new file mode 100755 (executable)
index 0000000..732ad37
--- /dev/null
@@ -0,0 +1,256 @@
+#!/usr/bin/python
+"""
+splittable_kwargs allows the splitting of **kwargs arguments among
+several functions.  This
+
+Copyright (C)  W. Trevor King  2008
+This code is released to the public domain.
+
+Example usage (adapted from the unittests)
+
+  @splittableKwargsFunction()
+  def foo(x, y, z=2):
+      return "foo: x "+str(x)+", y "+str(y)+", z "+str(z)+"\n"
+  
+  @splittableKwargsFunction()
+  def bar(a, b, c=2):
+      return "bar: a "+str(a)+", b "+str(b)+", c "+str(c)+"\n"
+  
+  @splittableKwargsFunction((bar, 'a'))
+  def baz(d, **kwargs):
+      string = bar(c=4, **kwargs)
+      return string + "baz: d "+str(d)+"\n"
+  
+  @splittableKwargsFunction(foo, bar)
+  def simple_joint(**kwargs):
+      fookw,bazkw = simple_joint._splitargs(simple_joint, kwargs)
+      string = foo(**fookw)
+      string += baz(**bazkw)
+      return string
+  
+  simple_joint(y=3,c=1,d=5)
+
+If, say, simple_joint's children had not been defined (you wanted to
+define bar after simple_joint in your module), you can skip the
+decorator in simple_joint, and make it splittable later (after you
+define bar) with
+
+  make_splittable_kwargs_function(simple_joint, foo, bar)
+"""
+
+import inspect
+import unittest
+
+class UnknownKwarg (KeyError):
+    def __init__(self, fn, kwarg, value):
+        if hasattr(fn, "_kwargs"):
+            fn_list = [fn]
+        else:
+            fn_list = fn
+        msg = "Unknown kwarg %s = %s.  Allowed:\n" \
+            % (kwarg, value)
+        for f in fn_list:
+            msg += "  %s %s\n" % (f.__name__, f._kwargs(f))
+        KeyError.__init__(self, msg)
+
+def _parse_splittable(splittable):
+    """
+      splittable -> (splittable_fn, masked_args)
+    """
+    if hasattr(splittable, "_kwargs"): # bare splittableKwargsFunction
+        return (splittable, [])
+    else: # function followed by masked args
+        return (splittable[0], splittable[1:])
+
+def splitargs(kwargs, *internal_splittables):
+    """
+    where
+      *internal_splittables : a list of splittableKwargsFunctions items
+          that this function uses internally.
+      the items can be either
+          a bare splittableKwargsFunction
+      or a tuple where the additional elements are arguments to mask
+          a (bare splittableKwargsFunction, masked argument, ...)
+    """
+    # sort the kwargs according to the appropriate function
+    fn_kwargs = [{} for splittable in internal_splittables]
+    for key,value in kwargs.items():
+        sorted = False
+        for i,splittable in enumerate(internal_splittables):
+            fn,masked = _parse_splittable(splittable)
+            if key in fn._kwargs(fn) and key not in masked:
+                fn_kwargs[i][key] = value
+                sorted = True
+                break
+        if sorted != True:
+            fns = []
+            for splittable in internal_splittables:
+                fn,masked = _parse_splittable(splittable)
+                fns.append(fn)
+            raise UnknownKwarg(fns, key, value)
+    return fn_kwargs
+
+def _kwargs(self):
+    "return a list of acceptable kwargs"
+    args,varargs,varkw,defaults = inspect.getargspec(self)
+    if varargs != None:
+        raise NotImplementedError, "\n %s" % varargs
+    if varkw != None:
+        for child in self._childSplittables:
+            child,masked = _parse_splittable(child)
+            child_args = child._kwargs(child)
+            for arg in masked:
+                child_args.remove(arg)
+            args.extend(child_args)
+    return args
+
+def _declareInternalSplittableKwargsFunction(self, function):
+    """
+    FUNCTION can be either a bare splittableKwargsFuntion or one
+    such function followed by a sequence of masked arguments.
+    
+    Example values for FUNCTION:
+    bar
+    (bar, "a")
+    (bar, "a", "b")
+    """
+    self._childSplittables.append(function)
+
+def make_splittable_kwargs_function(function, *internal_splittables):
+    function._kwargs = _kwargs
+    function._childSplittables = []
+    function._declareInternalSplittableKwargsFunction = \
+             _declareInternalSplittableKwargsFunction
+    def _splitargs(self, **kwargs):
+        return splitargs(kwargs, *self._childSplittables)
+    function._splitargs = lambda self,kwargs : splitargs(kwargs, *self._childSplittables)
+    
+    for splittable in internal_splittables:
+        function._declareInternalSplittableKwargsFunction(function, splittable)
+
+class splittableKwargsFunction (object):
+    def __init__(self, *internal_splittables):
+        self.internal_splittables = internal_splittables
+    def __call__(self, function):
+        make_splittable_kwargs_function(function, *self.internal_splittables)
+        return function
+
+@splittableKwargsFunction()
+def _foo(x, y, z=2):
+    "foo function"
+    return "foo: x "+str(x)+", y "+str(y)+", z "+str(z)+"\n"
+
+@splittableKwargsFunction()
+def _bar(a, b, c=2):
+    "bar function"
+    return "bar: a "+str(a)+", b "+str(b)+", c "+str(c)+"\n"
+
+@splittableKwargsFunction((_bar, 'a'))
+def _baz(d, **kwargs):
+    "baz function"
+    string = _bar(a=6, c=4, **kwargs)
+    return string + "baz: d "+str(d)+"\n"
+
+@splittableKwargsFunction(_foo, _bar)
+def _fun(**kwargs):
+    fookw,barkw = _fun._splitargs(_fun, kwargs)
+    return(fookw, barkw)
+
+class SplittableKwargsTestCase (unittest.TestCase):
+    @splittableKwargsFunction(_foo, _bar)
+    def foobar(self, **kwargs):
+        fookw,barkw = self.foobar._splitargs(self.foobar, kwargs)
+        string = _foo(**fookw)
+        string += _bar(**barkw)
+        return string
+
+    @staticmethod
+    def simple_joint(**kwargs):
+        string = "joining\n"
+        keys = kwargs.keys()
+        keys.sort()
+        string += "\n".join(["  %s : %s" % (k,kwargs[k]) for k in keys])
+        string += "\n"
+        fookw,barkw = splitargs(kwargs, _foo, _bar)
+        string += _foo(**fookw)
+        string += _bar(a=4, **barkw)
+        return string
+
+    @staticmethod
+    def deeper_joint(**kwargs):
+        string = "joining\n"
+        keys = kwargs.keys()
+        keys.sort()
+        string += "\n".join(["  %s : %s" % (k,kwargs[k]) for k in keys])
+        string += "\n"
+        fookw,bazkw = splitargs(kwargs, _foo, _baz)
+        string += _foo(**fookw)
+        string += _baz(**bazkw)
+        return string
+
+    def testDocPreserved(self):
+        self.failUnless(_foo.__doc__ == "foo function", _foo.__doc__)
+        self.failUnless(_bar.__doc__ == "bar function", _bar.__doc__)
+        self.failUnless(_baz.__doc__ == "baz function", _baz.__doc__)
+    def testSingleLevelMethod(self):
+        expected = """foo: x 1, y 2, z 3
+bar: a 8, b 5, c 2
+"""
+        output = self.foobar(x=1, y=2, z=3, a=8, b=5)
+        self.failUnless(output == expected,
+                        "GOT\n%s\nEXPECTED\n%s" % (output, expected))
+    def testSingleLevelFunction(self):
+        simple_joint = SplittableKwargsTestCase.simple_joint
+        expected = """joining
+  b : 5
+  x : 1
+  y : 2
+  z : 3
+foo: x 1, y 2, z 3
+bar: a 4, b 5, c 2
+"""
+        output = simple_joint(x=1, y=2, z=3, b=5)
+        self.failUnless(output == expected,
+                        "GOT\n%s\nEXPECTED\n%s" % (output, expected))
+    def testSingleLevelUnknownKwarg(self):
+        simple_joint = SplittableKwargsTestCase.simple_joint
+        self.assertRaises(UnknownKwarg, simple_joint,
+                          y=2, z=3, b=5, unknown=6)
+    def testDoubleLevel(self):
+        deeper_joint = SplittableKwargsTestCase.deeper_joint
+        expected = """joining
+  b : 5
+  d : 6
+  x : 1
+  y : 2
+  z : 3
+foo: x 1, y 2, z 3
+bar: a 6, b 5, c 4
+baz: d 6
+"""
+        output = deeper_joint(x=1, y=2, z=3, b=5, d=6)
+        self.failUnless(output == expected,
+                        "GOT\n%s\nEXPECTED\n%s" % (output, expected))
+    def testDoubleLevelOverrideRequired(self):
+        deeper_joint = SplittableKwargsTestCase.deeper_joint
+        self.assertRaises(UnknownKwarg, deeper_joint,
+                          x=8, y=2, z=3, b=5, a=1)
+    def testRecursiveSplitargsReference(self):
+        # Access the ._splitargs method using the defined function name
+        expected = ({'y':3, 'z':4}, {'a':1, 'b':2})
+        output = _fun(a=1,b=2,y=3,z=4)
+        self.failUnless(output == expected,
+                        "GOT\n%s\nEXPECTED\n%s" % (str(output), str(expected)))
+
+if __name__ == "__main__":
+    import sys
+    
+    unitsuite = unittest.TestLoader().loadTestsFromTestCase( \
+                                                      SplittableKwargsTestCase)
+    result = unittest.TextTestRunner(verbosity=2).run(unitsuite)
+    numErrors = len(result.errors)
+    numFailures = len(result.failures)
+    numBad = numErrors + numFailures
+    if numBad > 126:
+        numBad = 1
+    sys.exit(numBad)