Skip to content

Commit

Permalink
Change DDPPlugin to DDPStrategy (#244)
Browse files Browse the repository at this point in the history
* use ddp strategy

* silence sync_dist and ckpt warnings

* remove sync_dist due to overhead
  • Loading branch information
ejm714 authored Nov 1, 2022
1 parent bb89011 commit 94a2297
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
12 changes: 8 additions & 4 deletions zamba/models/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning.strategies import DDPStrategy

from zamba.data.video import VideoLoaderConfig
from zamba.models.config import (
Expand Down Expand Up @@ -281,7 +281,7 @@ def train_model(
logger=tensorboard_logger,
callbacks=callbacks,
fast_dev_run=train_config.dry_run,
strategy=DDPPlugin(find_unused_parameters=False)
strategy=DDPStrategy(find_unused_parameters=False)
if data_module.multiprocessing_context is not None
else None,
)
Expand Down Expand Up @@ -323,13 +323,17 @@ def train_model(
if not train_config.dry_run:
if trainer.datamodule.test_dataloader() is not None:
logger.info("Calculating metrics on holdout set.")
test_metrics = trainer.test(dataloaders=trainer.datamodule.test_dataloader())[0]
test_metrics = trainer.test(
dataloaders=trainer.datamodule.test_dataloader(), ckpt_path="best"
)[0]
with (Path(logging_and_save_dir) / "test_metrics.json").open("w") as fp:
json.dump(test_metrics, fp, indent=2)

if trainer.datamodule.val_dataloader() is not None:
logger.info("Calculating metrics on validation set.")
val_metrics = trainer.validate(dataloaders=trainer.datamodule.val_dataloader())[0]
val_metrics = trainer.validate(
dataloaders=trainer.datamodule.val_dataloader(), ckpt_path="best"
)[0]
with (Path(logging_and_save_dir) / "val_metrics.json").open("w") as fp:
json.dump(val_metrics, fp, indent=2)

Expand Down
5 changes: 4 additions & 1 deletion zamba/pytorch_lightning/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,10 @@ def aggregate_step_outputs(
def compute_and_log_metrics(
self, y_true: np.ndarray, y_pred: np.ndarray, y_proba: np.ndarray, subset: str
):
self.log(f"{subset}_macro_f1", f1_score(y_true, y_pred, average="macro", zero_division=0))
self.log(
f"{subset}_macro_f1",
f1_score(y_true, y_pred, average="macro", zero_division=0),
)

# if only two classes, skip top_k accuracy since not enough classes
if self.num_classes > 2:
Expand Down

0 comments on commit 94a2297

Please sign in to comment.