Skip to content

Commit

Permalink
add data loading times
Browse files Browse the repository at this point in the history
  • Loading branch information
lessw2020 committed Feb 26, 2024
1 parent 23b7739 commit 02c1fdf
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
8 changes: 7 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class TrainState:
current_loss: float = -1
losses: List[float] = field(default_factory=list)
iter_times: List[float] = field(default_factory=list)
data_load_times: List[float] = field(default_factory=list)

def state_dict(self) -> Dict[str, Any]:
return {
Expand Down Expand Up @@ -178,10 +179,13 @@ def main(job_config: JobConfig):
):
train_state.step += 1
# get batch
data_load_start = timer()
batch = next(iter(data_loader))
input_ids, labels = batch
input_ids = input_ids.cuda()
labels = labels.cuda()
data_load_time = round(timer() - data_load_start, 4)
train_state.data_load_times.append(data_load_time)
nwords_since_last_log += labels.numel()

optimizer.zero_grad()
Expand Down Expand Up @@ -264,7 +268,7 @@ def main(job_config: JobConfig):

rank0_log(
f"step: {train_state.step}, loss: {round(train_state.current_loss,4)},"
f" time: {curr_iter_time}, lr: {round(float(scheduler.get_last_lr()[0]), 8)}"
f" iter: {curr_iter_time}, data: {data_load_time}, lr: {round(float(scheduler.get_last_lr()[0]), 8)}"
)
scheduler.step()

Expand All @@ -277,6 +281,8 @@ def main(job_config: JobConfig):
if len(train_state.iter_times) > 3:
avg_iter_time = np.mean(train_state.iter_times[3:])
rank0_log(f"Average iter time: {avg_iter_time:.4f} seconds")
avg_data_load_time = np.mean(train_state.data_load_times[3:])
rank0_log(f"Average data load time: {avg_data_load_time:.4f} seconds")

rank0_log(f"{gpu_metrics.get_current_stats()}")

Expand Down
2 changes: 1 addition & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
dump_folder = "./outputs"

[profiling]
run_profiler = false
run_profiler = true
save_traces_folder = "profiling/traces"
# profiling frequency - example: 10 means every 10th iter will be profiled
profile_every_x_iter = 10
Expand Down

0 comments on commit 02c1fdf

Please sign in to comment.