-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/improve dataloader memory #76
base: develop
Are you sure you want to change the base?
Changes from all commits
cdaf082
fcc7c93
5d171c7
8c16e54
ee94593
3c6b5c9
57a13c5
bcd0fe6
9a22225
6615c97
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,9 +36,6 @@ def __init__( | |
rollout: int = 1, | ||
multistep: int = 1, | ||
timeincrement: int = 1, | ||
model_comm_group_rank: int = 0, | ||
model_comm_group_id: int = 0, | ||
model_comm_num_groups: int = 1, | ||
shuffle: bool = True, | ||
label: str = "generic", | ||
) -> None: | ||
|
@@ -54,12 +51,6 @@ def __init__( | |
time increment between samples, by default 1 | ||
multistep : int, optional | ||
collate (t-1, ... t - multistep) into the input state vector, by default 1 | ||
model_comm_group_rank : int, optional | ||
process rank in the torch.distributed group (important when running on multiple GPUs), by default 0 | ||
model_comm_group_id: int, optional | ||
device group ID, default 0 | ||
model_comm_num_groups : int, optional | ||
total number of device groups, by default 1 | ||
shuffle : bool, optional | ||
Shuffle batches, by default True | ||
label : str, optional | ||
|
@@ -77,11 +68,14 @@ def __init__( | |
self.n_samples_per_epoch_total: int = 0 | ||
self.n_samples_per_epoch_per_worker: int = 0 | ||
|
||
# DDP-relevant info | ||
self.model_comm_group_rank = model_comm_group_rank | ||
self.model_comm_num_groups = model_comm_num_groups | ||
self.model_comm_group_id = model_comm_group_id | ||
self.global_rank = int(os.environ.get("SLURM_PROCID", "0")) | ||
# lazy init model and reader group info, will be set by the DDPGroupStrategy: | ||
self.model_comm_group_rank = 0 | ||
self.model_comm_num_groups = 1 | ||
self.model_comm_group_id = 0 | ||
self.global_rank = 0 | ||
|
||
self.reader_group_rank = 0 | ||
self.reader_group_size = 1 | ||
|
||
# additional state vars (lazy init) | ||
self.n_samples_per_worker = 0 | ||
|
@@ -93,6 +87,7 @@ def __init__( | |
assert self.multi_step > 0, "Multistep value must be greater than zero." | ||
self.ensemble_dim: int = 2 | ||
self.ensemble_size = self.data.shape[self.ensemble_dim] | ||
self.grid_size = self.data.shape[-1] | ||
|
||
@cached_property | ||
def statistics(self) -> dict: | ||
|
@@ -128,6 +123,58 @@ def valid_date_indices(self) -> np.ndarray: | |
""" | ||
return get_usable_indices(self.data.missing, len(self.data), self.rollout, self.multi_step, self.timeincrement) | ||
|
||
def set_comm_group_info( | ||
self, | ||
global_rank: int, | ||
model_comm_group_id: int, | ||
model_comm_group_rank: int, | ||
model_comm_num_groups: int, | ||
reader_group_rank: int, | ||
reader_group_size: int, | ||
) -> None: | ||
"""Set model and reader communication group information (called by DDPGroupStrategy). | ||
|
||
Parameters | ||
---------- | ||
global_rank : int | ||
Global rank | ||
model_comm_group_id : int | ||
Model communication group ID | ||
model_comm_group_rank : int | ||
Model communication group rank | ||
model_comm_num_groups : int | ||
Number of model communication groups | ||
reader_group_rank : int | ||
Reader group rank | ||
reader_group_size : int | ||
Reader group size | ||
""" | ||
self.global_rank = global_rank | ||
self.model_comm_group_id = model_comm_group_id | ||
self.model_comm_group_rank = model_comm_group_rank | ||
self.model_comm_num_groups = model_comm_num_groups | ||
self.reader_group_rank = reader_group_rank | ||
self.reader_group_size = reader_group_size | ||
|
||
if self.reader_group_size > 1: | ||
# get the grid shard size and start/end indices | ||
grid_shard_size = self.grid_size // self.reader_group_size | ||
self.grid_start = self.reader_group_rank * grid_shard_size | ||
if self.reader_group_rank == self.reader_group_size - 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can this be shortened to self.grid_end = min(self.grid_size, (self.reader_group_rank + 1) * grid_shard_size) ? |
||
self.grid_end = self.grid_size | ||
else: | ||
self.grid_end = (self.reader_group_rank + 1) * grid_shard_size | ||
|
||
LOGGER.debug( | ||
"NativeGridDataset.set_group_info(): global_rank %d, model_comm_group_id %d, " | ||
"model_comm_group_rank %d, model_comm_num_groups %d, reader_group_rank %d", | ||
global_rank, | ||
model_comm_group_id, | ||
model_comm_group_rank, | ||
model_comm_num_groups, | ||
reader_group_rank, | ||
) | ||
|
||
def per_worker_init(self, n_workers: int, worker_id: int) -> None: | ||
"""Called by worker_init_func on each copy of dataset. | ||
|
||
|
@@ -233,7 +280,11 @@ def __iter__(self) -> torch.Tensor: | |
start = i - (self.multi_step - 1) * self.timeincrement | ||
end = i + (self.rollout + 1) * self.timeincrement | ||
|
||
x = self.data[start : end : self.timeincrement] | ||
if self.reader_group_size > 1: # read only a subset of the grid | ||
x = self.data[start : end : self.timeincrement, :, :, self.grid_start : self.grid_end] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @japols i'm puzzled a bit by this: i get what you're doing here, but given the way the zarr is chunked on disk (chunk tagging @floriankrb in case i misunderstood how the zarr chunking is done on disk (or how the slice index is implemented in anemoi-datasets) |
||
else: # read the full grid | ||
x = self.data[start : end : self.timeincrement, :, :, :] | ||
|
||
x = rearrange(x, "dates variables ensemble gridpoints -> dates ensemble gridpoints variables") | ||
self.ensemble_dim = 1 | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
may be useful (easier to understand) if we defined a
self.grid_dim = -1
and used that instead? (like we do for the ensemble dim just above)