From ec1afe1cbad4749a7965ca277033749bb492b015 Mon Sep 17 00:00:00 2001 From: tjkessler Date: Wed, 30 Jun 2021 13:31:42 -0400 Subject: [PATCH] Update to unit test arguments --- tests/test_all.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/test_all.py b/tests/test_all.py index d138d0a..99d18cb 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -257,7 +257,7 @@ def test_tune_batch_size(self): targets = [[5.0]] ds_eval = QSPRDataset(smiles, targets, backend=_BACKEND) model = ECNet(_N_DESC, 1, 5, 1) - res = tune_batch_size(1, 1, _N_PROCESSES, model=model, train_ds=ds_train, eval_ds=ds_eval) + res = tune_batch_size(1, 1, ds_train, ds_eval, _N_PROCESSES) self.assertTrue(1 <= res['batch_size'] <= len(ds_train.target_vals)) def test_tune_model_architecture(self): @@ -270,8 +270,7 @@ def test_tune_model_architecture(self): targets = [[5.0]] ds_eval = QSPRDataset(smiles, targets, backend=_BACKEND) model = ECNet(_N_DESC, 1, 5, 1) - res = tune_model_architecture(1, 1, _N_PROCESSES, model=model, train_ds=ds_train, - eval_ds=ds_eval) + res = tune_model_architecture(1, 1, ds_train, ds_eval, _N_PROCESSES) for k in list(res.keys()): self.assertTrue(res[k] >= CONFIG['architecture_params_range'][k][0]) self.assertTrue(res[k] <= CONFIG['architecture_params_range'][k][1]) @@ -285,12 +284,8 @@ def test_tune_training_hps(self): smiles = ['CCCCC'] targets = [[5.0]] ds_eval = QSPRDataset(smiles, targets, backend=_BACKEND) - model = ECNet(_N_DESC, 1, 5, 1) - res = tune_training_parameters(1, 1, _N_PROCESSES, model=model, train_ds=ds_train, - eval_ds=ds_eval) + res = tune_training_parameters(1, 1, ds_train, ds_eval, _N_PROCESSES) for k in list(res.keys()): - if k == 'betas': - continue self.assertTrue(res[k] >= CONFIG['training_params_range'][k][0]) self.assertTrue(res[k] <= CONFIG['training_params_range'][k][1])