Source code for argschema.utils

'''module that contains argschema functions for converting
marshmallow schemas to argparse and merging dictionaries from both systems
import logging
import warnings
import ast
import argparse
from operator import add
import inspect
import json
import marshmallow as mm
from argschema import fields
import collections

# explicit type mappings for field types that need them (default str)
FIELD_TYPE_MAP = {fields.Boolean: ast.literal_eval,
                  fields.List: ast.literal_eval,
                  fields.NumpyArray: ast.literal_eval

[docs]def prune_dict_with_none(d): """function to remove all dictionaries from a nested dictionary when all the values of a particular dictionary are None Parameters ---------- d: dictionary to prune Returns ------- dict pruned dictionary """ if all([d[key]==None for key in d.keys()]): return {} else: keys = [key for key in d.keys() if type(d[key])==dict] for key in keys: pruned = prune_dict_with_none(d[key]) if pruned == {}: d.pop(key) return d
[docs]def get_type_from_field(field): """Get type casting for command line argument from marshmallow.Field Parameters ---------- field : marshmallow.Field Field class from input schema Returns ------- callable Function to call to cast argument to """ if (isinstance(field, fields.List) and not field.metadata.get("cli_as_single_argument", False)): return list else: return FIELD_TYPE_MAP.get(type(field), str)
[docs]def cli_error_dict(arg_path, field_type, index=0): """Constuct a nested dictionary containing a casting error message Matches the format of errors generated by schema.dump. Parameters ---------- arg_path : string List of nested keys field_type : string Name of the marshmallow.Field type index : int Index into arg_path for recursion Returns ------- dict Dictionary representing argument path, containing error. """ if index == len(arg_path)-1: return {arg_path[index]: ["Command-line argument can't cast to {}".format(field_type)]} else: return {arg_path[index]: cli_error_dict(arg_path, field_type, index+1)}
[docs]def get_field_def_from_schema(parts,schema): """function to get a field_definition from a particular key, specified by it's parts list Parameters ---------- parts : list[str] the list of keys to get this schema schema: marshmallow.Schema the marshmallow schema to look up this key Returns ------- marshmallow.Field or None returns the field in the schema if it exists, otherwise returns None """ current_schema = schema for part in parts: if part not in current_schema.fields.keys(): return None else: if current_schema.only and part not in current_schema.only: field_def = None else: field_def = current_schema.fields[part] if isinstance(field_def, fields.Nested): current_schema = field_def.schema return field_def
[docs]def args_to_dict(argsobj, schemas=None): """function to convert namespace returned by argsparse into a nested dictionary Parameters ---------- argsobj : argparse.Namespace Namespace object returned by standard argparse.parse function schemas : list[marshmallow.Schema] Optional list of schemas which will be used to cast fields via `FIELD_TYPE_MAP` Returns ------- dict dictionary of namespace values where nesting elements uses '.' to denote nesting of keys """ d = {} argsdict = vars(argsobj) errors = {} field_def = None for field in argsdict.keys(): parts = field.split('.') root = d for i in range(len(parts)): if i == (len(parts) - 1): field_def = None for schema in schemas: field_def = get_field_def_from_schema(parts,schema) if field_def is not None: break #field_def = next(get_field_def(parts,schema) for schema in schemas if field_in_schema(parts,schema)) value = argsdict.get(field) if value is not None: try: value = get_type_from_field(field_def)(value) except ValueError: typename = field_def.__class__.__name__ errors.update(cli_error_dict(parts, typename)) root[parts[i]] = value else: if parts[i] not in root.keys(): root[parts[i]] = {} root = root[parts[i]] if errors: raise mm.ValidationError(json.dumps(errors, indent=2)) return prune_dict_with_none(d)
[docs]def merge_value(a, b, key, func=add): """attempt to merge these dictionaries using function defined by func (default to add) raise an exception if this fails Parameters ---------- a : dict one dictionary b : dict second dictionary key : key key to merge dictionary values on func : function(x function that merges two values of this key Returns (Default value = add) func : a[key] merged version of values (Default value = add) Returns ------- """ try: return func(a[key], b[key]) except: raise Exception("Cannot merge this key {},\ for values {} and {} of types {} and {}".format (key, a[key], b[key], type(a[key]), type(b[key])))
[docs]def smart_merge(a, b, path=None, merge_keys=None, overwrite_with_none=False): """updates dictionary a with values in dictionary b being careful not to write things with None, and performing a merge on merge_keys Parameters ---------- a : dict dictionary to perform update on b : dict dictionary to perform update with path : list list of nested keys traversed so far (used for recursion) (Default value = None) merge_keys : list list of keys to do merging on (default None) overwrite_with_none : (Default value = False) Returns ------- dict a dictionary that is a updated with b's values """ a = {} if a is None else a b = {} if b is None else b path = [] if path is None else path #simplifies code to have empty list rather than None #might allow some crazy dynamic merging in future if merge_keys is None: merge_keys = [] for key in b: if key in a: if isinstance(a[key], dict) and isinstance(b[key], dict): # recursively merge these leafs smart_merge(a[key], b[key], path + [str(key)], merge_keys) elif a[key] == b[key]: pass # same leaf value, so don't bother elif b[key] is None: if overwrite_with_none: a[key] = b[key] else: # in this case we are potentially overwriting a's value with b's # determine if we should try to merge if key in merge_keys: # attempt to merge leafs a[key] = merge_value(a, b, key) else: # otherwise replace leafs a[key] = b[key] else: # there is no corresponding leaf in a if b[key] is None: if overwrite_with_none: a[key] = b[key] else: if isinstance(b[key],dict): a[key]={} smart_merge(a[key], b[key], path + [str(key)], merge_keys) else: # otherwise replace entire leaf with b a[key] = b[key] return a
[docs]def get_description_from_field(field): """get the description for this marshmallow field Parameters ---------- field : marshmallow.fields.field field to get description Returns ------- str description string (or None) """ #look for description if 'description' in field.metadata: desc = field.metadata.get('description') #also look to see if description was added in metadata else: md = field.metadata.get('metadata', {}) if 'description' in md: desc = md['description'] else: desc = None return desc
[docs]def build_schema_arguments(schema, arguments=None, path=None, description =None): """given a jsonschema, create a dictionary of argparse arguments, by navigating down the Nested schema tree. (recursive function) Parameters ---------- schema : marshmallow.Schema schema with field.description filled in with help values arguments : list or None list of argument group dictionaries to add to (see Returns) (Default value = None) path : list or None list of strings denoted where you are in the tree (Default value = None) description: str or None description for the argument group at this level of the tree Returns ------- list List of argument group dictionaries, with keys ['title','description','args'] which contain the arguments for argparse. 'args' is an OrderedDict of dictionaries with keys of the argument names with kwargs to build an argparse argument """ path = [] if path is None else path arguments = [] if arguments is None else arguments arggroup = {} #name this argument group by the path, or the schema class name if it's the root if len(path)==0: arggroup['title']=schema.__class__.__name__ else: arggroup['title']='.'.join(path) arggroup['args']=collections.OrderedDict() #assume the description has been handed down arggroup['description']=description #sort the fields first by required, then by default values present or not for field_name, field in sorted(schema.declared_fields.items(), key= lambda x: 2*x[1].required+1*(x[1].default==mm.missing), reverse=True): #get this field's description desc = get_description_from_field(field) #if its nested, then we want to recusively follow this link if isinstance(field, mm.fields.Nested): if field.many: logging.warning("many=True not supported from argparse") else: build_schema_arguments(field.schema, arguments, path + [field_name], description = desc) elif isinstance(field, fields.Dict): logging.warning("setting Dict fields not supported from argparse") else: # it's not nested then let's build the argument arg = {} arg_name = '--' + '.'.join(path + [field_name]) if desc is not None: arg['help']=desc else: arg['help']='' #programatically add helpful notes to help string if field.default is not mm.missing: arg['help']+= " (default={})".format(field.default) if field.required: arg['help']+= " (REQUIRED)" for validator in field.validators: if isinstance(validator,mm.validate.ContainsOnly): arg['help']+= " (constrained list)" if isinstance(validator,mm.validate.OneOf): arg['help']+= " (valid options are {})".format(validator.choices) if (isinstance(field, mm.fields.List) and not field.metadata.get("cli_as_single_argument", False)): warn_msg = ("'{}' is using old-style command-line syntax with " "each element as a separate argument. This will " "not be supported in argschema after " "2.0. See" "master/user/intro.html#command-line-specification" " for details.").format(arg_name) warnings.warn(warn_msg, FutureWarning) arg['nargs'] = '*' arg['type'] = str # do type mapping after parsing so we can raise validation errors #DON'T WANT TO USE DEFAULT VALUES AS ARGPARSE OVERRULES JSON # if field.default != mm.missing: # arg['default'] = field.default #add this argument to the arggroup arggroup['args'][arg_name] = arg #tack on this arggroup to the list and return arguments.append(arggroup) return arguments
[docs]def schema_argparser(schema,additional_schemas=None): """given a jsonschema, build an argparse.ArgumentParser Parameters ---------- schema : argschema.schemas.ArgSchema schema to build an argparser from additional_schemas : list[marshmallow.schema] list of additional schemas to add to the command line arguments Returns ------- argparse.ArgumentParser that represents the schemas """ if additional_schemas is not None: schema_list = [schema] + additional_schemas else: schema_list = [schema] parser = argparse.ArgumentParser() for s in schema_list: #build up a list of argument groups using recursive function #to traverse the tree, root node gets the description given by doc string #of the schema arguments = build_schema_arguments(s,description=schema.__doc__) #make the root schema appeear first rather than last arguments = [arguments[-1]]+arguments[0:-1] for arg_group in arguments: group=parser.add_argument_group(arg_group['title'],arg_group['description']) for arg_name,arg in arg_group['args'].items(): group.add_argument(arg_name, **arg) return parser