Skip to content

Commit

Permalink
add model num params display, gpu memory metrics (#56)
Browse files Browse the repository at this point in the history
This PR is the start of adding perf related metrics. 
1 - This PR adds function for logging the total num of unique model
params, with option for only counting trainable params as well. (for
future peft/qlora type work).
2 - logs it with comma formatted logging and model name ala:
<img width="716" alt="Screenshot 2024-02-12 at 4 12 22 PM"
src="https://github.com/pytorch-labs/torchtrain/assets/46302957/8eb48870-ab1e-4b70-9159-92864ff6c0e5">

this helps de-mistify for example the size of our debug model as well:
<img width="716" alt="Screenshot 2024-02-12 at 4 10 17 PM"
src="https://github.com/pytorch-labs/torchtrain/assets/46302957/77475306-54bc-48a6-bf28-9c9a542577fd">

**additional updates** - added in gpu mem tracking. We want to show the
user peak memory stats, as well as monitor and alert for any
cudacachealloc retries which are a perf hindrance.

Thus, added class GPUMemoryMonitor:
usage:
1 - instantiate
<img width="1329" alt="Screenshot 2024-02-13 at 9 32 11 AM"
src="https://github.com/pytorch-labs/torchtrain/assets/46302957/95610386-6fde-47bb-bbdc-bb7c399c5895">

2 - start of training = start_monitoring()
3 - end of training = stop_monitoring()
4 - show results = get_peak_stats_str() and rank0_log it.
<img width="1074" alt="Screenshot 2024-02-13 at 9 12 45 AM"
src="https://github.com/pytorch-labs/torchtrain/assets/46302957/b6c7c854-7d83-436a-bea9-a67109422381">
  • Loading branch information
lessw2020 authored Feb 15, 2024
1 parent 076edda commit 40c93e9
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 2 deletions.
4 changes: 2 additions & 2 deletions run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ CHECKPOINT_FOLDER=${CHECKPOINT_FOLDER:-""}
# Please adjust this to a longer interval period. The unit of measurement is in steps.
CHECKPOINT_INTERVAL=${CHECKPOINT_INTERVAL:-5}

torchrun --nproc_per_node=${NGPU} \
torchrun --nproc_per_node=${NGPU} --rdzv_endpoint="localhost:5972" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --steps 10 \
--model ${MODEL} --model_conf ${MODEL_CONF} \
--pp_degree ${PP} --sp_degree ${SP} --dp_degree ${DP} \
--compile
--compile \
--checkpoint-folder=${CHECKPOINT_FOLDER} --checkpoint-interval=${CHECKPOINT_INTERVAL}
189 changes: 189 additions & 0 deletions torchtrain/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved

from collections import namedtuple

import torch
import torch.nn as nn

_gb_in_bytes = 1024 * 1024 * 1024
_mb_in_bytes = 1024 * 1024


def format_to_gb(item, precision=4):
"""quick function to format numbers to gigabyte and round to (default) 4 digit precision"""
metric_num = item / _gb_in_bytes
metric_num = round(metric_num, ndigits=precision)
return metric_num


def convert_to_gpu_pct(value, total_gpu_memory):
return round(100 * (value / total_gpu_memory), 2)


# named tuple for passing memory stats (as % of device capacity) for Tensorboard logging
GPUMemStats = namedtuple(
"GPUMemStats",
[
"allocated_curr",
"allocated_peak",
"reserved_curr",
"reserved_peak",
"active_curr",
"active_peak",
"num_retries",
],
)


class GPUMemoryMonitor:
"""
Class to monitor GPU memory usage
"""

def __init__(self, device: str = "cuda:0"):
self.device = torch.device(device) # device object
self.device_name = torch.cuda.get_device_name(self.device)
self.device_index = torch.cuda.current_device()
self.device_capacity = torch.cuda.get_device_properties(
self.device
).total_memory
self.device_capacity_gb = format_to_gb(self.device_capacity)
self.num_retries = 0
self.num_ooms = 0
self.peak_active_memory = 0
self.peak_allocated_memory = 0
self.peak_reserved_memory = 0
self.curr_reserved_memory = 0

self.device_reserved_memory_usage = 0
self.device_reserved_memory_gb = 0
self.device_reserved_memory_pct = 0

self.device_active_memory_usage = 0
self.device_active_memory_gb = 0
self.device_active_memory_pct = 0

# current stats
self.device_alloc_memory_usage = torch.cuda.memory_allocated(self.device)
self.device_alloc_memory_gb = format_to_gb(self.device_alloc_memory_usage)
self.device_alloc_memory_pct = convert_to_gpu_pct(
self.device_alloc_memory_usage, self.device_capacity
)

# reset stats, clear cache
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()

def get_pct_memory(self, memory_num):
pct_memory = memory_num / self.device_capacity
pct_memory = round(100 * (pct_memory), 2)
return pct_memory

def get_gb_memory(self, memory_num):
gb_memory = memory_num / _gb_in_bytes
gb_memory = round(gb_memory, 2)
return gb_memory

def get_current_stats(self, return_data: bool = False):
"""
get the CudaCachingAllocator stats for the current device
return_data: bool, if True, return the data as a named tuple
"""
curr_mem = torch.cuda.memory_stats(self.device)

self.device_alloc_memory_usage = curr_mem["allocated_bytes.all.current"]
self.device_alloc_memory_gb = format_to_gb(self.device_alloc_memory_usage)
self.device_alloc_memory_pct = convert_to_gpu_pct(
self.device_alloc_memory_usage, self.device_capacity
)

self.device_reserved_memory_usage = curr_mem["reserved_bytes.all.current"]
self.device_reserved_memory_gb = format_to_gb(self.device_reserved_memory_usage)
self.device_reserved_memory_pct = convert_to_gpu_pct(
self.device_reserved_memory_usage, self.device_capacity
)

self.device_active_memory_usage = curr_mem["active_bytes.all.current"]
self.device_active_memory_gb = format_to_gb(self.device_active_memory_usage)
self.device_active_memory_pct = convert_to_gpu_pct(
self.device_active_memory_usage, self.device_capacity
)

display_str = ""
display_str += f"Current Memory: {self.device_name} ({self.device_index}): Reserved: {self.device_reserved_memory_pct}%,"
display_str += f"Alloc {self.device_alloc_memory_pct}%, Active: {self.device_active_memory_pct}%\n"

self.get_peak_stats(curr_mem)

peak_active_pct = self.get_pct_memory(self.peak_active_memory)
peak_allocated_pct = self.get_pct_memory(self.peak_allocated_memory)
peak_reserved_pct = self.get_pct_memory(self.peak_reserved_memory)
display_str += f"Peak Memory: Reserved {peak_reserved_pct}%, Alloc {peak_allocated_pct}%, Active: {peak_active_pct}%\n"

display_str += f"num retries: {self.num_retries}, num ooms: {self.num_ooms}"
if self.num_retries > 0:
display_str += f"\nWARNING: {self.num_retries} retries -- recommend lowering batch size for max performance\n"

if not return_data:
return display_str

# return named tuple
curr_mem_stats = GPUMemStats(
self.device_alloc_memory_pct,
peak_active_pct,
self.device_reserved_memory_pct,
peak_reserved_pct,
self.device_active_memory_pct,
peak_active_pct,
self.num_retries,
)
return curr_mem_stats

def start_monitoring(self):
"""reset all monitoring stats"""
self.reset_peak_stats()

def get_peak_stats(self, cuda_info=None):
"""capture current peak memory stats"""
if not cuda_info:
cuda_info = torch.cuda.memory_stats()

self.peak_active_memory = cuda_info.get("active_bytes.all.peak", 0)
self.peak_allocated_memory = cuda_info.get("allocated_bytes.all.peak", 0)
self.peak_reserved_memory = cuda_info.get("reserved_bytes.all.peak", 0)

self.num_retries = cuda_info.get("num_alloc_retries", 0)
self.num_ooms = cuda_info.get("num_ooms", 0)

def reset_peak_stats(self):
"""reset peak memory stats"""
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
self.num_retries = 0
self.num_ooms = 0
self.active_peak_memory_utilization_str = ""
self.peak_memory_utilization_str = ""
self.peak_reserved_memory_utilization_str = ""

def __str__(self):
_ = self.get_current_stats()
display_str = f"{self.device_name} ({self.device_index}): {self.device_capacity_gb} GB capacity, "
display_str += f"{self.device_alloc_memory_gb} GB in-use, {self.device_alloc_memory_pct}% in-use"
return f"{display_str}"


def get_num_params(model: nn.Module, only_trainable: bool = False) -> int:
"""
Get the total model params
Args : only_trainable: whether to only count trainable params
"""
param_list = list(model.parameters())
if only_trainable:
param_list = [p for p in param_list if p.requires_grad]
unique_params = {p.data_ptr(): p for p in param_list}.values()
return sum(p.numel() for p in unique_params)
11 changes: 11 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torchtrain.datasets import create_tokenizer, dataloader_fn
from torchtrain.logging_utils import init_logger, rank0_log
from torchtrain.lr_scheduling import get_lr_scheduler
from torchtrain.metrics import get_num_params, GPUMemoryMonitor

from torchtrain.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtrain.parallelisms import models_parallelize_fns, ParallelDims
Expand Down Expand Up @@ -105,6 +106,14 @@ def main(args):

model = model_cls.from_model_args(model_config)

# log model size
model_param_count = get_num_params(model)
rank0_log(
f"Model {model_name} {args.model_conf} size: {model_param_count:,} total parameters"
)
gpu_metrics = GPUMemoryMonitor("cuda")
rank0_log(f"GPU memory usage: {gpu_metrics}")

# apply PTD parallelisms + AC
model = models_parallelize_fns[model_name](model, world_mesh, parallel_dims, args)

Expand Down Expand Up @@ -193,6 +202,8 @@ def main(args):

checkpoint.save(train_state.step, force=(train_state.step == args.steps))

rank0_log(f"{gpu_metrics.get_current_stats()}")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="TorchTrain arg parser.")
Expand Down

0 comments on commit 40c93e9

Please sign in to comment.