From 8671c913832a6ab351e8f0db60c749bb4d70f3b4 Mon Sep 17 00:00:00 2001 From: Less Wright Date: Tue, 27 Feb 2024 10:41:40 -0800 Subject: [PATCH] Add color to console output if local logging, auto avoid color logging on slurm (#93) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds the ability to do colored console outputs in order to highlight the training data outputs. It also adds a check to not use this color formatting on slurm, where it will add 33= instead of the color if not avoided. Note that I've just added some color to highlight the main training data. Users that fork/clone can use it to enhance their outputs as desired. Screenshot 2024-02-26 at 10 20 15 PM Note that on slurm it remains plain: Screenshot 2024-02-26 at 10 46 24 PM if you dont' check this, then it would otherwise look like this (this does not happen with this PR, just showing if we didn't check and credit to Yifu for noting this would be an issue): Screenshot 2024-02-26 at 10 39 23 PM --- torchtrain/utils.py | 35 +++++++++++++++++++++++++++++++++++ train.py | 39 +++++++++++++++++++++++++++++++-------- 2 files changed, 66 insertions(+), 8 deletions(-) diff --git a/torchtrain/utils.py b/torchtrain/utils.py index 9ae71cae..823e8843 100644 --- a/torchtrain/utils.py +++ b/torchtrain/utils.py @@ -1,6 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. +from dataclasses import dataclass from typing import Union import torch @@ -17,3 +18,37 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float: def dist_mean(x: Union[int, float], mesh: DeviceMesh) -> float: tensor = torch.tensor(x).cuda() return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh) + + +@dataclass +class Color: + black = "\033[30m" + red = "\033[31m" + green = "\033[32m" + yellow = "\033[33m" + blue = "\033[34m" + magenta = "\033[35m" + cyan = "\033[36m" + white = "\033[37m" + reset = "\033[39m" + + +@dataclass +class Background: + black = "\033[40m" + red = "\033[41m" + green = "\033[42m" + yellow = "\033[43m" + blue = "\033[44m" + magenta = "\033[45m" + cyan = "\033[46m" + white = "\033[47m" + reset = "\033[49m" + + +@dataclass +class Style: + bright = "\033[1m" + dim = "\033[2m" + normal = "\033[22m" + reset = "\033[0m" diff --git a/train.py b/train.py index 5ce5de37..95d42226 100644 --- a/train.py +++ b/train.py @@ -2,6 +2,7 @@ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. import os + from dataclasses import dataclass, field from timeit import default_timer as timer from typing import Any, Dict, List @@ -27,7 +28,11 @@ from torchtrain.parallelisms import models_parallelize_fns, ParallelDims from torchtrain.profiling import maybe_run_profiler -from torchtrain.utils import dist_max, dist_mean +from torchtrain.utils import Color, dist_max, dist_mean + +_is_local_logging = True +if "SLURM_JOB_ID" in os.environ: + _is_local_logging = False @dataclass @@ -119,9 +124,16 @@ def main(job_config: JobConfig): # log model size model_param_count = get_num_params(model) - rank0_log( - f"Model {model_name} {job_config.model.flavor} size: {model_param_count:,} total parameters" - ) + if _is_local_logging: + rank0_log( + f"{Color.blue}Model {model_name} {job_config.model.flavor} {Color.red}size: {model_param_count:,}" + f" total parameters{Color.reset}" + ) + else: + rank0_log( + f"{model_name} {job_config.model.flavor} size: {model_param_count:,} total parameters" + ) + gpu_metrics = GPUMemoryMonitor("cuda") rank0_log(f"GPU memory usage: {gpu_metrics}") @@ -268,10 +280,21 @@ def main(job_config: JobConfig): nwords_since_last_log = 0 time_last_log = timer() - rank0_log( - f"step: {train_state.step:>2} loss: {round(train_state.current_loss,4):>7}" - f" iter: {curr_iter_time:>7} data: {data_load_time:>5} lr: {round(float(scheduler.get_last_lr()[0]), 8):<6}" - ) + if _is_local_logging: + rank0_log( + f"{Color.cyan}step: {train_state.step:>2} {Color.green}loss: {round(train_state.current_loss,4):>7}" + f" {Color.reset}iter: {Color.blue}{curr_iter_time:>7}{Color.reset}" + f" data: {Color.blue}{data_load_time:>5} {Color.reset}" + f"lr: {Color.yellow}{round(float(scheduler.get_last_lr()[0]), 8):<6}{Color.reset}" + ) + else: + rank0_log( + f"step: {train_state.step:>2} loss: {round(train_state.current_loss,4):>7}" + f" iter: {curr_iter_time:>7}" + f" data: {data_load_time:>5} " + f"lr: {round(float(scheduler.get_last_lr()[0]), 8):<6}" + ) + scheduler.step() checkpoint.save(