'''Module that contains the base class ArgSchemaParser which should be
subclassed when using this library
'''
import json
import logging
import copy
from . import schemas
from . import utils
import marshmallow as mm
from .sources.json_source import JsonSource, JsonSink
from .sources.yaml_source import YamlSource, YamlSink
from .sources.source import NotConfiguredSourceError, MultipleConfiguredSourceError, get_input_from_config
[docs]def contains_non_default_schemas(schema, schema_list=[]):
"""returns True if this schema contains a schema which was not an instance of DefaultSchema
Parameters
----------
schema : marshmallow.Schema
schema to check
schema_list :
(Default value = [])
Returns
-------
bool
does this schema only contain schemas which are subclassed from schemas.DefaultSchema
"""
if not isinstance(schema, schemas.DefaultSchema):
return True
for k, v in schema.declared_fields.items():
if isinstance(v, mm.fields.Nested):
if type(v.schema) in schema_list:
return False
else:
schema_list.append(type(v.schema))
if contains_non_default_schemas(v.schema, schema_list):
return True
return False
[docs]def is_recursive_schema(schema, schema_list=[]):
"""returns true if this schema contains recursive elements
Parameters
----------
schema : marshmallow.Schema
schema to check
schema_list :
(Default value = [])
Returns
-------
bool
does this schema contain any recursively defined schemas
"""
for k, v in schema.declared_fields.items():
if isinstance(v, mm.fields.Nested):
if type(v.schema) in schema_list:
return True
else:
schema_list.append(type(v.schema))
if is_recursive_schema(v.schema, schema_list):
return True
return False
[docs]def fill_defaults(schema, args):
"""DEPRECATED, function to fill in default values from schema into args
bug: goes into an infinite loop when there is a recursively defined schema
Parameters
----------
schema : marshmallow.Schema
schema to get defaults from
args :
Returns
-------
dict
dictionary with missing default values filled in
"""
defaults = []
# find all of the schema entries with default values
schemata = [(schema, [])]
while schemata:
subschema, path = schemata.pop()
for k, v in subschema.declared_fields.items():
if isinstance(v, mm.fields.Nested):
schemata.append((v.schema, path + [k]))
elif v.default != mm.missing:
defaults.append((path + [k], v.default))
# put the default entries into the args dictionary
args = copy.deepcopy(args)
for path, val in defaults:
d = args
for path_item in path[:-1]:
d = d.setdefault(path_item, {})
if path[-1] not in d:
d[path[-1]] = val
return args
[docs]class ArgSchemaParser(object):
"""The main class you should sub-class to write your own argschema module.
Takes input_data, reference to a input_json and the command line inputs and parses out the parameters
and validates them against the schema_type specified.
To subclass this and make a new schema be default, simply override the default_schema and default_output_schema
attributes of this class.
Parameters
----------
input_data : dict or None
dictionary parameters to fall back on if not source is given or configured via command line
schema_type : schemas.ArgSchema
the schema to use to validate the parameters
output_schema_type : marshmallow.Schema
the schema to use to validate the output, used by self.output
input_source : argschema.sources.source.Source
a generic source of a dictionary
output_sink : argschema.sources.source.Source
a generic sink to write output dictionary to
args : list or None
command line arguments passed to the module, if None use argparse to parse the command line,
set to [] if you want to bypass command line parsing
logger_name : str
name of logger from the logging module you want to instantiate default ('argschema')
Raises
-------
marshmallow.ValidationError
If the combination of input_json, input_data and command line arguments do not pass
the validation of the schema
"""
default_schema = schemas.ArgSchema
default_output_schema = None
default_configurable_sources = [JsonSource]
default_configurable_sinks = [JsonSink]
def __init__(self,
input_data=None, # dictionary input as option instead of --input_json
schema_type=None, # schema for parsing arguments
output_schema_type=None, # schema for parsing output_json
args=None,
input_source=None,
output_sink=None,
logger_name=__name__):
if schema_type is None:
schema_type = self.default_schema
if output_schema_type is None:
output_schema_type = self.default_output_schema
self.schema = schema_type()
self.logger = self.initialize_logger(logger_name, 'WARNING')
self.logger.debug('input_data is {}'.format(input_data))
# convert schema to argparse object
# consolidate a list of the input and output source
# command line configuration schemas
io_schemas = []
for in_cfg in self.default_configurable_sources:
io_schemas.append(in_cfg.ConfigSchema())
for out_cfg in self.default_configurable_sinks:
io_schemas.append(out_cfg.ConfigSchema())
# build a command line parser from the input schemas and configurations
p = utils.schema_argparser(self.schema, io_schemas)
argsobj = p.parse_args(args)
argsdict = utils.args_to_dict(argsobj, [self.schema]+io_schemas)
self.logger.debug('argsdict is {}'.format(argsdict))
# if you received an input_source, get the dictionary from there
if input_source is not None:
input_data = input_source.get_dict()
else: # see if the input_data itself contains an InputSource configuration use that
config_data = self.__get_input_data_from_config(input_data)
input_data = config_data if config_data is not None else input_data
# check whether the command line arguments contain an input configuration and use that
config_data = self.__get_input_data_from_config(utils.smart_merge({},argsdict))
input_data = config_data if config_data is not None else input_data
# merge the command line dictionary into the input json
args = utils.smart_merge(input_data, argsdict)
self.logger.debug('args after merge {}'.format(args))
# if the output sink was not passed in, see if there is a configuration in the combined args
if output_sink is None:
output_sink = self.__get_output_sink_from_config(args)
# save the output sink for later
self.output_sink = output_sink
# validate with load!
result = self.load_schema_with_defaults(self.schema, args)
if len(result.errors) > 0:
raise mm.ValidationError(json.dumps(result.errors, indent=2))
self.schema_args = result
self.args = result.data
self.output_schema_type = output_schema_type
self.logger = self.initialize_logger(
logger_name, self.args.get('log_level'))
def __get_output_sink_from_config(self, d):
"""private function to check for ArgSink configuration in a dictionary and return a configured ArgSink
Parameters
----------
d : dict
dictionary to look for ArgSink Configuration parameters in
Returns
-------
ArgSink
A configured argsink
Raises
------
MultipleConfiguredSourceError
If more than one Sink is configured
"""
output_set = False
output_sink = None
for OutputSink in self.default_configurable_sinks:
try:
output_config_d = OutputSink.get_config(
OutputSink.ConfigSchema, d)
if output_set:
raise MultipleConfiguredSourceError(
"more then one OutputSink configuration present in {}".format(d))
output_sink = OutputSink(**output_config_d)
output_set = True
except NotConfiguredSourceError:
pass
return output_sink
def __get_input_data_from_config(self, d):
"""private function to check for ArgSource configurations in a dictionary
and return the data if it exists
Parameters
----------
d : dict
dictionary to look for InputSource configuration parameters in
Returns
-------
dict or None
dictionary of InputData if it found a valid configuration, None otherwise
Raises
------
MultipleConfiguredSourceError
if more than one InputSource is configured
"""
input_set = False
input_data = None
for InputSource in self.default_configurable_sources:
try:
input_data = get_input_from_config(InputSource, d)
if input_set:
raise MultipleConfiguredSourceError(
"more then one InputSource configuration present in {}".format(d))
input_set = True
except NotConfiguredSourceError as e:
pass
return input_data
[docs] def get_output_json(self, d):
"""method for getting the output_json pushed through validation
if validation exists
Parameters
----------
d:dict
output dictionary to output
Returns
-------
dict
validated and serialized version of the dictionary
Raises
------
marshmallow.ValidationError
If any of the output dictionary doesn't meet the output schema
"""
if self.output_schema_type is not None:
schema = self.output_schema_type()
(output_json, errors) = schema.dump(d)
if len(errors) > 0:
raise mm.ValidationError(json.dumps(errors))
else:
self.logger.warning("output_schema_type is not defined,\
the output won't be validated")
output_json = d
return output_json
[docs] def output(self,d,output_path=None,sink=None,**sink_options):
"""method for outputing dictionary to the output_json file path after
validating it through the output_schema_type
Parameters
----------
d:dict
output dictionary to output
sink: argschema.sources.source.ArgSink
output_sink to output to (optional default to self.output_source)
output_path: str
path to save to output file, optional (with default to self.mod['output_json'] location)
(DEPRECATED path to save to output file, optional (with default to self.mod['output_json'] location)
**sink_options :
will be passed through to sink.put_dict
Raises
------
marshmallow.ValidationError
If any of the output dictionary doesn't meet the output schema
"""
output_d = self.get_output_json(d)
if output_path is not None:
self.logger.warning('DEPRECATED, pass output_sink instead')
sink = JsonSink(output_json=output_path)
if sink is not None:
sink.put_dict(output_d,**sink_options)
else:
self.output_sink.put_dict(output_d,**sink_options)
[docs] def load_schema_with_defaults(self, schema, args):
"""method for deserializing the arguments dictionary (args)
given the schema (schema) making sure that the default values have
been filled in.
Parameters
----------
args : dict
a dictionary of input arguments
schema :
Returns
-------
dict
a deserialized dictionary of the parameters converted through marshmallow
Raises
------
marshmallow.ValidationError
If this schema contains nested schemas that don't subclass argschema.DefaultSchema
because these won't work with loading defaults.
"""
is_recursive = is_recursive_schema(schema)
is_non_default = contains_non_default_schemas(schema)
if (not is_recursive) and is_non_default:
# throw a warning
self.logger.warning("""DEPRECATED:You are using a Schema which contains
a Schema which is not subclassed from argschema.DefaultSchema,
default values will not work correctly in this case,
this use is deprecated, and future versions will not fill in default
values when you use non-DefaultSchema subclasses""")
args = fill_defaults(schema, args)
if is_recursive and is_non_default:
raise mm.ValidationError(
'Recursive schemas need to subclass argschema.DefaultSchema else defaults will not work')
# load the dictionary via the schema
result = schema.load(args)
return result
[docs] @staticmethod
def initialize_logger(name, log_level):
"""initializes the logger to a level with a name
logger = initialize_logger(name, log_level)
Parameters
----------
name : str
name of the logger
log_level :
Returns
-------
logging.Logger
a logger set with the name and level specified
"""
level = logging.getLevelName(log_level)
logging.basicConfig()
logger = logging.getLogger(name)
logger.setLevel(level=level)
return logger
[docs]class ArgSchemaYamlParser(ArgSchemaParser):
default_configurable_sources = [YamlSource]
default_configurable_sinks = [YamlSink]