Skip to content

Commit

Permalink
avoid using color console in slurm
Browse files Browse the repository at this point in the history
  • Loading branch information
lessw2020 committed Feb 27, 2024
1 parent 34a25d8 commit d31fe31
Showing 1 changed file with 30 additions and 10 deletions.
40 changes: 30 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import os

from dataclasses import dataclass, field
from timeit import default_timer as timer
from typing import Any, Dict, List
Expand Down Expand Up @@ -29,6 +30,10 @@
from torchtrain.profiling import maybe_run_profiler
from torchtrain.utils import Color, dist_max, dist_mean

_is_local_logging = True
if "SLURM_JOB_ID" in os.environ:
_is_local_logging = False


@dataclass
class TrainState:
Expand Down Expand Up @@ -119,10 +124,16 @@ def main(job_config: JobConfig):

# log model size
model_param_count = get_num_params(model)
rank0_log(
f"{Color.blue}Model {model_name} {job_config.model.flavor} {Color.red}size: {model_param_count:,}"
f" total parameters{Color.reset}"
)
if _is_local_logging:
rank0_log(
f"{Color.blue}Model {model_name} {job_config.model.flavor} {Color.red}size: {model_param_count:,}"
f" total parameters{Color.reset}"
)
else:
rank0_log(
f"{model_name} {job_config.model.flavor} size: {model_param_count:,} total parameters"
)

gpu_metrics = GPUMemoryMonitor("cuda")
rank0_log(f"GPU memory usage: {gpu_metrics}")

Expand Down Expand Up @@ -269,12 +280,21 @@ def main(job_config: JobConfig):
nwords_since_last_log = 0
time_last_log = timer()

rank0_log(
f"{Color.cyan}step: {train_state.step:>2} {Color.green}loss: {round(train_state.current_loss,4):>7}"
f" {Color.reset}iter: {Color.blue}{curr_iter_time:>7}{Color.reset}"
f" data: {Color.blue}{data_load_time:>5} {Color.reset}"
f"lr: {Color.yellow}{round(float(scheduler.get_last_lr()[0]), 8):<6}{Color.reset}"
)
if _is_local_logging:
rank0_log(
f"{Color.cyan}step: {train_state.step:>2} {Color.green}loss: {round(train_state.current_loss,4):>7}"
f" {Color.reset}iter: {Color.blue}{curr_iter_time:>7}{Color.reset}"
f" data: {Color.blue}{data_load_time:>5} {Color.reset}"
f"lr: {Color.yellow}{round(float(scheduler.get_last_lr()[0]), 8):<6}{Color.reset}"
)
else:
rank0_log(
f"step: {train_state.step:>2} loss: {round(train_state.current_loss,4):>7}"
f" iter: {curr_iter_time:>7}"
f" data: {data_load_time:>5} "
f"lr: {round(float(scheduler.get_last_lr()[0]), 8):<6}"
)

scheduler.step()

checkpoint.save(
Expand Down

0 comments on commit d31fe31

Please sign in to comment.