diff --git a/torchtrain/utils.py b/torchtrain/utils.py index 9ae71cae..823e8843 100644 --- a/torchtrain/utils.py +++ b/torchtrain/utils.py @@ -1,6 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. +from dataclasses import dataclass from typing import Union import torch @@ -17,3 +18,37 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float: def dist_mean(x: Union[int, float], mesh: DeviceMesh) -> float: tensor = torch.tensor(x).cuda() return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh) + + +@dataclass +class Color: + black = "\033[30m" + red = "\033[31m" + green = "\033[32m" + yellow = "\033[33m" + blue = "\033[34m" + magenta = "\033[35m" + cyan = "\033[36m" + white = "\033[37m" + reset = "\033[39m" + + +@dataclass +class Background: + black = "\033[40m" + red = "\033[41m" + green = "\033[42m" + yellow = "\033[43m" + blue = "\033[44m" + magenta = "\033[45m" + cyan = "\033[46m" + white = "\033[47m" + reset = "\033[49m" + + +@dataclass +class Style: + bright = "\033[1m" + dim = "\033[2m" + normal = "\033[22m" + reset = "\033[0m" diff --git a/train.py b/train.py index 5ce5de37..95d42226 100644 --- a/train.py +++ b/train.py @@ -2,6 +2,7 @@ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. import os + from dataclasses import dataclass, field from timeit import default_timer as timer from typing import Any, Dict, List @@ -27,7 +28,11 @@ from torchtrain.parallelisms import models_parallelize_fns, ParallelDims from torchtrain.profiling import maybe_run_profiler -from torchtrain.utils import dist_max, dist_mean +from torchtrain.utils import Color, dist_max, dist_mean + +_is_local_logging = True +if "SLURM_JOB_ID" in os.environ: + _is_local_logging = False @dataclass @@ -119,9 +124,16 @@ def main(job_config: JobConfig): # log model size model_param_count = get_num_params(model) - rank0_log( - f"Model {model_name} {job_config.model.flavor} size: {model_param_count:,} total parameters" - ) + if _is_local_logging: + rank0_log( + f"{Color.blue}Model {model_name} {job_config.model.flavor} {Color.red}size: {model_param_count:,}" + f" total parameters{Color.reset}" + ) + else: + rank0_log( + f"{model_name} {job_config.model.flavor} size: {model_param_count:,} total parameters" + ) + gpu_metrics = GPUMemoryMonitor("cuda") rank0_log(f"GPU memory usage: {gpu_metrics}") @@ -268,10 +280,21 @@ def main(job_config: JobConfig): nwords_since_last_log = 0 time_last_log = timer() - rank0_log( - f"step: {train_state.step:>2} loss: {round(train_state.current_loss,4):>7}" - f" iter: {curr_iter_time:>7} data: {data_load_time:>5} lr: {round(float(scheduler.get_last_lr()[0]), 8):<6}" - ) + if _is_local_logging: + rank0_log( + f"{Color.cyan}step: {train_state.step:>2} {Color.green}loss: {round(train_state.current_loss,4):>7}" + f" {Color.reset}iter: {Color.blue}{curr_iter_time:>7}{Color.reset}" + f" data: {Color.blue}{data_load_time:>5} {Color.reset}" + f"lr: {Color.yellow}{round(float(scheduler.get_last_lr()[0]), 8):<6}{Color.reset}" + ) + else: + rank0_log( + f"step: {train_state.step:>2} loss: {round(train_state.current_loss,4):>7}" + f" iter: {curr_iter_time:>7}" + f" data: {data_load_time:>5} " + f"lr: {round(float(scheduler.get_last_lr()[0]), 8):<6}" + ) + scheduler.step() checkpoint.save(