from argschema import ArgSchemaParser
from argschema.schemas import DefaultSchema
from argschema.fields import Str, Int, NumpyArray
import json
import numpy as np
import pytest
import marshmallow as mm
import os
[docs]class MyOutputSchema(DefaultSchema):
a = Str(required=True, description="a simple string")
b = Int(default=5, description="a default integer")
M = NumpyArray(required=True, description="a numpy array of answers")
[docs]def test_output(tmpdir):
file_out = tmpdir.join('test_output.json')
input_parameters = {
'output_json': str(file_out)
}
mod = ArgSchemaParser(input_data=input_parameters,
output_schema_type=MyOutputSchema,
args=[])
M = [[5, 5], [7, 2]]
Mnp = np.array(M)
output = {
"a": "example",
"M": Mnp
}
expected_output = {
"a": "example",
"b": 5,
"M": M
}
mod.output(output)
with open(str(file_out), 'r') as fp:
actual_output = json.load(fp)
assert actual_output == expected_output
[docs]def test_output_unvalidated(tmpdir):
file_out = tmpdir.join('test_output_unvalidated.json')
input_parameters = {
'output_json': str(file_out)
}
mod = ArgSchemaParser(input_data=input_parameters,
args=[])
output = {
"a": "example",
}
mod.output(output)
with open(str(file_out), 'r') as fp:
actual_output = json.load(fp)
assert actual_output == output
[docs]def test_bad_output(tmpdir):
file_out = tmpdir.join('test_output_bad.json')
input_parameters = {
'output_json': str(file_out)
}
mod = ArgSchemaParser(input_data=input_parameters,
output_schema_type=MyOutputSchema,
args=[])
M = [[5, 5], [7, 2]]
Mnp = np.array(M)
output = {
"a": "example",
"b": "not a number",
"M": Mnp
}
with pytest.raises(mm.ValidationError):
mod.output(output)
[docs]def test_alt_output(tmpdir):
file_out = tmpdir.join('test_output.json')
file_out_2 = tmpdir.join('test_output_2.json')
input_parameters = {
'output_json': str(file_out)
}
mod = ArgSchemaParser(input_data=input_parameters,
output_schema_type=MyOutputSchema,
args=[])
M = [[5, 5], [7, 2]]
Mnp = np.array(M)
output = {
"a": "example",
"M": Mnp
}
expected_output = {
"a": "example",
"b": 5,
"M": M
}
mod.output(output, str(file_out_2))
with open(str(file_out_2), 'r') as fp:
actual_output = json.load(fp)
assert actual_output == expected_output
[docs]def test_tmp_output_cleanput(tmpdir):
file_out = tmpdir.join('test_output.json')
input_parameters = {
'output_json': str(file_out)
}
ArgSchemaParser(input_data=input_parameters,
output_schema_type=MyOutputSchema,
args=[])
files = os.listdir(str(tmpdir))
assert len(files) == 0