Spaces:
Sleeping
Sleeping
| import unittest | |
| import numpy as np | |
| from stnn.data.preprocessing import train_test_split | |
| class TestTrainTestSplit(unittest.TestCase): | |
| def setUp(self): | |
| self.X_array = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) | |
| self.Y_array = np.array([1, 2, 3, 4]) | |
| self.X_list = [self.X_array, self.X_array] | |
| self.Y_list = [self.Y_array, self.Y_array] | |
| self.X_list_bad = [self.Y_array, self.X_array] | |
| self.Y_list_bad = [self.Y_array, self.X_array] | |
| def test_basic_functionality_array(self): | |
| X_train, X_test, Y_train, Y_test = train_test_split(self.X_array, self.Y_array, test_size = 0.25) | |
| self.assertEqual(len(X_train), 3) | |
| self.assertEqual(len(X_test), 1) | |
| self.assertEqual(len(Y_train), 3) | |
| self.assertEqual(len(Y_test), 1) | |
| def test_basic_functionality_list(self): | |
| X_train, X_test, Y_train, Y_test = train_test_split(self.X_list, self.Y_list, test_size = 0.25) | |
| self.assertEqual(len(X_train[0]), 3) | |
| self.assertEqual(len(X_test[0]), 1) | |
| self.assertEqual(len(Y_train[0]), 3) | |
| self.assertEqual(len(Y_test[0]), 1) | |
| def test_return_type_consistency_array(self): | |
| X_train, X_test, Y_train, Y_test = train_test_split(self.X_array, self.Y_array, test_size = 0.25) | |
| self.assertIsInstance(X_train, np.ndarray) | |
| self.assertIsInstance(X_test, np.ndarray) | |
| self.assertIsInstance(Y_train, np.ndarray) | |
| self.assertIsInstance(Y_test, np.ndarray) | |
| X_train, X_test, Y_train, Y_test = train_test_split([self.X_array], [self.Y_array], test_size = 0.25) | |
| self.assertIsInstance(X_train, list) | |
| self.assertIsInstance(X_test, list) | |
| self.assertIsInstance(Y_train, list) | |
| self.assertIsInstance(Y_test, list) | |
| def test_return_type_consistency_list(self): | |
| X_train, X_test, Y_train, Y_test = train_test_split(self.X_list, self.Y_list, test_size = 0.25) | |
| self.assertIsInstance(X_train, list) | |
| self.assertIsInstance(X_test, list) | |
| self.assertIsInstance(Y_train, list) | |
| self.assertIsInstance(Y_test, list) | |
| # noinspection PyTypeChecker | |
| X_train, X_test, Y_train, Y_test = train_test_split(tuple(self.X_list), tuple(self.Y_list), test_size = 0.25) | |
| self.assertIsInstance(X_train, list) | |
| self.assertIsInstance(X_test, list) | |
| self.assertIsInstance(Y_train, list) | |
| self.assertIsInstance(Y_test, list) | |
| def test_random_state(self): | |
| X_train1, X_test1, Y_train1, Y_test1 = train_test_split(self.X_array, self.Y_array, test_size = 0.25, | |
| random_state = 42) | |
| X_train2, X_test2, Y_train2, Y_test2 = train_test_split(self.X_array, self.Y_array, test_size = 0.25, | |
| random_state = 42) | |
| np.testing.assert_array_equal(X_train1, X_train2) | |
| np.testing.assert_array_equal(X_test1, X_test2) | |
| np.testing.assert_array_equal(Y_train1, Y_train2) | |
| np.testing.assert_array_equal(Y_test1, Y_test2) | |
| def test_invalid_test_size(self): | |
| with self.assertRaises(ValueError): | |
| train_test_split(self.X_array, self.Y_array, test_size = -0.1) | |
| with self.assertRaises(ValueError): | |
| train_test_split(self.X_array, self.Y_array, test_size = 1.5) | |
| def test_inconsistent_length(self): | |
| X = np.array([[1, 2], [3, 4]]) | |
| Y = np.array([1, 2, 3]) | |
| with self.assertRaises(ValueError): | |
| train_test_split(X, Y) | |
| with self.assertRaises(ValueError): | |
| train_test_split(self.X_list_bad, self.Y_list_bad) | |
| with self.assertRaises(ValueError): | |
| train_test_split(self.X_list_bad, self.Y_list) | |
| with self.assertRaises(ValueError): | |
| train_test_split(self.X_list, self.Y_list_bad) | |
| def test_empty(self): | |
| X_empty = np.zeros(0) | |
| Y_empty = np.zeros(0) | |
| with self.assertRaises(ValueError): | |
| train_test_split(X_empty, Y_empty) | |
| with self.assertRaises(ValueError): | |
| train_test_split([X_empty], []) | |
| with self.assertRaises(ValueError): | |
| train_test_split([], [Y_empty]) | |
| with self.assertRaises(ValueError): | |
| train_test_split([X_empty], [Y_empty]) | |
| if __name__ == '__main__': | |
| unittest.main() | |