diff --git a/tests/server/test_server.py b/tests/server/test_server.py index dfd7ee6..f36124b 100644 --- a/tests/server/test_server.py +++ b/tests/server/test_server.py @@ -49,6 +49,7 @@ def test_train_project(self): sv = Server() sv.load_data('cn_model_v1.0.csv', random=True, split=[0.7, 0.2, 0.1]) sv.create_project('test_project', 2, 2) + sv._vars['epochs'] = 100 sv.train() for pool in range(2): self.assertTrue(exists(join( @@ -72,6 +73,7 @@ def test_use_project(self): sv = Server() sv.load_data('cn_model_v1.0.csv', random=True, split=[0.7, 0.2, 0.1]) sv.create_project('test_project', 2, 2) + sv._vars['epochs'] = 100 sv.train() results = sv.use() self.assertEqual(len(results), len(sv._df)) @@ -84,6 +86,7 @@ def test_save_project(self): sv = Server() sv.load_data('cn_model_v1.0.csv', random=True, split=[0.7, 0.2, 0.1]) sv.create_project('test_project', 2, 2) + sv._vars['epochs'] = 100 sv.train() sv.save_project() self.assertTrue(exists('test_project.prj')) @@ -96,6 +99,7 @@ def test_multiprocessing_train(self): sv = Server(num_processes=8) sv.load_data('cn_model_v1.0.csv') sv.create_project('test_project', 2, 4) + sv._vars['epochs'] = 100 sv.train() for pool in range(2): self.assertTrue(exists(join( diff --git a/tests/tools/test_project.py b/tests/tools/test_project.py index 8037111..f6a0dbb 100644 --- a/tests/tools/test_project.py +++ b/tests/tools/test_project.py @@ -13,6 +13,7 @@ def test_predict(self): sv = Server() sv.load_data('cn_model_v2.0.csv') sv.create_project('test_project', 1, 1) + sv._vars['epochs'] = 100 sv.train() sv.save_project() diff --git a/tests/utils/test_server_utils.py b/tests/utils/test_server_utils.py index a4776de..8cafe3b 100644 --- a/tests/utils/test_server_utils.py +++ b/tests/utils/test_server_utils.py @@ -146,12 +146,11 @@ def test_train_model(self): df.create_sets(random=True) pd = df.package_sets() config = server_utils.default_config() + config['epochs'] = 100 r_squared = server_utils.train_model( pd, config, 'test', 'r2', filename='test_train.h5' ) self.assertTrue(exists('test_train.h5')) - self.assertGreaterEqual(r_squared, 0) - self.assertLessEqual(r_squared, 1) remove('test_train.h5') def test_use_model(self): @@ -161,6 +160,7 @@ def test_use_model(self): df.create_sets(random=True) pd = df.package_sets() config = server_utils.default_config() + config['epochs'] = 100 _ = server_utils.train_model( pd, config, 'test', 'rmse', filename='test_use.h5' )