From 80207e6edf0894e0cacd0f45e8384ccff10e381a Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 14 Feb 2024 18:16:16 -0800 Subject: [PATCH] add TensorBoard logging with loss and wps ghstack-source-id: d0828f16c06747a5af2586630e5205bf786de1c4 Pull Request resolved: https://github.com/pytorch-labs/torchtrain/pull/57 --- README.md | 18 +++++++++ requirements.txt | 1 + torchtrain/metrics.py | 44 +++++++++++++++++++++ torchtrain/parallelisms/__init__.py | 5 +++ torchtrain/train_configs/train_config.toml | 4 ++ torchtrain/utils.py | 19 +++++++++ train.py | 46 +++++++++++++++++++++- 7 files changed, 135 insertions(+), 2 deletions(-) create mode 100644 torchtrain/utils.py diff --git a/README.md b/README.md index 9c9dbb8e7..4170bab46 100644 --- a/README.md +++ b/README.md @@ -22,3 +22,21 @@ run the llama debug model locally to verify the setup is correct: ``` ./run_llama_train.sh ``` + +# TensorBoard + +To visualize training metrics on TensorBoard: + +1. (by default) set `enable_tensorboard = true` in `torchtrain/train_configs/train_config.toml` + +2. set up SSH tunneling +``` +ssh -L 6006:127.0.0.1:6006 [username]@[hostname] +``` + +3. then in the torchtrain repo +``` +tensorboard --logdir=./torchtrain/outputs/tb +``` + +4. go to the URL it provides OR to http://localhost:6006/ diff --git a/requirements.txt b/requirements.txt index 9bc33ca39..8e089a3e2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ torch >= 2.2.0.dev sentencepiece datasets tomli >= 1.1.0 ; python_version < "3.11" +tensorboard diff --git a/torchtrain/metrics.py b/torchtrain/metrics.py index 64888b373..092cafae0 100644 --- a/torchtrain/metrics.py +++ b/torchtrain/metrics.py @@ -4,10 +4,17 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved +import os from collections import namedtuple +from datetime import datetime +from typing import Any, Dict, Optional import torch import torch.nn as nn +from torch.utils.tensorboard import SummaryWriter + +from torchtrain.logging_utils import rank0_log +from torchtrain.profiling import get_config_from_toml _gb_in_bytes = 1024 * 1024 * 1024 _mb_in_bytes = 1024 * 1024 @@ -187,3 +194,40 @@ def get_num_params(model: nn.Module, only_trainable: bool = False) -> int: param_list = [p for p in param_list if p.requires_grad] unique_params = {p.data_ptr(): p for p in param_list}.values() return sum(p.numel() for p in unique_params) + + +class MetricLogger: + def __init__(self, log_dir, tag, enable_tb): + self.tag = tag + self.writer: Optional[SummaryWriter] = None + if enable_tb: + self.writer = SummaryWriter(log_dir, max_queue=1000) + + def log(self, metrics: Dict[str, Any], step: int): + if self.writer is not None: + for k, v in metrics.items(): + tag = k if self.tag is None else f"{self.tag}/{k}" + self.writer.add_scalar(tag, v, step) + + def close(self): + if self.writer is not None: + self.writer.close() + + +def build_metric_logger(tag: Optional[str] = None): + config = get_config_from_toml() + + dump_dir = config["global"]["dump_folder"] + save_tb_folder = config["metrics"]["save_tb_folder"] + # since we don't have run id yet, use current minute as identifier + datetime_str = datetime.now().strftime("%Y%m%d-%H%M") + log_dir = os.path.join(dump_dir, save_tb_folder, datetime_str) + + enable_tb = config["metrics"].get("enable_tensorboard", False) + if enable_tb: + rank0_log( + f"Metrics logging active. Tensorboard logs will be saved at {log_dir}." + ) + + rank_str = f"rank_{torch.distributed.get_rank()}" + return MetricLogger(os.path.join(log_dir, rank_str), tag, enable_tb) diff --git a/torchtrain/parallelisms/__init__.py b/torchtrain/parallelisms/__init__.py index 464397fac..1c9a56412 100644 --- a/torchtrain/parallelisms/__init__.py +++ b/torchtrain/parallelisms/__init__.py @@ -3,6 +3,7 @@ import logging from dataclasses import dataclass +from functools import cached_property from torch.distributed.device_mesh import init_device_mesh @@ -61,3 +62,7 @@ def sp_enabled(self): @property def pp_enabled(self): return self.pp > 1 + + @cached_property + def model_parallel_size(self): + return self.sp * self.pp diff --git a/torchtrain/train_configs/train_config.toml b/torchtrain/train_configs/train_config.toml index a3b02917e..da0161e05 100644 --- a/torchtrain/train_configs/train_config.toml +++ b/torchtrain/train_configs/train_config.toml @@ -7,3 +7,7 @@ run_profiler = true save_traces_folder = "profiling/traces" # profiling frequency - example: 10 means every 10th iter will be profiled profile_every_x_iter = 10 + +[metrics] +enable_tensorboard = true +save_tb_folder = "tb" diff --git a/torchtrain/utils.py b/torchtrain/utils.py new file mode 100644 index 000000000..9ae71caef --- /dev/null +++ b/torchtrain/utils.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from typing import Union + +import torch +import torch.distributed._functional_collectives as funcol +import torch.distributed.distributed_c10d as c10d +from torch.distributed.device_mesh import DeviceMesh + + +def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float: + tensor = torch.tensor(x).cuda() + return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh) + + +def dist_mean(x: Union[int, float], mesh: DeviceMesh) -> float: + tensor = torch.tensor(x).cuda() + return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh) diff --git a/train.py b/train.py index a0a0f891e..e922acf3b 100644 --- a/train.py +++ b/train.py @@ -4,8 +4,11 @@ import argparse import os from dataclasses import dataclass, field +from timeit import default_timer as timer from typing import Any, Dict, List, Union +import numpy as np + # torch imports import torch import torch.nn.functional as F @@ -18,12 +21,13 @@ from torchtrain.datasets import create_tokenizer, dataloader_fn from torchtrain.logging_utils import init_logger, rank0_log from torchtrain.lr_scheduling import get_lr_scheduler -from torchtrain.metrics import get_num_params, GPUMemoryMonitor +from torchtrain.metrics import build_metric_logger, get_num_params, GPUMemoryMonitor from torchtrain.models import model_name_to_cls, model_name_to_tokenizer, models_config from torchtrain.parallelisms import models_parallelize_fns, ParallelDims from torchtrain.profiling import maybe_run_profiler +from torchtrain.utils import dist_max, dist_mean @dataclass @@ -126,7 +130,7 @@ def main(args): scaler = build_grad_scaler(model) - # TODO: add metrics + metric_logger = build_metric_logger() # torch.compile model for improved performance if args.compile: @@ -156,6 +160,10 @@ def main(args): with maybe_run_profiler() as torch_profiler: checkpoint.reset() + # variables used to keep info for metrics logging + losses_since_last_log: List[float] = [] + nwords_since_last_log = 0 + time_last_log = timer() while train_state.step < args.steps or args.steps == -1: train_state.step += 1 # get batch @@ -163,6 +171,7 @@ def main(args): input_ids, labels = batch input_ids = input_ids.cuda() labels = labels.cuda() + nwords_since_last_log += labels.numel() optimizer.zero_grad() @@ -194,6 +203,32 @@ def main(args): train_state.current_loss = loss.item() train_state.losses.append(train_state.current_loss) + losses_since_last_log.append(train_state.current_loss) + + # log metrics + if (train_state.step - 1) % args.log_freq == 0: + avg_loss, max_loss = np.mean(losses_since_last_log), np.max( + losses_since_last_log + ) + global_avg_loss, global_max_loss = dist_mean( + avg_loss, world_mesh + ), dist_max(max_loss, world_mesh) + + time_delta = timer() - time_last_log + wps = nwords_since_last_log / ( + time_delta * parallel_dims.model_parallel_size + ) + + metrics = { + "global_avg_loss": global_avg_loss, + "global_max_loss": global_max_loss, + "wps": wps, + } + metric_logger.log(metrics, step=train_state.step) + + losses_since_last_log.clear() + nwords_since_last_log = 0 + time_last_log = timer() rank0_log( f"step: {train_state.step}, current loss: {train_state.current_loss}, lr: {scheduler.get_last_lr()}" @@ -202,6 +237,7 @@ def main(args): checkpoint.save(train_state.step, force=(train_state.step == args.steps)) + metric_logger.close() rank0_log(f"{gpu_metrics.get_current_stats()}") @@ -294,6 +330,12 @@ def main(args): "is an empty string, checkpointing is disabled." ), ) + parser.add_argument( + "--log_freq", + type=int, + default=10, + help="how often to log metrics to TensorBoard", + ) args = parser.parse_args() main(args)