Skip to content

Commit

Permalink
Log number of sequences consumed
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelburger committed Oct 9, 2024
1 parent d037562 commit fcd1d10
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/nanotron/data/petagraph_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(self,

self.rank = rank
self.log_directory = log_directory
self.num_consumed_sequences = 0
self.consumed_files_path = self.log_directory / f"consumed_files/consumed_files_rank_{self.rank}.txt"

# Take list of already consumed lists and remove them from the
Expand Down Expand Up @@ -272,6 +273,10 @@ def generate(self):
if text_raw is None or len(text_raw) == 0:
continue

# Log the consumed sequences
self.num_consumed_sequences += 1

# Log the consumed files
if self.log_directory is not None:
if source_path not in self.consumed_files:
with open(self.consumed_files_path, "a") as f:
Expand Down
22 changes: 22 additions & 0 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,12 @@ def train_step_logs(
else:
current_epoch = -1

if current_dataset is not None and hasattr(current_dataset, "num_consumed_sequences"):
num_consumed_sequences = current_dataset.num_consumed_sequences
current_dataset.num_consumed_sequences = 0
else:
num_consumed_sequences = 0

# Gather the values across all ranks
world_size_dp_pg = self.parallel_context.dp_pg.size()

Expand All @@ -624,9 +630,22 @@ def train_step_logs(
current_epoch_ranks = current_epoch_t_all.cpu().numpy()
current_epoch_all = current_epoch_ranks.mean()

num_consumed_seq_t = torch.tensor(num_consumed_sequences, device="cuda", dtype=torch.int64)
num_consumed_seq_t_all = torch.zeros(world_size_dp_pg, device="cuda", dtype=torch.int64)
dist.all_gather_into_tensor(
output_tensor=num_consumed_seq_t_all,
input_tensor=num_consumed_seq_t,
group=self.parallel_context.dp_pg
)
num_consumed_seq_ranks = num_consumed_seq_t_all.cpu().numpy()
num_consumed_seq_all = num_consumed_seq_ranks.sum()
self.metadata.consumed_num_sequences += int(num_consumed_seq_all)


else:
num_consumed_files_all = None
current_epoch_all = None
num_consumed_seq_all = None

# Logging on logger ranks
if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks:
Expand Down Expand Up @@ -659,6 +678,9 @@ def train_step_logs(
if current_epoch_all is not None:
log_entries.append(LogItem("rank_avg_epoch", current_epoch_all, "human_format"))

if num_consumed_seq_all is not None:
log_entries.append(LogItem("num_consumed_seq_all", num_consumed_seq_all, "human_format"))

if self.config.optimizer.clip_grad is not None:
log_entries.append(LogItem("grad_norm", self.grad_norm_unclipped.item(), "human_format")) # , ".3f"))

Expand Down

0 comments on commit fcd1d10

Please sign in to comment.