Source code for fields.test_numpyarray

import pytest
from argschema import ArgSchemaParser, ArgSchema
from argschema.fields import NumpyArray
from argschema.utils import dump
import marshmallow as mm
import numpy as np

numpy_array_test = {
    'a': [[1, 2],
          [3, 4]]
}


[docs]class NumpyFileuint16(ArgSchema): a = NumpyArray( dtype='uint16', required=True, description='list of lists representing a uint16 numpy array')
[docs]def test_numpy(): mod = ArgSchemaParser( input_data=numpy_array_test, schema_type=NumpyFileuint16, args=[]) assert mod.args['a'].shape == (2, 2) assert mod.args['a'].dtype == 'uint16'
[docs]def test_bad_shape(): bad_shape = { 'a': [[1, 2], [3]] } with pytest.raises(mm.ValidationError): mod = ArgSchemaParser( # noQA: F841 input_data=bad_shape, schema_type=NumpyFileuint16, args=[])
[docs]def test_bad_data(): bad_shape = { 'a': [['a', 'b']] } with pytest.raises(mm.ValidationError): mod = ArgSchemaParser( # noQA: F841 input_data=bad_shape, schema_type=NumpyFileuint16, args=[])
[docs]def test_serialize(): schema = NumpyFileuint16() object_dict = { 'a': np.array([1, 2]) } json_dict = dump(schema, object_dict) assert(type(json_dict['a']) == list) assert(json_dict['a'] == object_dict['a'].tolist())