Skip to content

Commit

Permalink
Log mean seq length processed
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelburger committed Oct 9, 2024
1 parent fcd1d10 commit 45328b2
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/nanotron/data/petagraph_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def __init__(self,
for _ in range(warmup_sample_size):
_ = next(self.iterable_dataset)

self.consumed_seq_len_queue = deque(maxlen=1000)
self.consumed_seq_len_queue = deque(maxlen=5000)
if self.log_directory is not None:
self.logging_func(f"[PetaGraphStreamDataset] Logging to {self.log_directory} on rank {self.rank}")

Expand Down
23 changes: 18 additions & 5 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,10 +584,8 @@ def train_step_logs(

if current_dataset is not None and hasattr(current_dataset, "consumed_seq_len_queue"):
consumed_seq_lens = np.array(list(current_dataset.consumed_seq_len_queue), dtype=np.int64)
median_seq_len = np.median(consumed_seq_lens)
mean_seq_len = np.mean(consumed_seq_lens)
else:
median_seq_len = 0.0
mean_seq_len = 0.0

if current_dataset is not None and hasattr(current_dataset, "consumed_files"):
Expand Down Expand Up @@ -640,12 +638,24 @@ def train_step_logs(
num_consumed_seq_ranks = num_consumed_seq_t_all.cpu().numpy()
num_consumed_seq_all = num_consumed_seq_ranks.sum()
self.metadata.consumed_num_sequences += int(num_consumed_seq_all)
num_consumed_seq_log = self.metadata.consumed_num_sequences

mean_consumed_seq_len_t = torch.tensor(mean_seq_len, device="cuda", dtype=torch.float32)
mean_consumed_seq_len_t_all = torch.zeros(world_size_dp_pg, device="cuda", dtype=torch.float32)
dist.all_gather_into_tensor(
output_tensor=mean_consumed_seq_len_t_all,
input_tensor=mean_consumed_seq_len_t,
group=self.parallel_context.dp_pg
)
mean_consumed_seq_len_ranks = mean_consumed_seq_len_t_all.cpu().numpy()
mean_consumed_seq_len_all = mean_consumed_seq_len_ranks.mean()


else:
num_consumed_files_all = None
current_epoch_all = None
num_consumed_seq_all = None
num_consumed_seq_log = None
mean_consumed_seq_len_all = None

# Logging on logger ranks
if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks:
Expand Down Expand Up @@ -678,8 +688,11 @@ def train_step_logs(
if current_epoch_all is not None:
log_entries.append(LogItem("rank_avg_epoch", current_epoch_all, "human_format"))

if num_consumed_seq_all is not None:
log_entries.append(LogItem("num_consumed_seq_all", num_consumed_seq_all, "human_format"))
if num_consumed_seq_log is not None:
log_entries.append(LogItem("num_consumed_seq_all", num_consumed_seq_log, "human_format"))

if mean_consumed_seq_len_all is not None:
log_entries.append(LogItem("approx_mean_consumed_seq_len", mean_consumed_seq_len_all, "human_format"))

if self.config.optimizer.clip_grad is not None:
log_entries.append(LogItem("grad_norm", self.grad_norm_unclipped.item(), "human_format")) # , ".3f"))
Expand Down

0 comments on commit 45328b2

Please sign in to comment.