'''module that contains argschema functions for converting
marshmallow schemas to argparse and merging dictionaries from both systems
'''
import logging
import argparse
from operator import add
import inspect
import marshmallow as mm
import collections
FIELD_TYPE_MAP = {v: k for k, v in mm.Schema.TYPE_MAPPING.items()}
[docs]def args_to_dict(argsobj):
"""function to convert namespace returned by argsparse into a nested dictionary
Parameters
----------
argsobj : argparse.Namespace
Namespace object returned by standard argparse.parse function
Returns
-------
dict
dictionary of namespace values where nesting elements uses '.' to denote nesting of keys
"""
d = {}
argsdict = vars(argsobj)
for field in argsdict.keys():
parts = field.split('.')
root = d
for i in range(len(parts)):
if i == (len(parts) - 1):
root[parts[i]] = argsdict.get(field)
else:
if parts[i] not in root.keys():
root[parts[i]] = {}
root = root[parts[i]]
return d
[docs]def merge_value(a, b, key, func=add):
"""attempt to merge these dictionaries using function defined by
func (default to add) raise an exception if this fails
Parameters
----------
a : dict
one dictionary
b : dict
second dictionary
key : key
key to merge dictionary values on
func : function(x
function that merges two values of this key
Returns (Default value = add)
func : a[key]
merged version of values (Default value = add)
Returns
-------
"""
try:
return func(a[key], b[key])
except:
raise Exception("Cannot merge this key {},\
for values {} and {} of types {} and {}".format
(key, a[key], b[key], type(a[key]), type(b[key])))
[docs]def smart_merge(a, b, path=None, merge_keys=None, overwrite_with_none=False):
"""updates dictionary a with values in dictionary b
being careful not to write things with None, and performing a merge on merge_keys
Parameters
----------
a : dict
dictionary to perform update on
b : dict
dictionary to perform update with
path : list
list of nested keys traversed so far (used for recursion) (Default value = None)
merge_keys : list
list of keys to do merging on (default None)
overwrite_with_none :
(Default value = False)
Returns
-------
dict
a dictionary that is a updated with b's values
"""
a = {} if a is None else a
b = {} if b is None else b
path = [] if path is None else path
#simplifies code to have empty list rather than None
#might allow some crazy dynamic merging in future
if merge_keys is None:
merge_keys = []
for key in b:
if key in a:
if isinstance(a[key], dict) and isinstance(b[key], dict):
# recursively merge these leafs
smart_merge(a[key], b[key], path + [str(key)], merge_keys)
elif a[key] == b[key]:
pass # same leaf value, so don't bother
elif b[key] is None:
if overwrite_with_none:
a[key] = b[key]
else:
# in this case we are potentially overwriting a's value with b's
# determine if we should try to merge
if key in merge_keys:
# attempt to merge leafs
a[key] = merge_value(a, b, key)
else: # otherwise replace leafs
a[key] = b[key]
else: # there is no corresponding leaf in a
if b[key] is None:
if overwrite_with_none:
a[key] = b[key]
else:
# otherwise replace entire leaf with b
a[key] = b[key]
return a
[docs]def get_description_from_field(field):
"""get the description for this marshmallow field
Parameters
----------
field : marshmallow.fields.field
field to get description
Returns
-------
str
description string (or None)
"""
#look for description
if 'description' in field.metadata:
desc = field.metadata.get('description')
#also look to see if description was added in metadata
else:
md = field.metadata.get('metadata', {})
if 'description' in md:
desc = md['description']
else:
desc = None
return desc
[docs]def build_schema_arguments(schema, arguments=None, path=None, description =None):
"""given a jsonschema, create a dictionary of argparse arguments,
by navigating down the Nested schema tree. (recursive function)
Parameters
----------
schema : marshmallow.Schema
schema with field.description filled in with help values
arguments : list or None
list of argument group dictionaries to add to (see Returns) (Default value = None)
path : list or None
list of strings denoted where you are in the tree (Default value = None)
description: str or None
description for the argument group at this level of the tree
Returns
-------
list
List of argument group dictionaries, with keys ['title','description','args']
which contain the arguments for argparse. 'args' is an OrderedDict of dictionaries
with keys of the argument names with kwargs to build an argparse argument
"""
path = [] if path is None else path
arguments = [] if arguments is None else arguments
arggroup = {}
#name this argument group by the path, or the schema class name if it's the root
if len(path)==0:
arggroup['title']=schema.__class__.__name__
else:
arggroup['title']='.'.join(path)
arggroup['args']=collections.OrderedDict()
#assume the description has been handed down
arggroup['description']=description
#sort the fields first by required, then by default values present or not
for field_name, field in sorted(schema.declared_fields.items(),
key= lambda x: 2*x[1].required+1*(x[1].default==mm.missing),
reverse=True):
#get this field's description
desc = get_description_from_field(field)
#if its nested, then we want to recusively follow this link
if isinstance(field, mm.fields.Nested):
if field.many:
logging.warning("many=True not supported from argparse")
else:
build_schema_arguments(field.schema,
arguments,
path + [field_name],
description = desc)
else:
# it's not nested then let's build the argument
arg = {}
arg_name = '--' + '.'.join(path + [field_name])
if desc is not None:
arg['help']=desc
else:
arg['help']=''
#programatically add helpful notes to help string
if field.default is not mm.missing:
arg['help']+= " (default={})".format(field.default)
if field.required:
arg['help']+= " (REQUIRED)"
for validator in field.validators:
if isinstance(validator,mm.validate.ContainsOnly):
arg['help']+= " (constrained list)"
if isinstance(validator,mm.validate.OneOf):
arg['help']+= " (valid options are {})".format(validator.choices)
#catch lists to figure out desired type
field_type = type(field)
if isinstance(field, mm.fields.List):
arg['nargs'] = '*'
container_type = type(field.container)
parent_classes = inspect.getmro(container_type)[1:]
# recurse to up the class tree to find out if this is a supported type
while (container_type not in FIELD_TYPE_MAP and
len(parent_classes)):
container_type = parent_classes[0]
parent_classes = parent_classes[1:]
if container_type in FIELD_TYPE_MAP:
arg['type'] = FIELD_TYPE_MAP[container_type]
else:
logging.warning("List contains unsupported type: %s" % str(
type(field.container)))
#otherwise look up the desired type in FIELD_TYPE_MAP
elif type(field) in FIELD_TYPE_MAP:
# it's a simple type, apply the mapping
arg['type'] = FIELD_TYPE_MAP[field_type]
#DON'T WANT TO USE DEFAULT VALUES AS ARGPARSE OVERRULES JSON
# if field.default != mm.missing:
# arg['default'] = field.default
#add this argument to the arggroup
arggroup['args'][arg_name] = arg
#tack on this arggroup to the list and return
arguments.append(arggroup)
return arguments
[docs]def schema_argparser(schema):
"""given a jsonschema, build an argparse.ArgumentParser
Parameters
----------
schema : argschema.schemas.ArgSchema
schema to build an argparser from
Returns
-------
argparse.ArgumentParser
the represents the schema
"""
#build up a list of argument groups using recursive function
#to traverse the tree, root node gets the description given by doc string
#of the schema
arguments = build_schema_arguments(schema,description=schema.__doc__)
#make the root schema appeear first rather than last
arguments = [arguments[-1]]+arguments[0:-1]
parser = argparse.ArgumentParser()
for arg_group in arguments:
group=parser.add_argument_group(arg_group['title'],arg_group['description'])
for arg_name,arg in arg_group['args'].items():
group.add_argument(arg_name, **arg)
return parser