-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/improve dataloader memory #76
base: develop
Are you sure you want to change the base?
Conversation
…via dataloader.read_frequency
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! Would be nice to clean up the group creation a bit more and consolidate everything to the strategy :-).
@@ -124,6 +125,23 @@ def __init__( | |||
config.hardware.num_gpus_per_node * config.hardware.num_nodes / config.hardware.num_gpus_per_model, | |||
) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we actually need any of these here? same for the model_comm_group etc. ... above. I think these are properly initialised by the strategy which uses model.set_model_comm_group and set_reader_group. So might be enough to initialise these to sensible default values when the model is not sharded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed it to pass everything from the DDPGroupStrategy via model.set_model_comm_group() and set_reader_group().
else: | ||
# init batch tensor with correct shape on non-root ranks | ||
shape = (batch.shape[0],) + tuple(batch[0].tolist()) | ||
batch = torch.zeros(shape, device=self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not torch.empty? Does probably not matter ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, no reason to use zeros instead of empty, changed it.
@@ -74,11 +74,19 @@ def __init__(self, config: DictConfig) -> None: | |||
* self.config.hardware.num_nodes | |||
// self.config.hardware.num_gpus_per_model | |||
) # number of model communication groups | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to get model comm groups and model reader groups from the strategy / use the rountines in the strategy to compute them instead of having code here to re-compute the groups. This should be possible because we initialise the strategy before loading the datamodule.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The datamodule only needed these to pass them down to the dataloader.dataset, I removed them entirely from the datamodule now.
src/anemoi/training/train/train.py
Outdated
@@ -308,6 +308,7 @@ def strategy(self) -> DDPGroupStrategy: | |||
"""Training strategy.""" | |||
return DDPGroupStrategy( | |||
self.config.hardware.num_gpus_per_model, | |||
self.config.dataloader.read_frequency, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we think of a way to make use of the routines / groups computed by the strategy in self.datamodule?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I managed to pass them to the dataloader directly via DDPGroupStrategy.process_dataloader() which is called by pytorch_lightning in trainer.fit(model, datamodule).
…nstead of SLURM_PROCID
4cc8a78
to
3c6b5c9
Compare
for more information, see https://pre-commit.ci
Very nice feature, Jan. I tested this on a rollout run on n320. These runs are painful because we need to reduce the number of workers to avoid out of memory issues and training speed is reduced drastically. But I did a test with your branch and the develop branch and the results are quite good! Here is a comparison in terms of memory usage for num_workers = 6: The very good thing is that the job on the develop branch actually crashes at the end of rollout=2 while with your new feature, the rollout fine-tuning keeps going. This will considerably speed up rollout fine-tuning. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice work, Jan! 👍
# get the grid shard size and start/end indices | ||
grid_shard_size = self.grid_size // self.reader_group_size | ||
self.grid_start = self.reader_group_rank * grid_shard_size | ||
if self.reader_group_rank == self.reader_group_size - 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this be shortened to
self.grid_end = min(self.grid_size, (self.reader_group_rank + 1) * grid_shard_size)
?
@@ -233,7 +280,11 @@ def __iter__(self) -> torch.Tensor: | |||
start = i - (self.multi_step - 1) * self.timeincrement | |||
end = i + (self.rollout + 1) * self.timeincrement | |||
|
|||
x = self.data[start : end : self.timeincrement] | |||
if self.reader_group_size > 1: # read only a subset of the grid | |||
x = self.data[start : end : self.timeincrement, :, :, self.grid_start : self.grid_end] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@japols i'm puzzled a bit by this: i get what you're doing here, but given the way the zarr is chunked on disk (chunk i
== self.data[i]
), wouldn't this imply that each worker is still reading a full chunk (a time slice, i.e. all latlons) and then discards the points that are not in its shard?
tagging @floriankrb in case i misunderstood how the zarr chunking is done on disk (or how the slice index is implemented in anemoi-datasets)
@@ -106,7 +105,7 @@ def initial_seed(self) -> int: | |||
(torch.rand(1), np_rng.random()) | |||
LOGGER.debug( | |||
"Initial seed: Rank %d, initial seed %d, running with random seed: %d", | |||
int(os.environ.get("SLURM_PROCID", "0")), | |||
self.strategy.global_rank, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is self.strategy
guaranteed to be correctly initialized at this point?
@@ -93,6 +87,7 @@ def __init__( | |||
assert self.multi_step > 0, "Multistep value must be greater than zero." | |||
self.ensemble_dim: int = 2 | |||
self.ensemble_size = self.data.shape[self.ensemble_dim] | |||
self.grid_size = self.data.shape[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
may be useful (easier to understand) if we defined a self.grid_dim = -1
and used that instead? (like we do for the ensemble dim just above)
Describe your changes
This PR adds a configurable read_frequency (config.dataloader.read_frequency) that defines how many GPUs per model communication group read (the same) data. Increasing the read_frequency heavily reduces CPU memory usage as dataloaders don't reproduce the same data.
The model communication group is further subdivided into reader groups of size read_frequency. For each reader group, only rank 0 reads data from the dataloader and communicates it to the rest via broadcast.
The following experiments on n320 show that CPU memory usage goes down as we increase the read_frequency:
MLFlow
This additional broadcasting step doesn't affect runtime (time spent waiting for broadcast would otherwise be spent loading data):
for 10 epochs @100 steps.
Type of change
Checklist before requesting a review
Tag possible reviewers
@ssmmnn11 @mishooax @theissenhelen @JesperDramsch @sahahner @mchantry