Skip to content

Commit

Permalink
Add color to console output if local logging, auto avoid color loggin…
Browse files Browse the repository at this point in the history
…g on slurm (#93)

This PR adds the ability to do colored console outputs in order to
highlight the training data outputs.
It also adds a check to not use this color formatting on slurm, where it
will add 33= instead of the color if not avoided.

Note that I've just added some color to highlight the main training
data. Users that fork/clone can use it to enhance their outputs as
desired.

<img width="1372" alt="Screenshot 2024-02-26 at 10 20 15 PM"
src="https://github.com/pytorch/torchtrain/assets/46302957/44849821-1677-40bf-896c-39344cd661d6">


Note that on slurm it remains plain:
<img width="847" alt="Screenshot 2024-02-26 at 10 46 24 PM"
src="https://github.com/pytorch/torchtrain/assets/46302957/172eaa58-4f5c-48f5-8ec1-bc349e3e82f2">

if you dont' check this, then it would otherwise look like this (this
does not happen with this PR, just showing if we didn't check and credit
to Yifu for noting this would be an issue):
<img width="847" alt="Screenshot 2024-02-26 at 10 39 23 PM"
src="https://github.com/pytorch/torchtrain/assets/46302957/4a87fb9a-dd3a-417c-a29e-286ded069358">
  • Loading branch information
lessw2020 authored Feb 27, 2024
1 parent 5dec536 commit 8671c91
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 8 deletions.
35 changes: 35 additions & 0 deletions torchtrain/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from dataclasses import dataclass
from typing import Union

import torch
Expand All @@ -17,3 +18,37 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float:
def dist_mean(x: Union[int, float], mesh: DeviceMesh) -> float:
tensor = torch.tensor(x).cuda()
return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh)


@dataclass
class Color:
black = "\033[30m"
red = "\033[31m"
green = "\033[32m"
yellow = "\033[33m"
blue = "\033[34m"
magenta = "\033[35m"
cyan = "\033[36m"
white = "\033[37m"
reset = "\033[39m"


@dataclass
class Background:
black = "\033[40m"
red = "\033[41m"
green = "\033[42m"
yellow = "\033[43m"
blue = "\033[44m"
magenta = "\033[45m"
cyan = "\033[46m"
white = "\033[47m"
reset = "\033[49m"


@dataclass
class Style:
bright = "\033[1m"
dim = "\033[2m"
normal = "\033[22m"
reset = "\033[0m"
39 changes: 31 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import os

from dataclasses import dataclass, field
from timeit import default_timer as timer
from typing import Any, Dict, List
Expand All @@ -27,7 +28,11 @@
from torchtrain.parallelisms import models_parallelize_fns, ParallelDims

from torchtrain.profiling import maybe_run_profiler
from torchtrain.utils import dist_max, dist_mean
from torchtrain.utils import Color, dist_max, dist_mean

_is_local_logging = True
if "SLURM_JOB_ID" in os.environ:
_is_local_logging = False


@dataclass
Expand Down Expand Up @@ -119,9 +124,16 @@ def main(job_config: JobConfig):

# log model size
model_param_count = get_num_params(model)
rank0_log(
f"Model {model_name} {job_config.model.flavor} size: {model_param_count:,} total parameters"
)
if _is_local_logging:
rank0_log(
f"{Color.blue}Model {model_name} {job_config.model.flavor} {Color.red}size: {model_param_count:,}"
f" total parameters{Color.reset}"
)
else:
rank0_log(
f"{model_name} {job_config.model.flavor} size: {model_param_count:,} total parameters"
)

gpu_metrics = GPUMemoryMonitor("cuda")
rank0_log(f"GPU memory usage: {gpu_metrics}")

Expand Down Expand Up @@ -268,10 +280,21 @@ def main(job_config: JobConfig):
nwords_since_last_log = 0
time_last_log = timer()

rank0_log(
f"step: {train_state.step:>2} loss: {round(train_state.current_loss,4):>7}"
f" iter: {curr_iter_time:>7} data: {data_load_time:>5} lr: {round(float(scheduler.get_last_lr()[0]), 8):<6}"
)
if _is_local_logging:
rank0_log(
f"{Color.cyan}step: {train_state.step:>2} {Color.green}loss: {round(train_state.current_loss,4):>7}"
f" {Color.reset}iter: {Color.blue}{curr_iter_time:>7}{Color.reset}"
f" data: {Color.blue}{data_load_time:>5} {Color.reset}"
f"lr: {Color.yellow}{round(float(scheduler.get_last_lr()[0]), 8):<6}{Color.reset}"
)
else:
rank0_log(
f"step: {train_state.step:>2} loss: {round(train_state.current_loss,4):>7}"
f" iter: {curr_iter_time:>7}"
f" data: {data_load_time:>5} "
f"lr: {round(float(scheduler.get_last_lr()[0]), 8):<6}"
)

scheduler.step()

checkpoint.save(
Expand Down

0 comments on commit 8671c91

Please sign in to comment.