Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/improve dataloader memory #76

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from

Conversation

japols
Copy link
Member

@japols japols commented Oct 9, 2024

Describe your changes

This PR adds a configurable read_frequency (config.dataloader.read_frequency) that defines how many GPUs per model communication group read (the same) data. Increasing the read_frequency heavily reduces CPU memory usage as dataloaders don't reproduce the same data.

The model communication group is further subdivided into reader groups of size read_frequency. For each reader group, only rank 0 reads data from the dataloader and communicates it to the rest via broadcast.

The following experiments on n320 show that CPU memory usage goes down as we increase the read_frequency:

Screenshot 2024-10-09 at 11 11 14

MLFlow

This additional broadcasting step doesn't affect runtime (time spent waiting for broadcast would otherwise be spent loading data):

read_frequency avg_epoch_time(s)
1 130.43
2 129.15
4 128.78

for 10 epochs @100 steps.

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist before requesting a review

  • I have performed a self-review of my code
  • My code follows the style guidelines of this project
  • I have commented my code, particularly in hard-to-understand areas
  • I have updated the documentation and docstrings to reflect the changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have ensured that the code is still pip-installable after the changes and runs
  • I have not introduced new dependencies in the inference partion of the model
  • I have ran this on single GPU
  • I have ran this on multi-GPU or multi-node
  • I have ran this to work on LUMI (or made sure the changes work independently.)
  • I have ran the Benchmark Profiler against the old version of the code

Tag possible reviewers

@ssmmnn11 @mishooax @theissenhelen @JesperDramsch @sahahner @mchantry

@FussyDuck
Copy link

FussyDuck commented Oct 9, 2024

CLA assistant check
All committers have signed the CLA.

Copy link
Member

@ssmmnn11 ssmmnn11 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! Would be nice to clean up the group creation a bit more and consolidate everything to the strategy :-).

@@ -124,6 +125,23 @@ def __init__(
config.hardware.num_gpus_per_node * config.hardware.num_nodes / config.hardware.num_gpus_per_model,
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually need any of these here? same for the model_comm_group etc. ... above. I think these are properly initialised by the strategy which uses model.set_model_comm_group and set_reader_group. So might be enough to initialise these to sensible default values when the model is not sharded.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it to pass everything from the DDPGroupStrategy via model.set_model_comm_group() and set_reader_group().

else:
# init batch tensor with correct shape on non-root ranks
shape = (batch.shape[0],) + tuple(batch[0].tolist())
batch = torch.zeros(shape, device=self.device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not torch.empty? Does probably not matter ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, no reason to use zeros instead of empty, changed it.

@@ -74,11 +74,19 @@ def __init__(self, config: DictConfig) -> None:
* self.config.hardware.num_nodes
// self.config.hardware.num_gpus_per_model
) # number of model communication groups

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to get model comm groups and model reader groups from the strategy / use the rountines in the strategy to compute them instead of having code here to re-compute the groups. This should be possible because we initialise the strategy before loading the datamodule.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The datamodule only needed these to pass them down to the dataloader.dataset, I removed them entirely from the datamodule now.

@@ -308,6 +308,7 @@ def strategy(self) -> DDPGroupStrategy:
"""Training strategy."""
return DDPGroupStrategy(
self.config.hardware.num_gpus_per_model,
self.config.dataloader.read_frequency,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we think of a way to make use of the routines / groups computed by the strategy in self.datamodule?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I managed to pass them to the dataloader directly via DDPGroupStrategy.process_dataloader() which is called by pytorch_lightning in trainer.fit(model, datamodule).

@japols japols force-pushed the feature/improve-dataloader-memory branch from 4cc8a78 to 3c6b5c9 Compare October 9, 2024 15:50
@japols japols added the enhancement New feature or request label Oct 9, 2024
@gabrieloks
Copy link
Contributor

gabrieloks commented Oct 10, 2024

Very nice feature, Jan. I tested this on a rollout run on n320. These runs are painful because we need to reduce the number of workers to avoid out of memory issues and training speed is reduced drastically.

But I did a test with your branch and the develop branch and the results are quite good!

Here is a comparison in terms of memory usage for num_workers = 6:

image

The very good thing is that the job on the develop branch actually crashes at the end of rollout=2 while with your new feature, the rollout fine-tuning keeps going. This will considerably speed up rollout fine-tuning.

@mchantry

Copy link
Member

@mishooax mishooax left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice work, Jan! 👍

# get the grid shard size and start/end indices
grid_shard_size = self.grid_size // self.reader_group_size
self.grid_start = self.reader_group_rank * grid_shard_size
if self.reader_group_rank == self.reader_group_size - 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be shortened to

self.grid_end = min(self.grid_size, (self.reader_group_rank + 1) * grid_shard_size)

?

@@ -233,7 +280,11 @@ def __iter__(self) -> torch.Tensor:
start = i - (self.multi_step - 1) * self.timeincrement
end = i + (self.rollout + 1) * self.timeincrement

x = self.data[start : end : self.timeincrement]
if self.reader_group_size > 1: # read only a subset of the grid
x = self.data[start : end : self.timeincrement, :, :, self.grid_start : self.grid_end]
Copy link
Member

@mishooax mishooax Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@japols i'm puzzled a bit by this: i get what you're doing here, but given the way the zarr is chunked on disk (chunk i == self.data[i]), wouldn't this imply that each worker is still reading a full chunk (a time slice, i.e. all latlons) and then discards the points that are not in its shard?

tagging @floriankrb in case i misunderstood how the zarr chunking is done on disk (or how the slice index is implemented in anemoi-datasets)

@@ -106,7 +105,7 @@ def initial_seed(self) -> int:
(torch.rand(1), np_rng.random())
LOGGER.debug(
"Initial seed: Rank %d, initial seed %d, running with random seed: %d",
int(os.environ.get("SLURM_PROCID", "0")),
self.strategy.global_rank,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is self.strategy guaranteed to be correctly initialized at this point?

@@ -93,6 +87,7 @@ def __init__(
assert self.multi_step > 0, "Multistep value must be greater than zero."
self.ensemble_dim: int = 2
self.ensemble_size = self.data.shape[self.ensemble_dim]
self.grid_size = self.data.shape[-1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may be useful (easier to understand) if we defined a self.grid_dim = -1 and used that instead? (like we do for the ensemble dim just above)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants