Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
lessw2020 committed Feb 27, 2024
1 parent d1110f0 commit 34a25d8
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from torchtrain.parallelisms import models_parallelize_fns, ParallelDims

from torchtrain.profiling import maybe_run_profiler
from torchtrain.utils import dist_max, dist_mean
from torchtrain.utils import Color, dist_max, dist_mean


@dataclass
Expand Down Expand Up @@ -120,7 +120,8 @@ def main(job_config: JobConfig):
# log model size
model_param_count = get_num_params(model)
rank0_log(
f"Model {model_name} {job_config.model.flavor} size: {model_param_count:,} total parameters"
f"{Color.blue}Model {model_name} {job_config.model.flavor} {Color.red}size: {model_param_count:,}"
f" total parameters{Color.reset}"
)
gpu_metrics = GPUMemoryMonitor("cuda")
rank0_log(f"GPU memory usage: {gpu_metrics}")
Expand Down Expand Up @@ -269,8 +270,10 @@ def main(job_config: JobConfig):
time_last_log = timer()

rank0_log(
f"step: {train_state.step:>2} loss: {round(train_state.current_loss,4):>7}"
f" iter: {curr_iter_time:>7} data: {data_load_time:>5} lr: {round(float(scheduler.get_last_lr()[0]), 8):<6}"
f"{Color.cyan}step: {train_state.step:>2} {Color.green}loss: {round(train_state.current_loss,4):>7}"
f" {Color.reset}iter: {Color.blue}{curr_iter_time:>7}{Color.reset}"
f" data: {Color.blue}{data_load_time:>5} {Color.reset}"
f"lr: {Color.yellow}{round(float(scheduler.get_last_lr()[0]), 8):<6}{Color.reset}"
)
scheduler.step()

Expand Down

0 comments on commit 34a25d8

Please sign in to comment.