Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/improve dataloader memory #76

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Keep it human-readable, your future self will thank you!
- Add anemoi-transform link to documentation
- Codeowners file (#56)
- Changelog merge strategy (#56)
- Feature: Add reader groups to reduce CPU memory usage [#76](https://github.com/ecmwf/anemoi-training/pull/76)

#### Miscellaneous

Expand Down
11 changes: 11 additions & 0 deletions src/anemoi/training/config/dataloader/native_grid.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
prefetch_factor: 2
pin_memory: True

# ============
# read_group_size:
# Form subgroups of model comm groups that read data together.
# Each reader in the group only reads 1/read_group_size of the data
# which is then all-gathered between the group.
# This can reduce CPU memory usage as well as increase dataloader throughput.
# The number of GPUs per model must be divisible by read_group_size.
# Good values are num_gpus_per_model or num_gpus_per_node.
# ============
read_group_size: 1

num_workers:
training: 8
validation: 8
Expand Down
29 changes: 0 additions & 29 deletions src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@


import logging
import os
from functools import cached_property
from typing import Callable

Expand Down Expand Up @@ -59,31 +58,6 @@ def __init__(self, config: DictConfig) -> None:
timestep,
)

self.global_rank = int(os.environ.get("SLURM_PROCID", "0")) # global rank
self.model_comm_group_id = (
self.global_rank // self.config.hardware.num_gpus_per_model
) # id of the model communication group the rank is participating in
self.model_comm_group_rank = (
self.global_rank % self.config.hardware.num_gpus_per_model
) # rank within one model communication group
total_gpus = self.config.hardware.num_gpus_per_node * self.config.hardware.num_nodes
assert (
total_gpus
) % self.config.hardware.num_gpus_per_model == 0, (
f"GPUs per model {self.config.hardware.num_gpus_per_model} does not divide total GPUs {total_gpus}"
)
self.model_comm_num_groups = (
self.config.hardware.num_gpus_per_node
* self.config.hardware.num_nodes
// self.config.hardware.num_gpus_per_model
) # number of model communication groups
LOGGER.debug(
"Rank %d model communication group number %d, with local model communication group rank %d",
self.global_rank,
self.model_comm_group_id,
self.model_comm_group_rank,
)

# Set the maximum rollout to be expected
self.rollout = (
self.config.training.rollout.max
Expand Down Expand Up @@ -173,9 +147,6 @@ def _get_dataset(
rollout=r,
multistep=self.config.training.multistep_input,
timeincrement=self.timeincrement,
model_comm_group_rank=self.model_comm_group_rank,
model_comm_group_id=self.model_comm_group_id,
model_comm_num_groups=self.model_comm_num_groups,
shuffle=shuffle,
label=label,
)
Expand Down
81 changes: 66 additions & 15 deletions src/anemoi/training/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ def __init__(
rollout: int = 1,
multistep: int = 1,
timeincrement: int = 1,
model_comm_group_rank: int = 0,
model_comm_group_id: int = 0,
model_comm_num_groups: int = 1,
shuffle: bool = True,
label: str = "generic",
) -> None:
Expand All @@ -54,12 +51,6 @@ def __init__(
time increment between samples, by default 1
multistep : int, optional
collate (t-1, ... t - multistep) into the input state vector, by default 1
model_comm_group_rank : int, optional
process rank in the torch.distributed group (important when running on multiple GPUs), by default 0
model_comm_group_id: int, optional
device group ID, default 0
model_comm_num_groups : int, optional
total number of device groups, by default 1
shuffle : bool, optional
Shuffle batches, by default True
label : str, optional
Expand All @@ -77,11 +68,14 @@ def __init__(
self.n_samples_per_epoch_total: int = 0
self.n_samples_per_epoch_per_worker: int = 0

# DDP-relevant info
self.model_comm_group_rank = model_comm_group_rank
self.model_comm_num_groups = model_comm_num_groups
self.model_comm_group_id = model_comm_group_id
self.global_rank = int(os.environ.get("SLURM_PROCID", "0"))
# lazy init model and reader group info, will be set by the DDPGroupStrategy:
self.model_comm_group_rank = 0
self.model_comm_num_groups = 1
self.model_comm_group_id = 0
self.global_rank = 0

self.reader_group_rank = 0
self.reader_group_size = 1

# additional state vars (lazy init)
self.n_samples_per_worker = 0
Expand All @@ -93,6 +87,7 @@ def __init__(
assert self.multi_step > 0, "Multistep value must be greater than zero."
self.ensemble_dim: int = 2
self.ensemble_size = self.data.shape[self.ensemble_dim]
self.grid_size = self.data.shape[-1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may be useful (easier to understand) if we defined a self.grid_dim = -1 and used that instead? (like we do for the ensemble dim just above)


@cached_property
def statistics(self) -> dict:
Expand Down Expand Up @@ -128,6 +123,58 @@ def valid_date_indices(self) -> np.ndarray:
"""
return get_usable_indices(self.data.missing, len(self.data), self.rollout, self.multi_step, self.timeincrement)

def set_comm_group_info(
self,
global_rank: int,
model_comm_group_id: int,
model_comm_group_rank: int,
model_comm_num_groups: int,
reader_group_rank: int,
reader_group_size: int,
) -> None:
"""Set model and reader communication group information (called by DDPGroupStrategy).

Parameters
----------
global_rank : int
Global rank
model_comm_group_id : int
Model communication group ID
model_comm_group_rank : int
Model communication group rank
model_comm_num_groups : int
Number of model communication groups
reader_group_rank : int
Reader group rank
reader_group_size : int
Reader group size
"""
self.global_rank = global_rank
self.model_comm_group_id = model_comm_group_id
self.model_comm_group_rank = model_comm_group_rank
self.model_comm_num_groups = model_comm_num_groups
self.reader_group_rank = reader_group_rank
self.reader_group_size = reader_group_size

if self.reader_group_size > 1:
# get the grid shard size and start/end indices
grid_shard_size = self.grid_size // self.reader_group_size
self.grid_start = self.reader_group_rank * grid_shard_size
if self.reader_group_rank == self.reader_group_size - 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be shortened to

self.grid_end = min(self.grid_size, (self.reader_group_rank + 1) * grid_shard_size)

?

self.grid_end = self.grid_size
else:
self.grid_end = (self.reader_group_rank + 1) * grid_shard_size

LOGGER.debug(
"NativeGridDataset.set_group_info(): global_rank %d, model_comm_group_id %d, "
"model_comm_group_rank %d, model_comm_num_groups %d, reader_group_rank %d",
global_rank,
model_comm_group_id,
model_comm_group_rank,
model_comm_num_groups,
reader_group_rank,
)

def per_worker_init(self, n_workers: int, worker_id: int) -> None:
"""Called by worker_init_func on each copy of dataset.

Expand Down Expand Up @@ -233,7 +280,11 @@ def __iter__(self) -> torch.Tensor:
start = i - (self.multi_step - 1) * self.timeincrement
end = i + (self.rollout + 1) * self.timeincrement

x = self.data[start : end : self.timeincrement]
if self.reader_group_size > 1: # read only a subset of the grid
x = self.data[start : end : self.timeincrement, :, :, self.grid_start : self.grid_end]
Copy link
Member

@mishooax mishooax Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@japols i'm puzzled a bit by this: i get what you're doing here, but given the way the zarr is chunked on disk (chunk i == self.data[i]), wouldn't this imply that each worker is still reading a full chunk (a time slice, i.e. all latlons) and then discards the points that are not in its shard?

tagging @floriankrb in case i misunderstood how the zarr chunking is done on disk (or how the slice index is implemented in anemoi-datasets)

else: # read the full grid
x = self.data[start : end : self.timeincrement, :, :, :]

x = rearrange(x, "dates variables ensemble gridpoints -> dates ensemble gridpoints variables")
self.ensemble_dim = 1

Expand Down
130 changes: 110 additions & 20 deletions src/anemoi/training/distributed/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@


import logging
import os

import numpy as np
import pytorch_lightning as pl
Expand All @@ -27,19 +26,22 @@
class DDPGroupStrategy(DDPStrategy):
"""Distributed Data Parallel strategy with group communication."""

def __init__(self, num_gpus_per_model: int, **kwargs: dict) -> None:
def __init__(self, num_gpus_per_model: int, read_group_size: int, **kwargs: dict) -> None:
"""Initialize the distributed strategy.

Parameters
----------
num_gpus_per_model : int
Number of GPUs per model to shard over.
read_group_size : int
Number of GPUs per reader group.
**kwargs : dict
Additional keyword arguments.

"""
super().__init__(**kwargs)
self.model_comm_group_size = num_gpus_per_model
self.read_group_size = read_group_size

def setup(self, trainer: pl.Trainer) -> None:
assert self.accelerator is not None, "Accelerator is not initialized for distributed strategy"
Expand All @@ -60,18 +62,56 @@ def setup(self, trainer: pl.Trainer) -> None:
torch.distributed.new_group(x) for x in model_comm_group_ranks
] # every rank has to create all of these

model_comm_group_id, model_comm_group_nr, model_comm_group_rank = self.get_my_model_comm_group(
model_comm_group_id, model_comm_group_rank, model_comm_num_groups = self.get_my_model_comm_group(
self.model_comm_group_size,
)
model_comm_group = model_comm_groups[model_comm_group_id]
self.model.set_model_comm_group(model_comm_group)
self.model.set_model_comm_group(
model_comm_group,
model_comm_group_id,
model_comm_group_rank,
model_comm_num_groups,
self.model_comm_group_size,
)

# set up reader groups by further splitting model_comm_group_ranks with read_group_size:

assert self.model_comm_group_size % self.read_group_size == 0, (
f"Number of GPUs per model ({self.model_comm_group_size}) must be divisible by read_group_size "
f"({self.read_group_size})."
)

reader_group_ranks = np.array(
[
np.split(group_ranks, int(self.model_comm_group_size / self.read_group_size))
for group_ranks in model_comm_group_ranks
],
) # Shape: (num_model_comm_groups, model_comm_grp_size/read_group_size, read_group_size)
reader_groups = [[torch.distributed.new_group(x) for x in group_ranks] for group_ranks in reader_group_ranks]
reader_group_id, reader_group_rank, reader_group_size, reader_group_root = self.get_my_reader_group(
model_comm_group_rank,
self.read_group_size,
)
# get all reader groups of the current model group
model_reader_groups = reader_groups[model_comm_group_id]
self.model.set_reader_groups(
model_reader_groups,
reader_group_id,
reader_group_rank,
reader_group_size,
)

LOGGER.debug(
"Rank %d model_comm_group is %s, group number %d, with local group rank %d and comms_group_ranks %s",
"Rank %d model_comm_group_id: %d model_comm_group: %s model_comm_group_rank: %d "
"reader_group_id: %d reader_group: %s reader_group_rank: %d reader_group_root (global): %d",
self.global_rank,
str(model_comm_group_nr),
model_comm_group_id,
model_comm_group_rank,
str(model_comm_group_ranks[model_comm_group_id]),
model_comm_group_rank,
reader_group_id,
reader_group_ranks[model_comm_group_id, reader_group_id],
reader_group_rank,
reader_group_root,
)

# register hooks for correct gradient reduction
Expand Down Expand Up @@ -109,7 +149,7 @@ def setup(self, trainer: pl.Trainer) -> None:
# seed ranks
self.seed_rnd(model_comm_group_id)

def get_my_model_comm_group(self, num_gpus_per_model: int) -> tuple[int, np.ndarray, int]:
def get_my_model_comm_group(self, num_gpus_per_model: int) -> tuple[int, int, int]:
"""Determine tasks that work together and from a model group.

Parameters
Expand All @@ -119,19 +159,69 @@ def get_my_model_comm_group(self, num_gpus_per_model: int) -> tuple[int, np.ndar

Returns
-------
tuple[int, np.ndarray, int]
Model_comm_group id, Model_comm_group Nr, Model_comm_group rank
tuple[int, int, int]
Model_comm_group id, Model_comm_group rank, Number of model_comm_groups
"""
model_comm_group_id = self.global_rank // num_gpus_per_model
model_comm_group_rank = self.global_rank % num_gpus_per_model
model_comm_num_groups = self.world_size // num_gpus_per_model

return model_comm_group_id, model_comm_group_rank, model_comm_num_groups

def get_my_reader_group(self, model_comm_group_rank: int, read_group_size: int) -> tuple[int, int, int]:
"""Determine tasks that work together and from a reader group.

Parameters
----------
model_comm_group_rank : int
Rank within the model communication group.
read_group_size : int
Number of dataloader readers per model group.

Returns
-------
tuple[int, int, int]
Reader_group id, Reader_group rank, Reader_group root (global rank)
"""
model_comm_groups = np.arange(0, self.world_size, dtype=np.int32)
model_comm_groups = np.split(model_comm_groups, self.world_size / num_gpus_per_model)
reader_group_id = model_comm_group_rank // read_group_size
reader_group_rank = model_comm_group_rank % read_group_size
reader_group_size = read_group_size
reader_group_root = (self.global_rank // read_group_size) * read_group_size

return reader_group_id, reader_group_rank, reader_group_size, reader_group_root

model_comm_group_id = None
for i, model_comm_group in enumerate(model_comm_groups):
if self.global_rank in model_comm_group:
model_comm_group_id = i
model_comm_group_nr = model_comm_group
model_comm_group_rank = np.ravel(np.asarray(model_comm_group == self.global_rank).nonzero())[0]
return model_comm_group_id, model_comm_group_nr, model_comm_group_rank
def process_dataloader(self, dataloader: torch.utils.data.DataLoader) -> torch.utils.data.DataLoader:
"""Pass communication group information to the dataloader for distributed training.

Parameters
----------
dataloader : torch.utils.data.DataLoader
Dataloader to process.

Returns
-------
torch.utils.data.DataLoader
Processed dataloader.

"""
dataloader = super().process_dataloader(dataloader)

# pass model and reader group information to the dataloaders dataset
model_comm_group_id, model_comm_group_rank, model_comm_num_groups = self.get_my_model_comm_group(
self.model_comm_group_size,
)
_, reader_group_rank, _, _ = self.get_my_reader_group(model_comm_group_rank, self.read_group_size)

dataloader.dataset.set_comm_group_info(
self.global_rank,
model_comm_group_id,
model_comm_group_rank,
model_comm_num_groups,
reader_group_rank,
self.read_group_size,
)

return dataloader

def seed_rnd(self, model_comm_group_id: int) -> None:
"""Seed the random number generators for the rank."""
Expand All @@ -145,7 +235,7 @@ def seed_rnd(self, model_comm_group_id: int) -> None:
"Strategy: Rank %d, model comm group id %d, base seed %d, seeded with %d, "
"running with random seed: %d, sanity rnd: %s"
),
int(os.environ.get("SLURM_PROCID", "0")),
self.global_rank,
model_comm_group_id,
base_seed,
initial_seed,
Expand Down
Loading
Loading