diff --git a/CHANGELOG.md b/CHANGELOG.md index e42c7c69..2e0378c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,7 @@ Keep it human-readable, your future self will thank you! - Add anemoi-transform link to documentation - Codeowners file (#56) - Changelog merge strategy (#56) +- Feature: Add reader groups to reduce CPU memory usage [#76](https://github.com/ecmwf/anemoi-training/pull/76) #### Miscellaneous diff --git a/src/anemoi/training/config/dataloader/native_grid.yaml b/src/anemoi/training/config/dataloader/native_grid.yaml index e6d50801..7262117e 100644 --- a/src/anemoi/training/config/dataloader/native_grid.yaml +++ b/src/anemoi/training/config/dataloader/native_grid.yaml @@ -1,6 +1,17 @@ prefetch_factor: 2 pin_memory: True +# ============ +# read_group_size: +# Form subgroups of model comm groups that read data together. +# Each reader in the group only reads 1/read_group_size of the data +# which is then all-gathered between the group. +# This can reduce CPU memory usage as well as increase dataloader throughput. +# The number of GPUs per model must be divisible by read_group_size. +# Good values are num_gpus_per_model or num_gpus_per_node. +# ============ +read_group_size: 1 + num_workers: training: 8 validation: 8 diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index f64a3091..062e0073 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -9,7 +9,6 @@ import logging -import os from functools import cached_property from typing import Callable @@ -59,31 +58,6 @@ def __init__(self, config: DictConfig) -> None: timestep, ) - self.global_rank = int(os.environ.get("SLURM_PROCID", "0")) # global rank - self.model_comm_group_id = ( - self.global_rank // self.config.hardware.num_gpus_per_model - ) # id of the model communication group the rank is participating in - self.model_comm_group_rank = ( - self.global_rank % self.config.hardware.num_gpus_per_model - ) # rank within one model communication group - total_gpus = self.config.hardware.num_gpus_per_node * self.config.hardware.num_nodes - assert ( - total_gpus - ) % self.config.hardware.num_gpus_per_model == 0, ( - f"GPUs per model {self.config.hardware.num_gpus_per_model} does not divide total GPUs {total_gpus}" - ) - self.model_comm_num_groups = ( - self.config.hardware.num_gpus_per_node - * self.config.hardware.num_nodes - // self.config.hardware.num_gpus_per_model - ) # number of model communication groups - LOGGER.debug( - "Rank %d model communication group number %d, with local model communication group rank %d", - self.global_rank, - self.model_comm_group_id, - self.model_comm_group_rank, - ) - # Set the maximum rollout to be expected self.rollout = ( self.config.training.rollout.max @@ -173,9 +147,6 @@ def _get_dataset( rollout=r, multistep=self.config.training.multistep_input, timeincrement=self.timeincrement, - model_comm_group_rank=self.model_comm_group_rank, - model_comm_group_id=self.model_comm_group_id, - model_comm_num_groups=self.model_comm_num_groups, shuffle=shuffle, label=label, ) diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index 9e368f9c..e6a10943 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -36,9 +36,6 @@ def __init__( rollout: int = 1, multistep: int = 1, timeincrement: int = 1, - model_comm_group_rank: int = 0, - model_comm_group_id: int = 0, - model_comm_num_groups: int = 1, shuffle: bool = True, label: str = "generic", ) -> None: @@ -54,12 +51,6 @@ def __init__( time increment between samples, by default 1 multistep : int, optional collate (t-1, ... t - multistep) into the input state vector, by default 1 - model_comm_group_rank : int, optional - process rank in the torch.distributed group (important when running on multiple GPUs), by default 0 - model_comm_group_id: int, optional - device group ID, default 0 - model_comm_num_groups : int, optional - total number of device groups, by default 1 shuffle : bool, optional Shuffle batches, by default True label : str, optional @@ -77,11 +68,14 @@ def __init__( self.n_samples_per_epoch_total: int = 0 self.n_samples_per_epoch_per_worker: int = 0 - # DDP-relevant info - self.model_comm_group_rank = model_comm_group_rank - self.model_comm_num_groups = model_comm_num_groups - self.model_comm_group_id = model_comm_group_id - self.global_rank = int(os.environ.get("SLURM_PROCID", "0")) + # lazy init model and reader group info, will be set by the DDPGroupStrategy: + self.model_comm_group_rank = 0 + self.model_comm_num_groups = 1 + self.model_comm_group_id = 0 + self.global_rank = 0 + + self.reader_group_rank = 0 + self.reader_group_size = 1 # additional state vars (lazy init) self.n_samples_per_worker = 0 @@ -93,6 +87,7 @@ def __init__( assert self.multi_step > 0, "Multistep value must be greater than zero." self.ensemble_dim: int = 2 self.ensemble_size = self.data.shape[self.ensemble_dim] + self.grid_size = self.data.shape[-1] @cached_property def statistics(self) -> dict: @@ -128,6 +123,58 @@ def valid_date_indices(self) -> np.ndarray: """ return get_usable_indices(self.data.missing, len(self.data), self.rollout, self.multi_step, self.timeincrement) + def set_comm_group_info( + self, + global_rank: int, + model_comm_group_id: int, + model_comm_group_rank: int, + model_comm_num_groups: int, + reader_group_rank: int, + reader_group_size: int, + ) -> None: + """Set model and reader communication group information (called by DDPGroupStrategy). + + Parameters + ---------- + global_rank : int + Global rank + model_comm_group_id : int + Model communication group ID + model_comm_group_rank : int + Model communication group rank + model_comm_num_groups : int + Number of model communication groups + reader_group_rank : int + Reader group rank + reader_group_size : int + Reader group size + """ + self.global_rank = global_rank + self.model_comm_group_id = model_comm_group_id + self.model_comm_group_rank = model_comm_group_rank + self.model_comm_num_groups = model_comm_num_groups + self.reader_group_rank = reader_group_rank + self.reader_group_size = reader_group_size + + if self.reader_group_size > 1: + # get the grid shard size and start/end indices + grid_shard_size = self.grid_size // self.reader_group_size + self.grid_start = self.reader_group_rank * grid_shard_size + if self.reader_group_rank == self.reader_group_size - 1: + self.grid_end = self.grid_size + else: + self.grid_end = (self.reader_group_rank + 1) * grid_shard_size + + LOGGER.debug( + "NativeGridDataset.set_group_info(): global_rank %d, model_comm_group_id %d, " + "model_comm_group_rank %d, model_comm_num_groups %d, reader_group_rank %d", + global_rank, + model_comm_group_id, + model_comm_group_rank, + model_comm_num_groups, + reader_group_rank, + ) + def per_worker_init(self, n_workers: int, worker_id: int) -> None: """Called by worker_init_func on each copy of dataset. @@ -233,7 +280,11 @@ def __iter__(self) -> torch.Tensor: start = i - (self.multi_step - 1) * self.timeincrement end = i + (self.rollout + 1) * self.timeincrement - x = self.data[start : end : self.timeincrement] + if self.reader_group_size > 1: # read only a subset of the grid + x = self.data[start : end : self.timeincrement, :, :, self.grid_start : self.grid_end] + else: # read the full grid + x = self.data[start : end : self.timeincrement, :, :, :] + x = rearrange(x, "dates variables ensemble gridpoints -> dates ensemble gridpoints variables") self.ensemble_dim = 1 diff --git a/src/anemoi/training/distributed/strategy.py b/src/anemoi/training/distributed/strategy.py index c6509795..32c96dc6 100644 --- a/src/anemoi/training/distributed/strategy.py +++ b/src/anemoi/training/distributed/strategy.py @@ -9,7 +9,6 @@ import logging -import os import numpy as np import pytorch_lightning as pl @@ -27,19 +26,22 @@ class DDPGroupStrategy(DDPStrategy): """Distributed Data Parallel strategy with group communication.""" - def __init__(self, num_gpus_per_model: int, **kwargs: dict) -> None: + def __init__(self, num_gpus_per_model: int, read_group_size: int, **kwargs: dict) -> None: """Initialize the distributed strategy. Parameters ---------- num_gpus_per_model : int Number of GPUs per model to shard over. + read_group_size : int + Number of GPUs per reader group. **kwargs : dict Additional keyword arguments. """ super().__init__(**kwargs) self.model_comm_group_size = num_gpus_per_model + self.read_group_size = read_group_size def setup(self, trainer: pl.Trainer) -> None: assert self.accelerator is not None, "Accelerator is not initialized for distributed strategy" @@ -60,18 +62,56 @@ def setup(self, trainer: pl.Trainer) -> None: torch.distributed.new_group(x) for x in model_comm_group_ranks ] # every rank has to create all of these - model_comm_group_id, model_comm_group_nr, model_comm_group_rank = self.get_my_model_comm_group( + model_comm_group_id, model_comm_group_rank, model_comm_num_groups = self.get_my_model_comm_group( self.model_comm_group_size, ) model_comm_group = model_comm_groups[model_comm_group_id] - self.model.set_model_comm_group(model_comm_group) + self.model.set_model_comm_group( + model_comm_group, + model_comm_group_id, + model_comm_group_rank, + model_comm_num_groups, + self.model_comm_group_size, + ) + + # set up reader groups by further splitting model_comm_group_ranks with read_group_size: + + assert self.model_comm_group_size % self.read_group_size == 0, ( + f"Number of GPUs per model ({self.model_comm_group_size}) must be divisible by read_group_size " + f"({self.read_group_size})." + ) + + reader_group_ranks = np.array( + [ + np.split(group_ranks, int(self.model_comm_group_size / self.read_group_size)) + for group_ranks in model_comm_group_ranks + ], + ) # Shape: (num_model_comm_groups, model_comm_grp_size/read_group_size, read_group_size) + reader_groups = [[torch.distributed.new_group(x) for x in group_ranks] for group_ranks in reader_group_ranks] + reader_group_id, reader_group_rank, reader_group_size, reader_group_root = self.get_my_reader_group( + model_comm_group_rank, + self.read_group_size, + ) + # get all reader groups of the current model group + model_reader_groups = reader_groups[model_comm_group_id] + self.model.set_reader_groups( + model_reader_groups, + reader_group_id, + reader_group_rank, + reader_group_size, + ) + LOGGER.debug( - "Rank %d model_comm_group is %s, group number %d, with local group rank %d and comms_group_ranks %s", + "Rank %d model_comm_group_id: %d model_comm_group: %s model_comm_group_rank: %d " + "reader_group_id: %d reader_group: %s reader_group_rank: %d reader_group_root (global): %d", self.global_rank, - str(model_comm_group_nr), model_comm_group_id, - model_comm_group_rank, str(model_comm_group_ranks[model_comm_group_id]), + model_comm_group_rank, + reader_group_id, + reader_group_ranks[model_comm_group_id, reader_group_id], + reader_group_rank, + reader_group_root, ) # register hooks for correct gradient reduction @@ -109,7 +149,7 @@ def setup(self, trainer: pl.Trainer) -> None: # seed ranks self.seed_rnd(model_comm_group_id) - def get_my_model_comm_group(self, num_gpus_per_model: int) -> tuple[int, np.ndarray, int]: + def get_my_model_comm_group(self, num_gpus_per_model: int) -> tuple[int, int, int]: """Determine tasks that work together and from a model group. Parameters @@ -119,19 +159,69 @@ def get_my_model_comm_group(self, num_gpus_per_model: int) -> tuple[int, np.ndar Returns ------- - tuple[int, np.ndarray, int] - Model_comm_group id, Model_comm_group Nr, Model_comm_group rank + tuple[int, int, int] + Model_comm_group id, Model_comm_group rank, Number of model_comm_groups + """ + model_comm_group_id = self.global_rank // num_gpus_per_model + model_comm_group_rank = self.global_rank % num_gpus_per_model + model_comm_num_groups = self.world_size // num_gpus_per_model + + return model_comm_group_id, model_comm_group_rank, model_comm_num_groups + + def get_my_reader_group(self, model_comm_group_rank: int, read_group_size: int) -> tuple[int, int, int]: + """Determine tasks that work together and from a reader group. + + Parameters + ---------- + model_comm_group_rank : int + Rank within the model communication group. + read_group_size : int + Number of dataloader readers per model group. + + Returns + ------- + tuple[int, int, int] + Reader_group id, Reader_group rank, Reader_group root (global rank) """ - model_comm_groups = np.arange(0, self.world_size, dtype=np.int32) - model_comm_groups = np.split(model_comm_groups, self.world_size / num_gpus_per_model) + reader_group_id = model_comm_group_rank // read_group_size + reader_group_rank = model_comm_group_rank % read_group_size + reader_group_size = read_group_size + reader_group_root = (self.global_rank // read_group_size) * read_group_size + + return reader_group_id, reader_group_rank, reader_group_size, reader_group_root - model_comm_group_id = None - for i, model_comm_group in enumerate(model_comm_groups): - if self.global_rank in model_comm_group: - model_comm_group_id = i - model_comm_group_nr = model_comm_group - model_comm_group_rank = np.ravel(np.asarray(model_comm_group == self.global_rank).nonzero())[0] - return model_comm_group_id, model_comm_group_nr, model_comm_group_rank + def process_dataloader(self, dataloader: torch.utils.data.DataLoader) -> torch.utils.data.DataLoader: + """Pass communication group information to the dataloader for distributed training. + + Parameters + ---------- + dataloader : torch.utils.data.DataLoader + Dataloader to process. + + Returns + ------- + torch.utils.data.DataLoader + Processed dataloader. + + """ + dataloader = super().process_dataloader(dataloader) + + # pass model and reader group information to the dataloaders dataset + model_comm_group_id, model_comm_group_rank, model_comm_num_groups = self.get_my_model_comm_group( + self.model_comm_group_size, + ) + _, reader_group_rank, _, _ = self.get_my_reader_group(model_comm_group_rank, self.read_group_size) + + dataloader.dataset.set_comm_group_info( + self.global_rank, + model_comm_group_id, + model_comm_group_rank, + model_comm_num_groups, + reader_group_rank, + self.read_group_size, + ) + + return dataloader def seed_rnd(self, model_comm_group_id: int) -> None: """Seed the random number generators for the rank.""" @@ -145,7 +235,7 @@ def seed_rnd(self, model_comm_group_id: int) -> None: "Strategy: Rank %d, model comm group id %d, base seed %d, seeded with %d, " "running with random seed: %d, sanity rnd: %s" ), - int(os.environ.get("SLURM_PROCID", "0")), + self.global_rank, model_comm_group_id, base_seed, initial_seed, diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 36f6fa9a..80f6d38e 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -9,8 +9,6 @@ import logging -import math -import os from collections import defaultdict from collections.abc import Mapping @@ -119,6 +117,7 @@ def __init__( self.use_zero_optimizer = config.training.zero_optimizer self.model_comm_group = None + self.reader_groups = None LOGGER.debug("Rollout window length: %d", self.rollout) LOGGER.debug("Rollout increase every : %d epochs", self.rollout_epoch_increment) @@ -127,11 +126,13 @@ def __init__( self.enable_plot = config.diagnostics.plot.enabled - self.model_comm_group_id = int(os.environ.get("SLURM_PROCID", "0")) // config.hardware.num_gpus_per_model - self.model_comm_group_rank = int(os.environ.get("SLURM_PROCID", "0")) % config.hardware.num_gpus_per_model - self.model_comm_num_groups = math.ceil( - config.hardware.num_gpus_per_node * config.hardware.num_nodes / config.hardware.num_gpus_per_model, - ) + # lazy init model and reader group info, will be set by the DDPGroupStrategy: + self.model_comm_group_id = 0 + self.model_comm_group_rank = 0 + self.model_comm_num_groups = 1 + + self.reader_group_id = 0 + self.reader_group_rank = 0 def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x, self.model_comm_group) @@ -192,9 +193,31 @@ def metrics_loss_scaling(config: DictConfig, data_indices: IndexCollection) -> t metric_ranges_validation[key] = [idx] return metric_ranges, metric_ranges_validation, loss_scaling - def set_model_comm_group(self, model_comm_group: ProcessGroup) -> None: - LOGGER.debug("set_model_comm_group: %s", model_comm_group) + def set_model_comm_group( + self, + model_comm_group: ProcessGroup, + model_comm_group_id: int, + model_comm_group_rank: int, + model_comm_num_groups: int, + model_comm_group_size: int, + ) -> None: self.model_comm_group = model_comm_group + self.model_comm_group_id = model_comm_group_id + self.model_comm_group_rank = model_comm_group_rank + self.model_comm_num_groups = model_comm_num_groups + self.model_comm_group_size = model_comm_group_size + + def set_reader_groups( + self, + reader_groups: list[ProcessGroup], + reader_group_id: int, + reader_group_rank: int, + reader_group_size: int, + ) -> None: + self.reader_groups = reader_groups + self.reader_group_id = reader_group_id + self.reader_group_rank = reader_group_rank + self.reader_group_size = reader_group_size def advance_input( self, @@ -230,9 +253,15 @@ def _step( validation_mode: bool = False, ) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]: del batch_idx - loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False) + + # all gather grid shards from reader group + if self.reader_group_size > 1: + batch = self.allgather_batch(batch) + # for validation not normalized in-place because remappers cannot be applied in-place batch = self.model.pre_processors(batch, in_place=not validation_mode) + + loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False) metrics = {} # start rollout of preprocessed batch @@ -268,6 +297,28 @@ def _step( loss *= 1.0 / self.rollout return loss, metrics, y_preds + def allgather_batch(self, batch: torch.Tensor) -> torch.Tensor: + grid_size = self.model.metadata["dataset"]["shape"][-1] + grid_shard_size = grid_size // self.reader_group_size + last_grid_shard_size = grid_size - (grid_shard_size * (self.reader_group_size - 1)) + + # prepare tensor list with correct shapes for all_gather + shard_shape = list(batch.shape) + shard_shape[-2] = grid_shard_size + last_shard_shape = list(batch.shape) + last_shard_shape[-2] = last_grid_shard_size + + tensor_list = [torch.empty(tuple(shard_shape), device=self.device) for _ in range(self.reader_group_size - 1)] + tensor_list.append(torch.empty(last_shard_shape, device=self.device)) + + torch.distributed.all_gather( + tensor_list, + batch, + group=self.reader_groups[self.reader_group_id], + ) + + return torch.cat(tensor_list, dim=-2) + def calculate_val_metrics( self, y_pred: torch.Tensor, diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index b772eb2a..fa3260b5 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -12,7 +12,6 @@ import datetime import logging -import os from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING @@ -106,7 +105,7 @@ def initial_seed(self) -> int: (torch.rand(1), np_rng.random()) LOGGER.debug( "Initial seed: Rank %d, initial seed %d, running with random seed: %d", - int(os.environ.get("SLURM_PROCID", "0")), + self.strategy.global_rank, initial_seed, rnd_seed, ) @@ -335,6 +334,7 @@ def strategy(self) -> DDPGroupStrategy: """Training strategy.""" return DDPGroupStrategy( self.config.hardware.num_gpus_per_model, + self.config.dataloader.get("read_group_size", 1), static_graph=not self.config.training.accum_grad_batches > 1, )