Skip to content

Commit

Permalink
add TensorBoard logging with loss and wps
Browse files Browse the repository at this point in the history
ghstack-source-id: d0828f16c06747a5af2586630e5205bf786de1c4
Pull Request resolved: pytorch#57
  • Loading branch information
tianyu-l committed Feb 15, 2024
1 parent cde8b43 commit d61e403
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 2 deletions.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,21 @@ run the llama debug model locally to verify the setup is correct:
```
./run_llama_train.sh
```

# TensorBoard

To visualize training metrics on TensorBoard:

1. (by default) set `enable_tensorboard = true` in `torchtrain/train_configs/train_config.toml`

2. set up SSH tunneling
```
ssh -L 6006:127.0.0.1:6006 [username]@[hostname]
```

3. then in the torchtrain repo
```
tensorboard --logdir=./torchtrain/outputs/tb
```

4. go to the URL it provides OR to http://localhost:6006/
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ torch >= 2.2.0.dev
sentencepiece
datasets
tomli >= 1.1.0 ; python_version < "3.11"
tensorboard
44 changes: 44 additions & 0 deletions torchtrain/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved

import os
from collections import namedtuple
from datetime import datetime
from typing import Any, Dict, Optional

import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter

from torchtrain.logging_utils import rank0_log
from torchtrain.profiling import get_config_from_toml

_gb_in_bytes = 1024 * 1024 * 1024
_mb_in_bytes = 1024 * 1024
Expand Down Expand Up @@ -187,3 +194,40 @@ def get_num_params(model: nn.Module, only_trainable: bool = False) -> int:
param_list = [p for p in param_list if p.requires_grad]
unique_params = {p.data_ptr(): p for p in param_list}.values()
return sum(p.numel() for p in unique_params)


class MetricLogger:
def __init__(self, log_dir, tag, enable_tb):
self.tag = tag
self.writer: Optional[SummaryWriter] = None
if enable_tb:
self.writer = SummaryWriter(log_dir, max_queue=1000)

def log(self, metrics: Dict[str, Any], step: int):
if self.writer is not None:
for k, v in metrics.items():
tag = k if self.tag is None else f"{self.tag}/{k}"
self.writer.add_scalar(tag, v, step)

def close(self):
if self.writer is not None:
self.writer.close()


def build_metric_logger(tag: Optional[str] = None):
config = get_config_from_toml()

dump_dir = config["global"]["dump_folder"]
save_tb_folder = config["metrics"]["save_tb_folder"]
# since we don't have run id yet, use current minute as identifier
datetime_str = datetime.now().strftime("%Y%m%d-%H%M")
log_dir = os.path.join(dump_dir, save_tb_folder, datetime_str)

enable_tb = config["metrics"].get("enable_tensorboard", False)
if enable_tb:
rank0_log(
f"Metrics logging active. Tensorboard logs will be saved at {log_dir}."
)

rank_str = f"rank_{torch.distributed.get_rank()}"
return MetricLogger(os.path.join(log_dir, rank_str), tag, enable_tb)
5 changes: 5 additions & 0 deletions torchtrain/parallelisms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import logging
from dataclasses import dataclass
from functools import cached_property

from torch.distributed.device_mesh import init_device_mesh

Expand Down Expand Up @@ -61,3 +62,7 @@ def sp_enabled(self):
@property
def pp_enabled(self):
return self.pp > 1

@cached_property
def model_parallel_size(self):
return self.sp * self.pp
4 changes: 4 additions & 0 deletions torchtrain/train_configs/train_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ run_profiler = true
save_traces_folder = "profiling/traces"
# profiling frequency - example: 10 means every 10th iter will be profiled
profile_every_x_iter = 10

[metrics]
enable_tensorboard = true
save_tb_folder = "tb"
19 changes: 19 additions & 0 deletions torchtrain/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from typing import Union

import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.distributed_c10d as c10d
from torch.distributed.device_mesh import DeviceMesh


def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float:
tensor = torch.tensor(x).cuda()
return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh)


def dist_mean(x: Union[int, float], mesh: DeviceMesh) -> float:
tensor = torch.tensor(x).cuda()
return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh)
46 changes: 44 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
import argparse
import os
from dataclasses import dataclass, field
from timeit import default_timer as timer
from typing import Any, Dict, List, Union

import numpy as np

# torch imports
import torch
import torch.nn.functional as F
Expand All @@ -18,12 +21,13 @@
from torchtrain.datasets import create_tokenizer, dataloader_fn
from torchtrain.logging_utils import init_logger, rank0_log
from torchtrain.lr_scheduling import get_lr_scheduler
from torchtrain.metrics import get_num_params, GPUMemoryMonitor
from torchtrain.metrics import build_metric_logger, get_num_params, GPUMemoryMonitor

from torchtrain.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtrain.parallelisms import models_parallelize_fns, ParallelDims

from torchtrain.profiling import maybe_run_profiler
from torchtrain.utils import dist_max, dist_mean


@dataclass
Expand Down Expand Up @@ -126,7 +130,7 @@ def main(args):

scaler = build_grad_scaler(model)

# TODO: add metrics
metric_logger = build_metric_logger()

# torch.compile model for improved performance
if args.compile:
Expand Down Expand Up @@ -156,13 +160,18 @@ def main(args):

with maybe_run_profiler() as torch_profiler:
checkpoint.reset()
# variables used to keep info for metrics logging
losses_since_last_log: List[float] = []
nwords_since_last_log = 0
time_last_log = timer()
while train_state.step < args.steps or args.steps == -1:
train_state.step += 1
# get batch
batch = next(iter(data_loader))
input_ids, labels = batch
input_ids = input_ids.cuda()
labels = labels.cuda()
nwords_since_last_log += labels.numel()

optimizer.zero_grad()

Expand Down Expand Up @@ -194,6 +203,32 @@ def main(args):

train_state.current_loss = loss.item()
train_state.losses.append(train_state.current_loss)
losses_since_last_log.append(train_state.current_loss)

# log metrics
if (train_state.step - 1) % args.log_freq == 0:
avg_loss, max_loss = np.mean(losses_since_last_log), np.max(
losses_since_last_log
)
global_avg_loss, global_max_loss = dist_mean(
avg_loss, world_mesh
), dist_max(max_loss, world_mesh)

time_delta = timer() - time_last_log
wps = nwords_since_last_log / (
time_delta * parallel_dims.model_parallel_size
)

metrics = {
"global_avg_loss": global_avg_loss,
"global_max_loss": global_max_loss,
"wps": wps,
}
metric_logger.log(metrics, step=train_state.step)

losses_since_last_log.clear()
nwords_since_last_log = 0
time_last_log = timer()

rank0_log(
f"step: {train_state.step}, current loss: {train_state.current_loss}, lr: {scheduler.get_last_lr()}"
Expand All @@ -202,6 +237,7 @@ def main(args):

checkpoint.save(train_state.step, force=(train_state.step == args.steps))

metric_logger.close()
rank0_log(f"{gpu_metrics.get_current_stats()}")


Expand Down Expand Up @@ -294,6 +330,12 @@ def main(args):
"is an empty string, checkpointing is disabled."
),
)
parser.add_argument(
"--log_freq",
type=int,
default=10,
help="how often to log metrics to TensorBoard",
)

args = parser.parse_args()
main(args)

0 comments on commit d61e403

Please sign in to comment.