Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 9, 2024
1 parent ee94593 commit 4cc8a78
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 26 deletions.
2 changes: 1 addition & 1 deletion src/anemoi/training/config/dataloader/native_grid.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ prefetch_factor: 2
# ============
# read_frequency:
# Only ever read_frequency-th GPU of each model commm group reads data
# to reduce CPU memory usage.
# to reduce CPU memory usage.
# The number of GPUs per model must be divisible by read_frequency.
# Default: 1 (all GPUs read data)
# ============
Expand Down
8 changes: 4 additions & 4 deletions src/anemoi/training/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,13 @@ def __iter__(self) -> torch.Tensor:
Currently it receives data with an ensemble dimension, which is discarded for
now. (Until the code is "ensemble native".)
"""
if self.reader_group_rank != 0:
if self.reader_group_rank != 0:
# yield dummy data only with shape information for non-root ranks (shape used for broadcast)
shape = (self.rollout + self.multi_step, self.data.shape[2], self.data.shape[3], self.data.shape[1])
for _ in self.chunk_index_range:
for _ in self.chunk_index_range:
yield torch.tensor(shape, dtype=torch.long)
return
return

if self.shuffle:
shuffled_chunk_indices = self.rng.choice(
self.chunk_index_range,
Expand Down
23 changes: 9 additions & 14 deletions src/anemoi/training/distributed/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,7 @@
class DDPGroupStrategy(DDPStrategy):
"""Distributed Data Parallel strategy with group communication."""

def __init__(
self,
num_gpus_per_model: int,
read_frequency: int,
**kwargs: dict) -> None:
def __init__(self, num_gpus_per_model: int, read_frequency: int, **kwargs: dict) -> None:
"""Initialize the distributed strategy.
Parameters
Expand Down Expand Up @@ -85,14 +81,13 @@ def setup(self, trainer: pl.Trainer) -> None:
f"({self.read_frequency})."
)

reader_group_ranks = np.array([
np.split(group_ranks, int(self.model_comm_group_size / self.read_frequency))
for group_ranks in model_comm_group_ranks
]) # Shape: (num_model_comm_groups, model_comm_grp_size/read_freq, read_freq)
reader_groups = [
[torch.distributed.new_group(x) for x in group_ranks]
for group_ranks in reader_group_ranks
]
reader_group_ranks = np.array(
[
np.split(group_ranks, int(self.model_comm_group_size / self.read_frequency))
for group_ranks in model_comm_group_ranks
],
) # Shape: (num_model_comm_groups, model_comm_grp_size/read_freq, read_freq)
reader_groups = [[torch.distributed.new_group(x) for x in group_ranks] for group_ranks in reader_group_ranks]
reader_group_id = model_comm_group_rank // self.read_frequency
reader_group_rank = model_comm_group_rank % self.read_frequency
# get all reader groups of the current model group
Expand All @@ -105,7 +100,7 @@ def setup(self, trainer: pl.Trainer) -> None:
str(reader_group_ranks[model_comm_group_id, reader_group_id]),
model_comm_group_id,
reader_group_id,
reader_group_rank,
reader_group_rank,
)

# register hooks for correct gradient reduction
Expand Down
18 changes: 11 additions & 7 deletions src/anemoi/training/train/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,16 @@ def __init__(
self.reader_group_id = self.model_comm_group_rank // self.reader_group_size
self.reader_group_rank = self.model_comm_group_rank % self.reader_group_size
# global rank of the root of the current reader group (required for broadcasting):
self.reader_group_root = (int(os.environ.get("SLURM_PROCID", "0")) // self.reader_group_size) * self.reader_group_size
self.reader_group_root = (
int(os.environ.get("SLURM_PROCID", "0")) // self.reader_group_size
) * self.reader_group_size

LOGGER.debug(
f"GraphForecaster: "
f"Rank {os.environ.get('SLURM_PROCID', '0')} model_comm_group_id: {self.model_comm_group_id}"
f" model_comm_group_rank: {self.model_comm_group_rank}"
f" model_comm_group_rank: {self.model_comm_group_rank}"
f" reader_group_id: {self.reader_group_id}"
f" reader_group_rank: {self.reader_group_rank}"
f" reader_group_rank: {self.reader_group_rank}"
f" reader_group_root: {self.reader_group_root}",
)

Expand Down Expand Up @@ -240,17 +242,19 @@ def _step(
) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]:
del batch_idx

# preprocess batch and broadcast from reader_group rank 0 to reader_group
if self.reader_group_rank == 0:
# preprocess batch and broadcast from reader_group rank 0 to reader_group
if self.reader_group_rank == 0:
# for validation not normalized in-place because remappers cannot be applied in-place
batch = self.model.pre_processors(batch, in_place=not validation_mode)
else:
else:
# init batch tensor with correct shape on non-root ranks
shape = (batch.shape[0],) + tuple(batch[0].tolist())
batch = torch.zeros(shape, device=self.device)

if self.reader_groups is not None and self.reader_group_size > 1:
torch.distributed.broadcast(batch, src=self.reader_group_root, group=self.reader_groups[self.reader_group_id])
torch.distributed.broadcast(
batch, src=self.reader_group_root, group=self.reader_groups[self.reader_group_id],
)

# Synchronize after the broadcast to ensure that model_comm_group and reader_group don't overlap
# see https://pytorch.org/docs/stable/distributed.html#torch.distributed.new_group WARNING
Expand Down

0 comments on commit 4cc8a78

Please sign in to comment.