Skip to content

Commit

Permalink
Update to unit test arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
tjkessler committed Jun 30, 2021
1 parent 97ff934 commit ec1afe1
Showing 1 changed file with 3 additions and 8 deletions.
11 changes: 3 additions & 8 deletions tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def test_tune_batch_size(self):
targets = [[5.0]]
ds_eval = QSPRDataset(smiles, targets, backend=_BACKEND)
model = ECNet(_N_DESC, 1, 5, 1)
res = tune_batch_size(1, 1, _N_PROCESSES, model=model, train_ds=ds_train, eval_ds=ds_eval)
res = tune_batch_size(1, 1, ds_train, ds_eval, _N_PROCESSES)
self.assertTrue(1 <= res['batch_size'] <= len(ds_train.target_vals))

def test_tune_model_architecture(self):
Expand All @@ -270,8 +270,7 @@ def test_tune_model_architecture(self):
targets = [[5.0]]
ds_eval = QSPRDataset(smiles, targets, backend=_BACKEND)
model = ECNet(_N_DESC, 1, 5, 1)
res = tune_model_architecture(1, 1, _N_PROCESSES, model=model, train_ds=ds_train,
eval_ds=ds_eval)
res = tune_model_architecture(1, 1, ds_train, ds_eval, _N_PROCESSES)
for k in list(res.keys()):
self.assertTrue(res[k] >= CONFIG['architecture_params_range'][k][0])
self.assertTrue(res[k] <= CONFIG['architecture_params_range'][k][1])
Expand All @@ -285,12 +284,8 @@ def test_tune_training_hps(self):
smiles = ['CCCCC']
targets = [[5.0]]
ds_eval = QSPRDataset(smiles, targets, backend=_BACKEND)
model = ECNet(_N_DESC, 1, 5, 1)
res = tune_training_parameters(1, 1, _N_PROCESSES, model=model, train_ds=ds_train,
eval_ds=ds_eval)
res = tune_training_parameters(1, 1, ds_train, ds_eval, _N_PROCESSES)
for k in list(res.keys()):
if k == 'betas':
continue
self.assertTrue(res[k] >= CONFIG['training_params_range'][k][0])
self.assertTrue(res[k] <= CONFIG['training_params_range'][k][1])

Expand Down

0 comments on commit ec1afe1

Please sign in to comment.