Skip to content

Commit

Permalink
Reconfigure limit_train_batches in terms of micro batches (#8738)
Browse files Browse the repository at this point in the history
* Reconfigure limit_train_batches in micro batches

Signed-off-by: Abhishree <[email protected]>

* Replace _reconfigure_val_batches with _reconfigure_limit_batches for BERT & T5

Signed-off-by: Abhishree <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Abhishree <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
athitten and pre-commit-ci[bot] authored Apr 3, 2024
1 parent 0eb6302 commit cf8d882
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -324,40 +324,43 @@ def get_model_module_list(self):
else:
return [self.model]

def _reconfigure_val_batches(self):
def _reconfigure_limit_batches(self, limit_batches, dataloader, mode):
"""
Reconfigure trainer.limit_val_batches for pretraining
"""
# Override limit_val_batches to be a multiple of num microbatches and so there are limit_val_batches//num_micro_batches num of global batches
if isinstance(self.trainer.limit_val_batches, int):
self.trainer.limit_val_batches *= get_num_microbatches()
# Override limit_batches in terms of num microbatches and so there are limit_batches//num_micro_batches num of global batches
if isinstance(limit_batches, int):
limit_batches *= get_num_microbatches()
else:
assert isinstance(self.trainer.limit_val_batches, float)
# Don't reconfigure if limit_val_batches is 0.0 or if there's no val dataloader
if self.trainer.limit_val_batches == 0.0 or self._validation_dl is None:
assert isinstance(limit_batches, float)
# Don't reconfigure if limit_batches is 0.0 or if there's no dataloader
if limit_batches == 0.0 or dataloader is None:
return
# len(self._validation_dl) returns len as num of microbatches
val_len_in_micro_batches = len(self._validation_dl)
if self._validation_ds is not None and len(self._validation_dl) != float("inf"):
if self.trainer.limit_val_batches == 1.0:
self.trainer.limit_val_batches = val_len_in_micro_batches
# len(dataloader) returns len as num of microbatches
dl_len_in_micro_batches = len(dataloader)
if len(dataloader) != float("inf"):
if limit_batches == 1.0:
limit_batches = dl_len_in_micro_batches
else:
limit_val_micro_batches = int(val_len_in_micro_batches * self.trainer.limit_val_batches)
if limit_val_micro_batches == 0 and self.trainer.limit_val_batches > 0.0:
min_percentage = 1.0 / len(self._validation_dl)
limit_micro_batches = int(dl_len_in_micro_batches * limit_batches)
if limit_micro_batches == 0 and limit_batches > 0.0:
min_percentage = 1.0 / len(dataloader)
raise MisconfigurationException(
f"You requested to check {self.trainer.limit_val_batches} of the val_dataloader but"
f" {self.trainer.limit_val_batches} * {len(self._validation_dl)} < 1. Please increase the"
f"You requested to check {limit_batches} of the val_dataloader but"
f" {limit_batches} * {len(dataloader)} < 1. Please increase the"
f" `limit_val_batches` argument. Try at least"
f" `limit_val_batches={min_percentage}`"
)
# Make sure trainer.limit_val_batches is a multiple of num of microbatches
if limit_val_micro_batches < get_num_microbatches():
self.trainer.limit_val_batches = get_num_microbatches()
if limit_micro_batches < get_num_microbatches():
limit_batches = get_num_microbatches()
else:
self.trainer.limit_val_batches = (
limit_val_micro_batches - limit_val_micro_batches % get_num_microbatches()
)
limit_batches = limit_batches - limit_batches % get_num_microbatches()

if mode == 'train':
self.trainer.limit_train_batches = limit_batches
else:
self.trainer.limit_val_batches = limit_batches

# Override num sanity steps to be a multiple of num of microbatches
self.trainer.num_sanity_val_steps *= get_num_microbatches()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -683,8 +683,6 @@ def build_LDDL_data(self, cfg):
logging.info(f'Finished building LDDL Dataloaders')

def build_train_valid_test_datasets(self):
# Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step
self._reconfigure_val_batches()
logging.info('Building Bert datasets.')
if self.trainer.limit_val_batches > 1.0 and isinstance(self.trainer.limit_val_batches, float):
raise ValueError("limit_val_batches must be an integer or float less than or equal to 1.0.")
Expand Down Expand Up @@ -730,6 +728,10 @@ def build_train_valid_test_datasets(self):
if self._test_ds is not None:
logging.info(f'Length of test dataset: {len(self._test_ds)}')
logging.info(f'Finished building Bert datasets.')

# Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step
self._reconfigure_limit_batches(self.trainer.limit_val_batches, self._validation_dl, 'val')

return self._train_ds, self._validation_ds, self._test_ds

def backward(self, *args, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1453,8 +1453,10 @@ def setup(self, stage=None):
self.setup_training_data(self.cfg.data)
self.setup_validation_data(self.cfg.data)
self.setup_test_data(self.cfg.data)
# Override limit_train_batches in terms of num of microbatches
self._reconfigure_limit_batches(self.trainer.limit_train_batches, self._train_dl, 'train')
# Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step
self._reconfigure_val_batches()
self._reconfigure_limit_batches(self.trainer.limit_val_batches, self._validation_dl, 'val')

if stage == 'fit':
self.initialize_last_rank_embeddings()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,6 @@ def add_special_tokens_to_tokenizer(
tokenizer.add_special_tokens([f'<extra_id_{mask_type}>'])

def build_train_valid_test_datasets(self):
# Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step
self._reconfigure_val_batches()
logging.info(f'Building {self.model_name} datasets.')
if self.trainer.limit_val_batches > 1.0 and isinstance(self.trainer.limit_val_batches, float):
raise ValueError("limit_val_batches must be an integer or float less than or equal to 1.0.")
Expand Down Expand Up @@ -245,6 +243,10 @@ def build_train_valid_test_datasets(self):
logging.info(f'Length of val dataset: {len(self._validation_ds)}')
logging.info(f'Length of test dataset: {len(self._test_ds)}')
logging.info(f'Finished building {self.model_name} datasets.')

# Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step
self._reconfigure_limit_batches(self.trainer.limit_val_batches, self._validation_dl, 'val')

return self._train_ds, self._validation_ds, self._test_ds

def list_available_models(self):
Expand Down

0 comments on commit cf8d882

Please sign in to comment.