Source code for test_first_test

import pytest
import json
import logging
import marshmallow as mm
from argschema import ArgSchemaParser, ArgSchema
import argschema
from pathlib import Path


[docs]def test_bad_path(): with pytest.raises(mm.ValidationError): example = { "input_json": "a bad path", "output_json": "another example", "log_level": "DEBUG"} ArgSchemaParser(input_data=example, args=[])
[docs]def test_simple_example(tmpdir): file_in = tmpdir.join('test_input_json.json') file_in.write('nonesense') file_out = tmpdir.join('test_output.json') example = { "input_json": str(file_in), "output_json": str(file_out), "log_level": "CRITICAL"} jm = ArgSchemaParser(input_data=example, args=[]) assert jm.args['log_level'] == 'CRITICAL'
[docs]def test_log_catch(): with pytest.raises(mm.ValidationError): example = {"log_level": "NOTACHOICE"} jm = ArgSchemaParser(input_data=example, args=[]) print(jm.args)
[docs]class MyExtension(argschema.schemas.DefaultSchema): a = mm.fields.Str(description='a string', required=True) b = mm.fields.Int(description='an integer') c = mm.fields.Int(description='an integer', default=10) d = mm.fields.List(mm.fields.Int, description='a list of integers')
[docs]class SimpleExtension(ArgSchema): test = mm.fields.Nested(MyExtension, required=True)
[docs]def test_simple_extension_required(): with pytest.raises(mm.ValidationError): example1 = {} ArgSchemaParser( input_data=example1, schema_type=SimpleExtension, args=[])
SimpleExtension_example_invalid = { 'test': { 'a': 5, 'b': 1, 'd': ['a', 2, 3] } } SimpleExtension_example_valid = { 'test': { 'a': "hello", 'b': 1, 'd': [1, 5, 4] } }
[docs]@pytest.fixture(scope='module') def simple_extension_file(tmpdir_factory): file_ = tmpdir_factory.mktemp('test').join('testinput.json') file_.write(json.dumps(SimpleExtension_example_valid)) return file_
[docs]def test_simple_extension_fail(): with pytest.raises(mm.ValidationError): ArgSchemaParser( input_data=SimpleExtension_example_invalid, schema_type=SimpleExtension, args=[])
[docs]def test_simple_extension_pass(): mod = ArgSchemaParser( input_data=SimpleExtension_example_valid, schema_type=SimpleExtension, args=[]) assert mod.args['test']['a'] == 'hello' assert mod.args['test']['b'] == 1 assert mod.args['test']['c'] == 10 assert len(mod.args['test']['d']) == 3
[docs]def test_simple_extension_write_pass(simple_extension_file): args = ['--input_json', str(simple_extension_file)] mod = ArgSchemaParser( input_data=SimpleExtension_example_valid, schema_type=SimpleExtension, args=args) assert mod.args['test']['a'] == 'hello' assert mod.args['test']['b'] == 1 assert mod.args['test']['c'] == 10 assert len(mod.args['test']['d']) == 3 assert mod.logger.getEffectiveLevel() == logging.ERROR
[docs]def test_simple_extension_write_debug_level(simple_extension_file): args = ['--input_json', str(simple_extension_file), '--log_level', 'DEBUG'] mod = ArgSchemaParser(schema_type=SimpleExtension, args=args) assert mod.logger.getEffectiveLevel() == logging.DEBUG
[docs]def test_simple_extension_write_overwrite(simple_extension_file): args = ['--input_json', str(simple_extension_file), '--test.b', '5'] mod = ArgSchemaParser(schema_type=SimpleExtension, args=args) assert mod.args['test']['b'] == 5
[docs]def test_simple_extension_write_overwrite_list(simple_extension_file): args = ['--input_json', str(simple_extension_file), '--test.d', '6', '7', '8', '9'] mod = ArgSchemaParser(schema_type=SimpleExtension, args=args) assert len(mod.args['test']['d']) == 4
[docs]def test_bad_input_json_argparse(): args = ['--input_json', 'not_a_file.json'] with pytest.raises(mm.ValidationError): ArgSchemaParser(schema_type=SimpleExtension, args=args)
# TESTS DEMONSTRATING BAD BEHAVIOR OF DEFAULT LOADING
[docs]class MyExtensionOld(mm.Schema): a = mm.fields.Str(description='a string') b = mm.fields.Int(description='an integer') c = mm.fields.Int(description='an integer', default=10) d = mm.fields.List(mm.fields.Int, description='a list of integers')
[docs]class SimpleExtensionOld(ArgSchema): test = mm.fields.Nested(MyExtensionOld, default=None, required=True)
[docs]def test_simple_extension_old_pass(): mod = ArgSchemaParser( input_data=SimpleExtension_example_valid, schema_type=SimpleExtensionOld, args=[]) assert mod.args['test']['a'] == 'hello' assert mod.args['test']['b'] == 1 assert mod.args['test']['c'] == 10 assert len(mod.args['test']['d']) == 3
[docs]class RecursiveSchema(argschema.schemas.DefaultSchema): children = mm.fields.Nested("self", many=True, description='children of this node') name = mm.fields.Str(default="anonymous", description='name of this node')
[docs]class ExampleRecursiveSchema(ArgSchema): tree = mm.fields.Nested(RecursiveSchema, required=True)
recursive_data = { 'tree': { 'name': 'root', 'children': [ { "name": 'child1' }, { "name": "branch1", "children": [ { "name": "subchild1" }, { }, { } ] } ] } }
[docs]def test_recursive_schema(): mod = ArgSchemaParser( input_data=recursive_data, schema_type=ExampleRecursiveSchema, args=[]) assert mod.args['tree']['name'] == 'root' assert len(mod.args['tree']['children']) == 2 assert mod.args['tree']['children'][0]['name'] == 'child1' assert mod.args['tree']['children'][1]['name'] == 'branch1' assert len(mod.args['tree']['children'][1]['children']) == 3 assert mod.args['tree']['children'][1]['children'][2]['name'] == 'anonymous'
[docs]class BadRecursiveSchema(mm.Schema): children = mm.fields.Nested("self", many=True, description='children of this node') name = mm.fields.Str(default="anonymous", description='name of this node')
[docs]class BadExampleRecursiveSchema(ArgSchema): tree = mm.fields.Nested(BadRecursiveSchema, required=True)
[docs]def bad_test_recursive_schema(): with pytest.raises(mm.ValidationError): ArgSchemaParser(input_data=recursive_data, schema_type=BadExampleRecursiveSchema, args=[])
[docs]class ModelFit(argschema.schemas.DefaultSchema): fit_type = argschema.fields.Str(description="") hof_fit = argschema.fields.InputFile(description="") hof = argschema.fields.InputFile(description="")
[docs]class PopulationSelectionPaths(argschema.schemas.DefaultSchema): fits = argschema.fields.Nested(ModelFit, description="", many=True)
[docs]class PopulationSelectionParameters(argschema.ArgSchema): paths = argschema.fields.Nested(PopulationSelectionPaths)
[docs]@pytest.fixture def david_data(tmpdir): files = [Path(tmpdir / "file1.txt") for i in range(4)] for ifile in files: ifile.touch() dict_args = { 'paths': { 'fits': [{ 'fit_type': 'test', 'hof_fit': str(files[0]), 'hof': str(files[1]) }, { 'fit_type': 'test2', 'hof_fit': str(files[2]), 'hof': str(files[3]) } ] } } yield dict_args
[docs]def test_david_example(tmpdir_factory, david_data): file_ = tmpdir_factory.mktemp('test').join('testinput.json') file_.write(json.dumps(david_data)) args = ['--input_json', str(file_)] mod = argschema.ArgSchemaParser( schema_type=PopulationSelectionParameters, args=args) assert(len(mod.args['paths']['fits']) == 2)
[docs]class MyShorterExtension(ArgSchema): a = mm.fields.Str(description='a string') b = mm.fields.Int(description='an integer') c = mm.fields.Int(description='an integer', default=10) d = mm.fields.List(mm.fields.Int, description='a list of integers')
[docs]def test_simple_description(): d = { 'a': "hello", 'b': 1, 'd': [1, 5, 4] } argschema.ArgSchemaParser( input_data=d, schema_type=MyShorterExtension, args=[])
[docs]class MySchemaPostLoad(ArgSchema): xid = argschema.fields.Int(required=True)
[docs] @mm.post_load def my_post(self, data, **kwargs): return data
[docs]class MyPostLoadClass(ArgSchemaParser): default_schema = MySchemaPostLoad
[docs] def run(self): print(self.args)
[docs]def test_post_load_schema(): example1 = { 'xid': 1, } mb = MyPostLoadClass(input_data=example1, args=[]) mb.run()