diff --git a/torchtrain/metrics.py b/torchtrain/metrics.py index 092cafae..cad98e85 100644 --- a/torchtrain/metrics.py +++ b/torchtrain/metrics.py @@ -122,15 +122,15 @@ def get_current_stats(self, return_data: bool = False): ) display_str = "" - display_str += f"Current Memory: {self.device_name} ({self.device_index}): Reserved: {self.device_reserved_memory_pct}%," - display_str += f"Alloc {self.device_alloc_memory_pct}%, Active: {self.device_active_memory_pct}%\n" + display_str += f"Current Memory: {self.device_name} ({self.device_index}): Reserved: {self.device_reserved_memory_pct}%, " + display_str += f"Alloc {self.device_alloc_memory_pct}%, Active: {self.device_active_memory_pct}%\n" self.get_peak_stats(curr_mem) peak_active_pct = self.get_pct_memory(self.peak_active_memory) peak_allocated_pct = self.get_pct_memory(self.peak_allocated_memory) peak_reserved_pct = self.get_pct_memory(self.peak_reserved_memory) - display_str += f"Peak Memory: Reserved {peak_reserved_pct}%, Alloc {peak_allocated_pct}%, Active: {peak_active_pct}%\n" + display_str += f"Peak Memory: Reserved {peak_reserved_pct}%, Alloc {peak_allocated_pct}%, Active: {peak_active_pct}%\n" display_str += f"num retries: {self.num_retries}, num ooms: {self.num_ooms}" if self.num_retries > 0: diff --git a/train.py b/train.py index e922acf3..faabc2d8 100644 --- a/train.py +++ b/train.py @@ -219,10 +219,18 @@ def main(args): time_delta * parallel_dims.model_parallel_size ) + gpu_mem_stats = gpu_metrics.get_current_stats(return_data=True) + metrics = { - "global_avg_loss": global_avg_loss, - "global_max_loss": global_max_loss, + "loss_metrics/global_avg_loss": global_avg_loss, + "loss_metrics/global_max_loss": global_max_loss, "wps": wps, + "memory_current/active(%)": gpu_mem_stats.active_curr, + "memory_current/allocated(%)": gpu_mem_stats.allocated_curr, + "memory_current/reserved(%)": gpu_mem_stats.reserved_curr, + "memory_peak/active(%)": gpu_mem_stats.active_peak, + "memory_peak/allocated(%)": gpu_mem_stats.allocated_peak, + "memory_peak/reserved(%)": gpu_mem_stats.reserved_peak, } metric_logger.log(metrics, step=train_state.step)