From 40c93e95cca38810963dc420d637568f15e9d2c6 Mon Sep 17 00:00:00 2001 From: Less Wright Date: Wed, 14 Feb 2024 16:33:59 -0800 Subject: [PATCH] add model num params display, gpu memory metrics (#56) This PR is the start of adding perf related metrics. 1 - This PR adds function for logging the total num of unique model params, with option for only counting trainable params as well. (for future peft/qlora type work). 2 - logs it with comma formatted logging and model name ala: Screenshot 2024-02-12 at 4 12 22 PM this helps de-mistify for example the size of our debug model as well: Screenshot 2024-02-12 at 4 10 17 PM **additional updates** - added in gpu mem tracking. We want to show the user peak memory stats, as well as monitor and alert for any cudacachealloc retries which are a perf hindrance. Thus, added class GPUMemoryMonitor: usage: 1 - instantiate Screenshot 2024-02-13 at 9 32 11 AM 2 - start of training = start_monitoring() 3 - end of training = stop_monitoring() 4 - show results = get_peak_stats_str() and rank0_log it. Screenshot 2024-02-13 at 9 12 45 AM --- run_llama_train.sh | 4 +- torchtrain/metrics.py | 189 ++++++++++++++++++++++++++++++++++++++++++ train.py | 11 +++ 3 files changed, 202 insertions(+), 2 deletions(-) create mode 100644 torchtrain/metrics.py diff --git a/run_llama_train.sh b/run_llama_train.sh index 607a9135..723ddec8 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -23,10 +23,10 @@ CHECKPOINT_FOLDER=${CHECKPOINT_FOLDER:-""} # Please adjust this to a longer interval period. The unit of measurement is in steps. CHECKPOINT_INTERVAL=${CHECKPOINT_INTERVAL:-5} -torchrun --nproc_per_node=${NGPU} \ +torchrun --nproc_per_node=${NGPU} --rdzv_endpoint="localhost:5972" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ train.py --steps 10 \ --model ${MODEL} --model_conf ${MODEL_CONF} \ --pp_degree ${PP} --sp_degree ${SP} --dp_degree ${DP} \ ---compile +--compile \ --checkpoint-folder=${CHECKPOINT_FOLDER} --checkpoint-interval=${CHECKPOINT_INTERVAL} diff --git a/torchtrain/metrics.py b/torchtrain/metrics.py new file mode 100644 index 00000000..64888b37 --- /dev/null +++ b/torchtrain/metrics.py @@ -0,0 +1,189 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved + +from collections import namedtuple + +import torch +import torch.nn as nn + +_gb_in_bytes = 1024 * 1024 * 1024 +_mb_in_bytes = 1024 * 1024 + + +def format_to_gb(item, precision=4): + """quick function to format numbers to gigabyte and round to (default) 4 digit precision""" + metric_num = item / _gb_in_bytes + metric_num = round(metric_num, ndigits=precision) + return metric_num + + +def convert_to_gpu_pct(value, total_gpu_memory): + return round(100 * (value / total_gpu_memory), 2) + + +# named tuple for passing memory stats (as % of device capacity) for Tensorboard logging +GPUMemStats = namedtuple( + "GPUMemStats", + [ + "allocated_curr", + "allocated_peak", + "reserved_curr", + "reserved_peak", + "active_curr", + "active_peak", + "num_retries", + ], +) + + +class GPUMemoryMonitor: + """ + Class to monitor GPU memory usage + """ + + def __init__(self, device: str = "cuda:0"): + self.device = torch.device(device) # device object + self.device_name = torch.cuda.get_device_name(self.device) + self.device_index = torch.cuda.current_device() + self.device_capacity = torch.cuda.get_device_properties( + self.device + ).total_memory + self.device_capacity_gb = format_to_gb(self.device_capacity) + self.num_retries = 0 + self.num_ooms = 0 + self.peak_active_memory = 0 + self.peak_allocated_memory = 0 + self.peak_reserved_memory = 0 + self.curr_reserved_memory = 0 + + self.device_reserved_memory_usage = 0 + self.device_reserved_memory_gb = 0 + self.device_reserved_memory_pct = 0 + + self.device_active_memory_usage = 0 + self.device_active_memory_gb = 0 + self.device_active_memory_pct = 0 + + # current stats + self.device_alloc_memory_usage = torch.cuda.memory_allocated(self.device) + self.device_alloc_memory_gb = format_to_gb(self.device_alloc_memory_usage) + self.device_alloc_memory_pct = convert_to_gpu_pct( + self.device_alloc_memory_usage, self.device_capacity + ) + + # reset stats, clear cache + torch.cuda.reset_peak_memory_stats() + torch.cuda.empty_cache() + + def get_pct_memory(self, memory_num): + pct_memory = memory_num / self.device_capacity + pct_memory = round(100 * (pct_memory), 2) + return pct_memory + + def get_gb_memory(self, memory_num): + gb_memory = memory_num / _gb_in_bytes + gb_memory = round(gb_memory, 2) + return gb_memory + + def get_current_stats(self, return_data: bool = False): + """ + get the CudaCachingAllocator stats for the current device + + return_data: bool, if True, return the data as a named tuple + """ + curr_mem = torch.cuda.memory_stats(self.device) + + self.device_alloc_memory_usage = curr_mem["allocated_bytes.all.current"] + self.device_alloc_memory_gb = format_to_gb(self.device_alloc_memory_usage) + self.device_alloc_memory_pct = convert_to_gpu_pct( + self.device_alloc_memory_usage, self.device_capacity + ) + + self.device_reserved_memory_usage = curr_mem["reserved_bytes.all.current"] + self.device_reserved_memory_gb = format_to_gb(self.device_reserved_memory_usage) + self.device_reserved_memory_pct = convert_to_gpu_pct( + self.device_reserved_memory_usage, self.device_capacity + ) + + self.device_active_memory_usage = curr_mem["active_bytes.all.current"] + self.device_active_memory_gb = format_to_gb(self.device_active_memory_usage) + self.device_active_memory_pct = convert_to_gpu_pct( + self.device_active_memory_usage, self.device_capacity + ) + + display_str = "" + display_str += f"Current Memory: {self.device_name} ({self.device_index}): Reserved: {self.device_reserved_memory_pct}%," + display_str += f"Alloc {self.device_alloc_memory_pct}%, Active: {self.device_active_memory_pct}%\n" + + self.get_peak_stats(curr_mem) + + peak_active_pct = self.get_pct_memory(self.peak_active_memory) + peak_allocated_pct = self.get_pct_memory(self.peak_allocated_memory) + peak_reserved_pct = self.get_pct_memory(self.peak_reserved_memory) + display_str += f"Peak Memory: Reserved {peak_reserved_pct}%, Alloc {peak_allocated_pct}%, Active: {peak_active_pct}%\n" + + display_str += f"num retries: {self.num_retries}, num ooms: {self.num_ooms}" + if self.num_retries > 0: + display_str += f"\nWARNING: {self.num_retries} retries -- recommend lowering batch size for max performance\n" + + if not return_data: + return display_str + + # return named tuple + curr_mem_stats = GPUMemStats( + self.device_alloc_memory_pct, + peak_active_pct, + self.device_reserved_memory_pct, + peak_reserved_pct, + self.device_active_memory_pct, + peak_active_pct, + self.num_retries, + ) + return curr_mem_stats + + def start_monitoring(self): + """reset all monitoring stats""" + self.reset_peak_stats() + + def get_peak_stats(self, cuda_info=None): + """capture current peak memory stats""" + if not cuda_info: + cuda_info = torch.cuda.memory_stats() + + self.peak_active_memory = cuda_info.get("active_bytes.all.peak", 0) + self.peak_allocated_memory = cuda_info.get("allocated_bytes.all.peak", 0) + self.peak_reserved_memory = cuda_info.get("reserved_bytes.all.peak", 0) + + self.num_retries = cuda_info.get("num_alloc_retries", 0) + self.num_ooms = cuda_info.get("num_ooms", 0) + + def reset_peak_stats(self): + """reset peak memory stats""" + torch.cuda.reset_peak_memory_stats() + torch.cuda.empty_cache() + self.num_retries = 0 + self.num_ooms = 0 + self.active_peak_memory_utilization_str = "" + self.peak_memory_utilization_str = "" + self.peak_reserved_memory_utilization_str = "" + + def __str__(self): + _ = self.get_current_stats() + display_str = f"{self.device_name} ({self.device_index}): {self.device_capacity_gb} GB capacity, " + display_str += f"{self.device_alloc_memory_gb} GB in-use, {self.device_alloc_memory_pct}% in-use" + return f"{display_str}" + + +def get_num_params(model: nn.Module, only_trainable: bool = False) -> int: + """ + Get the total model params + Args : only_trainable: whether to only count trainable params + """ + param_list = list(model.parameters()) + if only_trainable: + param_list = [p for p in param_list if p.requires_grad] + unique_params = {p.data_ptr(): p for p in param_list}.values() + return sum(p.numel() for p in unique_params) diff --git a/train.py b/train.py index e2f70031..a0a0f891 100644 --- a/train.py +++ b/train.py @@ -18,6 +18,7 @@ from torchtrain.datasets import create_tokenizer, dataloader_fn from torchtrain.logging_utils import init_logger, rank0_log from torchtrain.lr_scheduling import get_lr_scheduler +from torchtrain.metrics import get_num_params, GPUMemoryMonitor from torchtrain.models import model_name_to_cls, model_name_to_tokenizer, models_config from torchtrain.parallelisms import models_parallelize_fns, ParallelDims @@ -105,6 +106,14 @@ def main(args): model = model_cls.from_model_args(model_config) + # log model size + model_param_count = get_num_params(model) + rank0_log( + f"Model {model_name} {args.model_conf} size: {model_param_count:,} total parameters" + ) + gpu_metrics = GPUMemoryMonitor("cuda") + rank0_log(f"GPU memory usage: {gpu_metrics}") + # apply PTD parallelisms + AC model = models_parallelize_fns[model_name](model, world_mesh, parallel_dims, args) @@ -193,6 +202,8 @@ def main(args): checkpoint.save(train_state.step, force=(train_state.step == args.steps)) + rank0_log(f"{gpu_metrics.get_current_stats()}") + if __name__ == "__main__": parser = argparse.ArgumentParser(description="TorchTrain arg parser.")