Spaces:
Sleeping
Sleeping
| import unittest | |
| import numpy as np | |
| import h5py | |
| import tempfile | |
| from stnn.data.preprocessing import get_data_from_file, load_data, load_training_data | |
| class TestGetDataFromFile(unittest.TestCase): | |
| def setUp(self): | |
| self.temp_file = tempfile.NamedTemporaryFile(delete = False) | |
| self.nx1, self.nx2, self.nx3 = 30, 20, 16 | |
| self.Nsamples = 10 | |
| with h5py.File(self.temp_file.name, 'w') as f: | |
| f.create_dataset('ell', data = np.random.rand(self.Nsamples)) | |
| f.create_dataset('a1', data = np.random.rand(self.Nsamples)) | |
| f.create_dataset('a2', data = np.random.rand(self.Nsamples)) | |
| f.create_dataset('rho', data = np.random.rand(self.Nsamples, self.nx1, self.nx2)) | |
| f.create_dataset('ibf', data = np.random.rand(self.Nsamples, self.nx2, self.nx3 // 2)) | |
| f.create_dataset('obf', data = np.random.rand(self.Nsamples, self.nx2, self.nx3 // 2)) | |
| self.temp_file1 = tempfile.NamedTemporaryFile(delete = False) | |
| with h5py.File(self.temp_file1.name, 'w') as f: | |
| f.create_dataset('ell', data = np.random.rand(self.Nsamples)) | |
| f.create_dataset('a1', data = np.random.rand(self.Nsamples)) | |
| f.create_dataset('a2', data = np.random.rand(self.Nsamples)) | |
| f.create_dataset('rho', data = np.random.rand(self.Nsamples, self.nx1, self.nx2)) | |
| f.create_dataset('ibf', data = np.random.rand(self.Nsamples, self.nx2, self.nx3 // 2)) | |
| f.create_dataset('obf', data = np.random.rand(self.Nsamples, self.nx2, self.nx3 // 2)) | |
| self.bad_file = tempfile.NamedTemporaryFile(delete = False) | |
| with h5py.File(self.bad_file.name, 'w') as f: | |
| f.create_dataset('ell', data = np.random.rand(self.Nsamples)) | |
| f.create_dataset('a1', data = np.random.rand(self.Nsamples)) | |
| f.create_dataset('a2', data = np.random.rand(self.Nsamples)) | |
| f.create_dataset('rho', data = np.random.rand(self.Nsamples, self.nx1, self.nx2)) | |
| def tearDown(self): | |
| self.temp_file.close() | |
| self.temp_file1.close() | |
| self.bad_file.close() | |
| def test_missing_datasets(self): | |
| with self.assertRaises(ValueError): | |
| get_data_from_file(self.bad_file.name, self.nx2, self.nx2) | |
| def test_data_extraction_shapes(self): | |
| result = get_data_from_file(self.temp_file.name, self.nx2, self.nx3) | |
| self.assertEqual(result[0].shape, (self.Nsamples,)) | |
| self.assertEqual(result[1].shape, (self.Nsamples,)) | |
| self.assertEqual(result[2].shape, (self.Nsamples,)) | |
| self.assertEqual(result[3].shape, (self.Nsamples, 2 * self.nx2, self.nx3 // 2)) | |
| self.assertEqual(result[4].shape, (self.Nsamples, self.nx1, self.nx2)) | |
| def test_nrange_parameter(self): | |
| Nrange = (2, 5) | |
| result = get_data_from_file(self.temp_file.name, self.nx2, self.nx3, Nrange = Nrange) | |
| expected_size = Nrange[1] - Nrange[0] | |
| self.assertEqual(result[0].shape, (expected_size,)) | |
| self.assertEqual(result[0].shape, (expected_size,)) | |
| def test_list_input(self): | |
| file_list = [self.temp_file.name, self.temp_file1.name] | |
| Nrange_list = [(0, -1), (0, -1)] | |
| with self.assertRaises(TypeError): | |
| # noinspection PyTypeChecker | |
| _ = get_data_from_file(file_list, self.nx2, self.nx3, Nrange = Nrange_list) | |
| def test_invalid_Nrange(self): | |
| Nrange_list = [(0, -1), (0, -1)] | |
| with self.assertRaises(TypeError): | |
| _ = get_data_from_file(self.temp_file.name, self.nx2, self.nx3, Nrange = Nrange_list) | |
| for Nrange in [(0, 1, 1), 1, (1), (1.5, 3), (3, 1.5), (1.5, 1.5), 'x']: | |
| with self.assertRaises(TypeError): | |
| _ = get_data_from_file(self.temp_file.name, self.nx2, self.nx3, Nrange = Nrange) | |
| with self.assertRaises(TypeError): | |
| _ = get_data_from_file(self.temp_file.name, self.nx2, self.nx3, Nrange = list(Nrange)) | |
| def test_good_data_load(self): | |
| files = [self.temp_file.name, self.temp_file1.name] | |
| Nrange_list = [(0, None), (0, self.Nsamples)] | |
| ell1, ell2, a1, a2 = 0.1, 2.0, 1.0, 5.0 | |
| params, bf, rho = load_data(files, self.nx2, self.nx3, ell1, ell2, a1, a2, Nrange_list = Nrange_list) | |
| self.assertEqual(params.shape, (2 * self.Nsamples, 3)) | |
| self.assertEqual(bf.shape, (2 * self.Nsamples, 2 * self.nx2, self.nx3 // 2)) | |
| self.assertEqual(rho.shape, (2 * self.Nsamples, self.nx1, self.nx2)) | |
| test_size = 0.3 | |
| (params_train, bf_train, rho_train, | |
| params_test, bf_test, rho_test) = load_training_data(files, self.nx2, self.nx3, | |
| ell1, ell2, a1, a2, test_size = test_size, | |
| Nrange_list = Nrange_list) | |
| Ntest = int(test_size * 2 * self.Nsamples) | |
| Ntrain = 2 * self.Nsamples - Ntest | |
| self.assertEqual(params_train.shape, (Ntrain, 3)) | |
| self.assertEqual(bf_train.shape, (Ntrain, 2 * self.nx2, self.nx3 // 2)) | |
| self.assertEqual(rho_train.shape, (Ntrain, self.nx1, self.nx2)) | |
| self.assertEqual(params_test.shape, (Ntest, 3)) | |
| self.assertEqual(bf_test.shape, (Ntest, 2 * self.nx2, self.nx3 // 2)) | |
| self.assertEqual(rho_test.shape, (Ntest, self.nx1, self.nx2)) | |
| def test_bad_data_load(self): | |
| files = [self.temp_file.name, self.temp_file1.name] | |
| Nrange_list = (0, -1) | |
| ell1, ell2, a1, a2 = 0.1, 2.0, 1.0, 5.0 | |
| with self.assertRaises(TypeError): | |
| _ = load_data(files, self.nx2, self.nx3, ell1, ell2, a1, a2, Nrange_list = Nrange_list) | |
| with self.assertRaises(TypeError): | |
| _ = load_data(files, self.nx2, self.nx3, ell1, ell2, a1, a2, Nrange_list = list(Nrange_list)) | |
| Nrange_list = [(0, -1), (0, -1)] | |
| for test_size in [-1, 0.0, 1.5]: | |
| with self.assertRaises(ValueError): | |
| _ = load_training_data(files, self.nx2, self.nx3, | |
| ell1, ell2, a1, a2, test_size = test_size, Nrange_list = Nrange_list) | |
| if __name__ == '__main__': | |
| unittest.main() | |