From 34a25d80db1d440a582e9823ab5ad17b8dff340b Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Mon, 26 Feb 2024 22:20:31 -0800 Subject: [PATCH] linting --- train.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index 5ce5de37..29eee795 100644 --- a/train.py +++ b/train.py @@ -27,7 +27,7 @@ from torchtrain.parallelisms import models_parallelize_fns, ParallelDims from torchtrain.profiling import maybe_run_profiler -from torchtrain.utils import dist_max, dist_mean +from torchtrain.utils import Color, dist_max, dist_mean @dataclass @@ -120,7 +120,8 @@ def main(job_config: JobConfig): # log model size model_param_count = get_num_params(model) rank0_log( - f"Model {model_name} {job_config.model.flavor} size: {model_param_count:,} total parameters" + f"{Color.blue}Model {model_name} {job_config.model.flavor} {Color.red}size: {model_param_count:,}" + f" total parameters{Color.reset}" ) gpu_metrics = GPUMemoryMonitor("cuda") rank0_log(f"GPU memory usage: {gpu_metrics}") @@ -269,8 +270,10 @@ def main(job_config: JobConfig): time_last_log = timer() rank0_log( - f"step: {train_state.step:>2} loss: {round(train_state.current_loss,4):>7}" - f" iter: {curr_iter_time:>7} data: {data_load_time:>5} lr: {round(float(scheduler.get_last_lr()[0]), 8):<6}" + f"{Color.cyan}step: {train_state.step:>2} {Color.green}loss: {round(train_state.current_loss,4):>7}" + f" {Color.reset}iter: {Color.blue}{curr_iter_time:>7}{Color.reset}" + f" data: {Color.blue}{data_load_time:>5} {Color.reset}" + f"lr: {Color.yellow}{round(float(scheduler.get_last_lr()[0]), 8):<6}{Color.reset}" ) scheduler.step()