From ae85e970cff7da781b3a7b43c1dcd50ca82df4f1 Mon Sep 17 00:00:00 2001 From: Less Wright Date: Mon, 26 Feb 2024 10:16:35 -0800 Subject: [PATCH 1/5] add iter time tracking via cuda events, add data loading times, add columnar display to show both, show avg iter & data loading times at end of training (#87) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds basic perf timing and display for 'per iter' and 'final iter average' display. (in part based on Andrew's comment about having to open the trace to compare iter timing). 1. tracking list is housed in TrainState, but I do not save it as part of the state dict as I view this as useful but not saveable info. 2. iter times are tracked after dataloading is done each iter and after optimizer step. The idea is to make this timing expressly the model training iter (not data loading or post iter other metrics calcs). 3. 'time' is now displayed at each iter along with the usual loss and lr. 4. at the end of training, assuming more than 3 iters run, then the average iter time is calculated by igoring the first three iters (consider these as warmup esp as cudaCacheAllocator gets warmed up) and displayed. 5. based on @tianyu-l feedback: I have added data loading times as well. I used the same timeit.default_timer() from timeit to be consistent. (cpu side so no synch's needed :) 6 - after fiddling with printf width formatting options, added beautiful aligned columnar display for the per iter updates: Now: Screenshot 2024-02-26 at 9 39 25 AM before: Screenshot 2024-02-26 at 8 39 46 AM --- train.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index f3f3389e..904ebf17 100644 --- a/train.py +++ b/train.py @@ -35,6 +35,8 @@ class TrainState: step: int = 0 current_loss: float = -1 losses: List[float] = field(default_factory=list) + iter_times: List[float] = field(default_factory=list) + data_load_times: List[float] = field(default_factory=list) def state_dict(self) -> Dict[str, Any]: return { @@ -177,15 +179,22 @@ def main(job_config: JobConfig): ): train_state.step += 1 # get batch + data_load_start = timer() batch = next(iter(data_loader)) input_ids, labels = batch input_ids = input_ids.cuda() labels = labels.cuda() + data_load_time = round(timer() - data_load_start, 4) + train_state.data_load_times.append(data_load_time) nwords_since_last_log += labels.numel() optimizer.zero_grad() # forward + start_timer = torch.cuda.Event(enable_timing=True) + end_timer = torch.cuda.Event(enable_timing=True) + start_timer.record() + pred = model(input_ids) tok_loss = F.cross_entropy( pred.flatten(0, 1), labels.flatten(0, 1), reduction="none" @@ -207,6 +216,13 @@ def main(job_config: JobConfig): # updates the scale for next iteration scaler.update() + # training iteration complete + end_timer.record() + torch.cuda.synchronize() + + curr_iter_time = round(start_timer.elapsed_time(end_timer) * 1e-3, 4) + train_state.iter_times.append(curr_iter_time) + # if profiler is active if torch_profiler: torch_profiler.step() @@ -251,8 +267,8 @@ def main(job_config: JobConfig): time_last_log = timer() rank0_log( - f"step: {train_state.step}, current loss: {round(train_state.current_loss,4)}," - f" lr: {round(float(scheduler.get_last_lr()[0]), 8)}" + f"step: {train_state.step:>2} loss: {round(train_state.current_loss,4):>7}" + f" iter: {curr_iter_time:>7} data: {data_load_time:>5} lr: {round(float(scheduler.get_last_lr()[0]), 8):<6}" ) scheduler.step() @@ -261,6 +277,13 @@ def main(job_config: JobConfig): ) metric_logger.close() + # calc and show average iter time, disregard first three iterations (warmup) + if len(train_state.iter_times) > 3: + avg_iter_time = np.mean(train_state.iter_times[3:]) + rank0_log(f"Average iter time: {avg_iter_time:.4f} seconds") + avg_data_load_time = np.mean(train_state.data_load_times[3:]) + rank0_log(f"Average data load time: {avg_data_load_time:.4f} seconds") + rank0_log(f"{gpu_metrics.get_current_stats()}") From 97aa4bc84644c95e81005201fa2a3d37dc40865d Mon Sep 17 00:00:00 2001 From: gnadathur Date: Mon, 26 Feb 2024 12:48:19 -0800 Subject: [PATCH 2/5] Fill missing options in toml file wih argparse defaults (#91) Summary: Summary: Follow up on config unification, options not available in config file are picked from command line defaults. Test Plan: ============================= test session starts ============================== platform linux -- Python 3.10.13, pytest-8.0.1, pluggy-1.4.0 -- /home/gnadathur/local/a/pytorch-env/bin/python cachedir: .pytest_cache rootdir: /data/users/gnadathur/a/torchtrain configfile: pyproject.toml plugins: cov-4.1.0 collecting ... collected 3 items test/test_job_config.py::TestJobConfig::test_command_line_args PASSED [ 33%] test/test_job_config.py::TestJobConfig::test_job_config_file PASSED [ 66%] test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist PASSED [100%] ---------- coverage: platform linux, python 3.10.13-final-0 ---------- Coverage XML written to file coverage.xml ============================= slowest 20 durations ============================= 0.00s call test/test_job_config.py::TestJobConfig::test_job_config_file 0.00s call test/test_job_config.py::TestJobConfig::test_command_line_args 0.00s call test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist 0.00s setup test/test_job_config.py::TestJobConfig::test_command_line_args 0.00s teardown test/test_job_config.py::TestJobConfig::test_command_line_args 0.00s setup test/test_job_config.py::TestJobConfig::test_job_config_file 0.00s setup test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist 0.00s teardown test/test_job_config.py::TestJobConfig::test_job_config_file 0.00s teardown test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist ============================== 3 passed in 0.06s =============================== Test Plan: Reviewers: Subscribers: Tasks: Tags: --------- Co-authored-by: gnadathur --- test/test_job_config.py | 4 ++-- torchtrain/config_manager.py | 11 +++++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/test/test_job_config.py b/test/test_job_config.py index ccdaf206..0e3d9c63 100644 --- a/test/test_job_config.py +++ b/test/test_job_config.py @@ -9,12 +9,12 @@ class TestJobConfig: def test_command_line_args(self): config = JobConfig() config.parse_args([]) - assert config.model.name == "llama" + assert config.training.steps == -1 def test_job_config_file(self): config = JobConfig() config.parse_args(["--job.config_file", "./train_configs/debug_model.toml"]) - assert config.model.name == "llama" + assert config.training.steps == 10 def test_job_file_does_not_exist(self): with pytest.raises(FileNotFoundError): diff --git a/torchtrain/config_manager.py b/torchtrain/config_manager.py index ff8afe8f..613f9411 100644 --- a/torchtrain/config_manager.py +++ b/torchtrain/config_manager.py @@ -1,3 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. import argparse @@ -17,16 +20,16 @@ class JobConfig: Semantics: - Default config is loaded from a toml file. If no toml file is provided, then the default config is loaded from argparse defaults. + - if toml file has missing keys, they are filled with argparse defaults. """ def parse_args(self, args_list: list = sys.argv[1:]): args = JobConfig.init_args_from_command_line(args_list) config_file = getattr(args, "job.config_file", None) - if config_file is None: - args_dict = self._args_to_two_level_dict(args) - else: + args_dict = self._args_to_two_level_dict(args) + if config_file is not None: with open(config_file, "rb") as f: - args_dict = tomllib.load(f) + args_dict |= tomllib.load(f) for k, v in args_dict.items(): class_type = type(k.title(), (), v) setattr(self, k, class_type()) From 5dec5360962a575047ac4f297df813c332f3ec5f Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Mon, 26 Feb 2024 15:57:05 -0800 Subject: [PATCH 3/5] support infinite loop over alpaca dataset ghstack-source-id: 38cbc277e2a177bc0baf35450a661835b97a7f22 Pull Request resolved: https://github.com/pytorch/torchtrain/pull/92 --- torchtrain/datasets/alpaca.py | 40 ++++++++++++++++++++++------------- train.py | 4 +++- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/torchtrain/datasets/alpaca.py b/torchtrain/datasets/alpaca.py index 3ee1442d..734792b6 100644 --- a/torchtrain/datasets/alpaca.py +++ b/torchtrain/datasets/alpaca.py @@ -20,6 +20,7 @@ class AlpacaDataset(IterableDataset): seq_len (int): max sequence length world_size (int): number of data parallel processes participating in training rank (int): rank of the current data parallel process + infinite: whether to loop infinitely over the dataset Data input format: { @@ -43,38 +44,47 @@ def __init__( seq_len: int = 2048, world_size: int = 1, rank: int = 0, + infinite: bool = False, **kwargs ) -> None: # TODO: This is a temporary solution for small datasets like Alpaca. # For larger datasets we need to use a more scalable approach. # Setting `streaming=True` works for large dataset, but the speed is slow. ds = load_dataset("tatsu-lab/alpaca", split="train") - self.data_iterator = iter(split_dataset_by_node(ds, rank, world_size)) + self._data = split_dataset_by_node(ds, rank, world_size) self._tokenizer = tokenizer self.seq_len = seq_len + self.infinite = infinite def __iter__(self): max_buffer_token_len = 1 + self.seq_len all_tokens: List[int] = [] - for sample in self.data_iterator: - sample_text = sample["text"] - sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True) - all_tokens.extend(sample_tokens) + while True: + for sample in iter(self._data): + sample_text = sample["text"] + sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True) + all_tokens.extend(sample_tokens) - while len(all_tokens) >= max_buffer_token_len: - x = torch.LongTensor(all_tokens[:max_buffer_token_len]) - # batched_x = x.reshape(self.batch_size, -1) - # update tokens to the remaining tokens - all_tokens = all_tokens[max_buffer_token_len:] - input = x[:-1] - label = x[1:] - yield input, label + while len(all_tokens) >= max_buffer_token_len: + x = torch.LongTensor(all_tokens[:max_buffer_token_len]) + # update tokens to the remaining tokens + all_tokens = all_tokens[max_buffer_token_len:] + input = x[:-1] + label = x[1:] + yield input, label + if not self.infinite: + break def build_alpaca_data_loader( - tokenizer: TokenizerIf, batch_size: int, seq_len: int, world_size, rank + tokenizer: TokenizerIf, + batch_size: int, + seq_len: int, + world_size: int, + rank: int, + infinite: bool = True, ): - alpaca_ds = AlpacaDataset(tokenizer, seq_len, world_size, rank) + alpaca_ds = AlpacaDataset(tokenizer, seq_len, world_size, rank, infinite) return DataLoader(alpaca_ds, batch_size=batch_size) diff --git a/train.py b/train.py index 904ebf17..5ce5de37 100644 --- a/train.py +++ b/train.py @@ -167,6 +167,8 @@ def main(job_config: JobConfig): ) checkpoint.load() + data_iterator = iter(data_loader) + with maybe_run_profiler(job_config) as torch_profiler: checkpoint.reset() # variables used to keep info for metrics logging @@ -180,7 +182,7 @@ def main(job_config: JobConfig): train_state.step += 1 # get batch data_load_start = timer() - batch = next(iter(data_loader)) + batch = next(data_iterator) input_ids, labels = batch input_ids = input_ids.cuda() labels = labels.cuda() From 8671c913832a6ab351e8f0db60c749bb4d70f3b4 Mon Sep 17 00:00:00 2001 From: Less Wright Date: Tue, 27 Feb 2024 10:41:40 -0800 Subject: [PATCH 4/5] Add color to console output if local logging, auto avoid color logging on slurm (#93) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds the ability to do colored console outputs in order to highlight the training data outputs. It also adds a check to not use this color formatting on slurm, where it will add 33= instead of the color if not avoided. Note that I've just added some color to highlight the main training data. Users that fork/clone can use it to enhance their outputs as desired. Screenshot 2024-02-26 at 10 20 15 PM Note that on slurm it remains plain: Screenshot 2024-02-26 at 10 46 24 PM if you dont' check this, then it would otherwise look like this (this does not happen with this PR, just showing if we didn't check and credit to Yifu for noting this would be an issue): Screenshot 2024-02-26 at 10 39 23 PM --- torchtrain/utils.py | 35 +++++++++++++++++++++++++++++++++++ train.py | 39 +++++++++++++++++++++++++++++++-------- 2 files changed, 66 insertions(+), 8 deletions(-) diff --git a/torchtrain/utils.py b/torchtrain/utils.py index 9ae71cae..823e8843 100644 --- a/torchtrain/utils.py +++ b/torchtrain/utils.py @@ -1,6 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. +from dataclasses import dataclass from typing import Union import torch @@ -17,3 +18,37 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float: def dist_mean(x: Union[int, float], mesh: DeviceMesh) -> float: tensor = torch.tensor(x).cuda() return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh) + + +@dataclass +class Color: + black = "\033[30m" + red = "\033[31m" + green = "\033[32m" + yellow = "\033[33m" + blue = "\033[34m" + magenta = "\033[35m" + cyan = "\033[36m" + white = "\033[37m" + reset = "\033[39m" + + +@dataclass +class Background: + black = "\033[40m" + red = "\033[41m" + green = "\033[42m" + yellow = "\033[43m" + blue = "\033[44m" + magenta = "\033[45m" + cyan = "\033[46m" + white = "\033[47m" + reset = "\033[49m" + + +@dataclass +class Style: + bright = "\033[1m" + dim = "\033[2m" + normal = "\033[22m" + reset = "\033[0m" diff --git a/train.py b/train.py index 5ce5de37..95d42226 100644 --- a/train.py +++ b/train.py @@ -2,6 +2,7 @@ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. import os + from dataclasses import dataclass, field from timeit import default_timer as timer from typing import Any, Dict, List @@ -27,7 +28,11 @@ from torchtrain.parallelisms import models_parallelize_fns, ParallelDims from torchtrain.profiling import maybe_run_profiler -from torchtrain.utils import dist_max, dist_mean +from torchtrain.utils import Color, dist_max, dist_mean + +_is_local_logging = True +if "SLURM_JOB_ID" in os.environ: + _is_local_logging = False @dataclass @@ -119,9 +124,16 @@ def main(job_config: JobConfig): # log model size model_param_count = get_num_params(model) - rank0_log( - f"Model {model_name} {job_config.model.flavor} size: {model_param_count:,} total parameters" - ) + if _is_local_logging: + rank0_log( + f"{Color.blue}Model {model_name} {job_config.model.flavor} {Color.red}size: {model_param_count:,}" + f" total parameters{Color.reset}" + ) + else: + rank0_log( + f"{model_name} {job_config.model.flavor} size: {model_param_count:,} total parameters" + ) + gpu_metrics = GPUMemoryMonitor("cuda") rank0_log(f"GPU memory usage: {gpu_metrics}") @@ -268,10 +280,21 @@ def main(job_config: JobConfig): nwords_since_last_log = 0 time_last_log = timer() - rank0_log( - f"step: {train_state.step:>2} loss: {round(train_state.current_loss,4):>7}" - f" iter: {curr_iter_time:>7} data: {data_load_time:>5} lr: {round(float(scheduler.get_last_lr()[0]), 8):<6}" - ) + if _is_local_logging: + rank0_log( + f"{Color.cyan}step: {train_state.step:>2} {Color.green}loss: {round(train_state.current_loss,4):>7}" + f" {Color.reset}iter: {Color.blue}{curr_iter_time:>7}{Color.reset}" + f" data: {Color.blue}{data_load_time:>5} {Color.reset}" + f"lr: {Color.yellow}{round(float(scheduler.get_last_lr()[0]), 8):<6}{Color.reset}" + ) + else: + rank0_log( + f"step: {train_state.step:>2} loss: {round(train_state.current_loss,4):>7}" + f" iter: {curr_iter_time:>7}" + f" data: {data_load_time:>5} " + f"lr: {round(float(scheduler.get_last_lr()[0]), 8):<6}" + ) + scheduler.step() checkpoint.save( From 5a1689fbfcded01271c25e6bfd3648aac494781c Mon Sep 17 00:00:00 2001 From: Less Wright Date: Tue, 27 Feb 2024 11:51:44 -0800 Subject: [PATCH 5/5] update GPU metrics logging to GiB (gibibytes) (#95) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit this PR updates the GPU metrics to labelling as GiB - we were calculating GiB but calling it GB. (credit to @awgu for flagging this - issue https://github.com/pytorch/torchtrain/issues/94) function names and member vars in metrics.py have been updated to _gib instead of _gb for clarity, and the logging output now labels as GiB: Screenshot 2024-02-27 at 11 28 23 AM --- torchtrain/metrics.py | 51 ++++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/torchtrain/metrics.py b/torchtrain/metrics.py index b2ad3cc9..d56d80a3 100644 --- a/torchtrain/metrics.py +++ b/torchtrain/metrics.py @@ -16,19 +16,20 @@ from torchtrain.logging_utils import rank0_log -_gb_in_bytes = 1024 * 1024 * 1024 -_mb_in_bytes = 1024 * 1024 +# note that GiB (gibibyte) is 1024, vs GB is 1000 +_gib_in_bytes = 1024 * 1024 * 1024 +_mib_in_bytes = 1024 * 1024 -def format_to_gb(item, precision=4): - """quick function to format numbers to gigabyte and round to (default) 4 digit precision""" - metric_num = item / _gb_in_bytes +def _format_to_gib(item, precision=4): + """quick function to format numbers to gibibyte and round to (default) 4 digit precision""" + metric_num = item / _gib_in_bytes metric_num = round(metric_num, ndigits=precision) return metric_num -def convert_to_gpu_pct(value, total_gpu_memory): - return round(100 * (value / total_gpu_memory), 2) +def _convert_to_gpu_pct(value, total_gpu_memory, precision=4): + return round(100 * (value / total_gpu_memory), precision) # named tuple for passing memory stats (as % of device capacity) for Tensorboard logging @@ -58,7 +59,7 @@ def __init__(self, device: str = "cuda:0"): self.device_capacity = torch.cuda.get_device_properties( self.device ).total_memory - self.device_capacity_gb = format_to_gb(self.device_capacity) + self.device_capacity_gib = _format_to_gib(self.device_capacity) self.num_retries = 0 self.num_ooms = 0 self.peak_active_memory = 0 @@ -67,17 +68,17 @@ def __init__(self, device: str = "cuda:0"): self.curr_reserved_memory = 0 self.device_reserved_memory_usage = 0 - self.device_reserved_memory_gb = 0 + self.device_reserved_memory_gib = 0 self.device_reserved_memory_pct = 0 self.device_active_memory_usage = 0 - self.device_active_memory_gb = 0 + self.device_active_memory_gib = 0 self.device_active_memory_pct = 0 # current stats self.device_alloc_memory_usage = torch.cuda.memory_allocated(self.device) - self.device_alloc_memory_gb = format_to_gb(self.device_alloc_memory_usage) - self.device_alloc_memory_pct = convert_to_gpu_pct( + self.device_alloc_memory_gib = _format_to_gib(self.device_alloc_memory_usage) + self.device_alloc_memory_pct = _convert_to_gpu_pct( self.device_alloc_memory_usage, self.device_capacity ) @@ -90,10 +91,8 @@ def get_pct_memory(self, memory_num): pct_memory = round(100 * (pct_memory), 2) return pct_memory - def get_gb_memory(self, memory_num): - gb_memory = memory_num / _gb_in_bytes - gb_memory = round(gb_memory, 2) - return gb_memory + def get_gib_memory(self, memory_num): + return _format_to_gib(memory_num, precision=2) def get_current_stats(self, return_data: bool = False): """ @@ -104,21 +103,23 @@ def get_current_stats(self, return_data: bool = False): curr_mem = torch.cuda.memory_stats(self.device) self.device_alloc_memory_usage = curr_mem["allocated_bytes.all.current"] - self.device_alloc_memory_gb = format_to_gb(self.device_alloc_memory_usage) - self.device_alloc_memory_pct = convert_to_gpu_pct( + self.device_alloc_memory_gib = _format_to_gib(self.device_alloc_memory_usage) + self.device_alloc_memory_pct = _convert_to_gpu_pct( self.device_alloc_memory_usage, self.device_capacity ) self.device_reserved_memory_usage = curr_mem["reserved_bytes.all.current"] - self.device_reserved_memory_gb = format_to_gb(self.device_reserved_memory_usage) - self.device_reserved_memory_pct = convert_to_gpu_pct( + self.device_reserved_memory_gib = _format_to_gib( + self.device_reserved_memory_usage + ) + self.device_reserved_memory_pct = _convert_to_gpu_pct( self.device_reserved_memory_usage, self.device_capacity ) self.device_active_memory_usage = curr_mem["active_bytes.all.current"] - self.device_active_memory_gb = format_to_gb(self.device_active_memory_usage) - self.device_active_memory_pct = convert_to_gpu_pct( - self.device_active_memory_usage, self.device_capacity + self.device_active_memory_gib = _format_to_gib(self.device_active_memory_usage) + self.device_active_memory_pct = _convert_to_gpu_pct( + self.device_active_memory_usage, self.device_capacity, precision=2 ) display_str = "" @@ -179,8 +180,8 @@ def reset_peak_stats(self): def __str__(self): _ = self.get_current_stats() - display_str = f"{self.device_name} ({self.device_index}): {self.device_capacity_gb} GB capacity, " - display_str += f"{self.device_alloc_memory_gb} GB in-use, {self.device_alloc_memory_pct}% in-use" + display_str = f"{self.device_name} ({self.device_index}): {self.device_capacity_gib} GiB capacity, " + display_str += f"{self.device_alloc_memory_gib} GiB in-use, {self.device_alloc_memory_pct}% in-use" return f"{display_str}"