Skip to content

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Oct 10, 2022
1 parent bb80cfa commit c8f977a
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 12 deletions.
2 changes: 1 addition & 1 deletion examples/cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def collect_res(seed: int) -> Dict[str, float]:
lmodel = MyLModule(model, optimizer, loss_fn, lr_s, hparams)
trainer = ml.Trainer(lmodel, device_ids, runs_dir=RUNS_DIR, **hparams["trainer_hparams"])
res = trainer.fit(ldm.train_dataloader, ldm.val_dataloader)
res2 = trainer.test(ldm.test_dataloader)
res2 = trainer.test(ldm.test_dataloader, True, True)
res.update(res2)
return res
res = ml.multi_runs(collect_res, 3, seed=42)
Expand Down
2 changes: 1 addition & 1 deletion examples/cv_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def collect_res(seed: int) -> Dict[str, float]:
lmodel = MyLModule(model, optimizer, loss_fn, lr_s, hparams)
trainer = ml.Trainer(lmodel, device_ids, runs_dir=RUNS_DIR, **hparams["trainer_hparams"])
res = trainer.fit(ldm.train_dataloader, ldm.val_dataloader)
res2 = trainer.test(ldm.test_dataloader)
res2 = trainer.test(ldm.test_dataloader, True, True)
res.update(res2)
return res
res = ml.multi_runs(collect_res, 3, seed=42)
Expand Down
2 changes: 1 addition & 1 deletion examples/cv_ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def collect_res(seed: int) -> Dict[str, float]:
lmodel = MyLModule(model, optimizer, loss_fn, lr_s, hparams)
trainer = ml.Trainer(lmodel, device_ids, runs_dir=RUNS_DIR, **hparams["trainer_hparams"])
res = trainer.fit(ldm.train_dataloader, ldm.val_dataloader)
res2 = trainer.test(ldm.test_dataloader)
res2 = trainer.test(ldm.test_dataloader, True, True)
res.update(res2)
return res
res = ml.multi_runs(collect_res, 3, seed=42)
Expand Down
2 changes: 1 addition & 1 deletion examples/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,4 @@ def tokenize_function(example):
logger.info("KeyboardInterrupt Detected...")
raise
finally:
logger.info(trainer.test(ldm.test_dataloader))
logger.info(trainer.test(ldm.test_dataloader, True, True))
6 changes: 3 additions & 3 deletions examples/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def training_epoch_end(self) -> Dict[str, float]:
lmodel = MyLModule(model, optimizer, loss_fn, metrics, "acc")
ldm = ml.LDataModule(train_dataset, val_dataset, test_dataset, 64)
trainer = ml.Trainer(lmodel, [], 40, RUNS_DIR, gradient_clip_norm=10, val_every_n_epoch=10, verbose=True)
logger.info(trainer.test(ldm.val_dataloader, False, True))
logger.info(trainer.test(ldm.val_dataloader, True, True))
logger.info(trainer.fit(ldm.train_dataloader, ldm.val_dataloader))
logger.info(trainer.test(ldm.test_dataloader, True, True))

Expand All @@ -113,7 +113,7 @@ def training_epoch_end(self) -> Dict[str, float]:
ldm = ml.LDataModule(train_dataset, val_dataset, test_dataset, 64)
trainer = ml.Trainer(lmodel, [0], 100, RUNS_DIR, gradient_clip_norm=10,
val_every_n_epoch=10, verbose=True, resume_from_ckpt=ckpt_path)
logger.info(trainer.test(ldm.val_dataloader, False, True))
logger.info(trainer.test(ldm.val_dataloader, True, True))
logger.info(trainer.fit(ldm.train_dataloader, ldm.val_dataloader))
logger.info(trainer.test(ldm.test_dataloader, True, True))

Expand All @@ -124,4 +124,4 @@ def training_epoch_end(self) -> Dict[str, float]:
lmodel = MyLModule(None, None, loss_fn, metrics, "loss")
ldm = ml.LDataModule(train_dataset, val_dataset, test_dataset, 64)
trainer = ml.Trainer(lmodel, [], None, RUNS_DIR, resume_from_ckpt=ckpt_path)
logger.info(trainer.test(ldm.test_dataloader, False, True))
logger.info(trainer.test(ldm.test_dataloader, True, True))
31 changes: 26 additions & 5 deletions mini_lightning/mini_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,15 +846,36 @@ def fit(self, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader]
cuda.empty_cache()
return best_mes if self.rank in {-1, 0} else {} # core_metrics is best

def test(self, dataloader: Optional[DataLoader], test_best: bool = True, test_last: bool = False) -> Dict[str, float]:
def _best_ckpt_is_last(self) -> bool:
if self.best_ckpt_path is None or self.last_ckpt_path is None:
return False

best_ckpt_fname = os.path.basename(self.best_ckpt_path)
m = re.match(r"best-epoch=(\d+)", best_ckpt_fname)
assert m is not None
best_epoch_idx = m.group(1)
last_ckpt_fname = os.path.basename(self.last_ckpt_path)
m = re.match(r"last-epoch=(\d+)", last_ckpt_fname)
assert m is not None
last_epoch_idx = m.group(1)
return best_epoch_idx == last_epoch_idx

def test(self, dataloader: Optional[DataLoader], test_best: bool = False, test_last: bool = True) -> Dict[str, float]:
res_mes = {}
if test_best:
# If last first, last will be overridden in tensorboard. So best first.
m = self._test(dataloader, "best")
res_mes.update(m)
if self.best_ckpt_path is None:
logger.warning("Ignore test best: self.best_ckpt_path is None")
test_best = False
else:
m = self._test(dataloader, "best")
res_mes.update(m)
#
if test_last: # just current model
m = self._test(dataloader, "last")
res_mes.update(m)
if self._best_ckpt_is_last() and test_best is True:
logger.info("Ignore test last: the best ckpt is the last ckpt")
else:
m = self._test(dataloader, "last")
res_mes.update(m)
cuda.empty_cache()
return res_mes if self.rank in {-1, 0} else {}

0 comments on commit c8f977a

Please sign in to comment.