diff --git a/src/itwinai/cli.py b/src/itwinai/cli.py index fc42d94f..9bb6b27e 100644 --- a/src/itwinai/cli.py +++ b/src/itwinai/cli.py @@ -19,19 +19,63 @@ app = typer.Typer(pretty_exceptions_enable=False) +@app.command() +def generate_gpu_energy_plot( + log_dir: str = "scalability_metrics/gpu_energy_data", + pattern: str = r"gpu_energy_data.*\.csv$", + output_file: str = "plots/gpu_energy_plot.png", +) -> None: + """Generate a GPU energy plot showing the expenditure for each combination of + strategy and number of GPUs in Watt hours. + + Args: + log_dir: The directory where the csv logs are stored. Defaults to + ``utilization_logs``. + pattern: A regex pattern to recognize the file names in the 'log_dir' folder. + Defaults to ``dataframe_(?:\\w+)_(?:\\d+)\\.csv$``. Set it to 'None' to + make it None. In this case, it will match all files in the given folder. + output_file: The path to where the resulting plot should be saved. Defaults to + ``plots/gpu_energy_plot.png``. + + """ + import matplotlib.pyplot as plt + + from itwinai.torch.monitoring.plotting import gpu_energy_plot, read_energy_df + + log_dir_path = Path(log_dir) + if not log_dir_path.exists(): + raise ValueError( + f"The provided log_dir, '{log_dir_path.resolve()}', does not exist." + ) + + if pattern.lower() == "none": + pattern = None + + gpu_utilization_df = read_energy_df(pattern=pattern, log_dir=log_dir_path) + gpu_energy_plot(gpu_utilization_df=gpu_utilization_df) + + output_path = Path(output_file) + output_path.parent.mkdir(parents=True, exist_ok=True) + + plt.savefig(output_path) + print(f"\nSaved GPU energy plot at '{output_path.resolve()}'.") + + @app.command() def generate_communication_plot( - log_dir: str = "profiling_logs", - pattern: str = r"profile_(\w+)_(\d+)_(\d+)\.csv$", - output_file: str = "plots/comm_plot.png", + log_dir: str = "scalability_metrics/communication_data", + pattern: str = r"(.+)_(\d+)_(\d+)\.csv$", + output_file: str = "plots/communication_plot.png", ) -> None: """Generate stacked plot showing computation vs. communication fraction. Stores it + to output_file. Args: - log_dir: The directory where the csv logs are stored. Defauls to + log_dir: The directory where the csv logs are stored. Defaults to ``profiling_logs``. pattern: A regex pattern to recognize the file names in the 'log_dir' folder. - Defaults to ``profile_(\\w+)_(\\d+)_(\\d+)\\.csv$``. + Defaults to ``profile_(\\w+)_(\\d+)_(\\d+)\\.csv$``. Set it to 'None' to + make it None. In this case, it will match all files in the given folder. output_file: The path to where the resulting plot should be saved. Defaults to ``plots/comm_plot.png``. """ @@ -45,13 +89,17 @@ def generate_communication_plot( log_dir_path = Path(log_dir) if not log_dir_path.exists(): - raise IOError( + raise ValueError( f"The directory '{log_dir_path.resolve()}' does not exist, so could not" f"extract profiling logs. Make sure you are running this command in the " - f"same directory as the logging dir." + f"same directory as the logging dir or are passing a sufficient relative" + f"path." ) - df = create_combined_comm_overhead_df(logs_dir=log_dir_path, pattern=pattern) + if pattern.lower() == "none": + pattern = None + + df = create_combined_comm_overhead_df(log_dir=log_dir_path, pattern=pattern) values = get_comp_fraction_full_array(df, print_table=True) strategies = sorted(df["strategy"].unique()) @@ -67,7 +115,7 @@ def generate_communication_plot( output_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(output_path) - print(f"\nSaved computation vs. communication plot at '{output_path.resolve()}'") + print(f"\nSaved computation vs. communication plot at '{output_path.resolve()}'.") @app.command() diff --git a/src/itwinai/torch/distributed.py b/src/itwinai/torch/distributed.py index 559ba2b0..42a5cf06 100644 --- a/src/itwinai/torch/distributed.py +++ b/src/itwinai/torch/distributed.py @@ -37,6 +37,9 @@ class TorchDistributedStrategy(DistributedStrategy): #: Defaults to False. is_initialized: bool = False + # Provides the name of the strategy for logging purposes etc. + name: str + @property def is_main_worker(self) -> bool: """Checks if local worker has global rank equal to zero. @@ -46,12 +49,14 @@ def is_main_worker(self) -> bool: """ if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) return self.global_rank() == 0 @abc.abstractmethod def init(self) -> None: """Initializes the chosen distributed backend""" + # @abc.abstractmethod # def distributed_engine( # self, model: nn.Module, optimizer: Optimizer, @@ -61,8 +66,10 @@ def init(self) -> None: @abc.abstractmethod def distributed( - self, model: nn.Module, optimizer: Optimizer, - lr_scheduler: Optional[LRScheduler] = None + self, + model: nn.Module, + optimizer: Optimizer, + lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[nn.Module, Optimizer, Optional[LRScheduler]]: """Setup model, optimizer and scheduler for distributed.""" @@ -109,7 +116,8 @@ def device(self) -> str: """ if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) return f"cuda:{self.local_rank()}" def set_device(self): @@ -119,18 +127,24 @@ def set_device(self): torch.cuda.set_device(self.local_rank()) def create_dataloader( - self, dataset: Dataset[T_co], batch_size: Optional[int] = 1, + self, + dataset: Dataset[T_co], + batch_size: Optional[int] = 1, shuffle: Optional[bool] = None, sampler: Union[Sampler, Iterable, None] = None, batch_sampler: Union[Sampler[List], Iterable[List], None] = None, - num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None, - pin_memory: bool = False, drop_last: bool = False, + num_workers: int = 0, + collate_fn: Optional[_collate_fn_t] = None, + pin_memory: bool = False, + drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None, - multiprocessing_context=None, generator=None, - *, prefetch_factor: Optional[int] = None, + multiprocessing_context=None, + generator=None, + *, + prefetch_factor: Optional[int] = None, persistent_workers: bool = False, - pin_memory_device: str = "" + pin_memory_device: str = "", ): """Create a distributed DataLoader by using ``DistributedSampler`` as random sampler. @@ -271,22 +285,22 @@ def create_dataloader( https://pytorch.org/docs/stable/data.html#multi-process-data-loading .. _Dataset Types: https://pytorch.org/docs/stable/data.html#dataset-types - """ + """ if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) if batch_sampler is not None: - print( - "WARNING: batch_sampler is ignored by TorchDistributedStrategy" - ) + print("WARNING: batch_sampler is ignored by TorchDistributedStrategy") if self.is_distributed: if sampler is None: sampler = DistributedSampler( - dataset, num_replicas=self.global_world_size(), + dataset, + num_replicas=self.global_world_size(), rank=self.global_rank(), - shuffle=shuffle + shuffle=shuffle, ) elif not isinstance(sampler, DistributedSampler): raise RuntimeError( @@ -294,14 +308,20 @@ def create_dataloader( ) # shuffle and batch_sampler must be unset return DataLoader( - dataset=dataset, batch_size=batch_size, sampler=sampler, - num_workers=num_workers, collate_fn=collate_fn, - pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, + dataset=dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=num_workers, + collate_fn=collate_fn, + pin_memory=pin_memory, + drop_last=drop_last, + timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, - generator=generator, prefetch_factor=prefetch_factor, + generator=generator, + prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, - pin_memory_device=pin_memory_device + pin_memory_device=pin_memory_device, ) @abc.abstractmethod @@ -359,11 +379,12 @@ class TorchDDPStrategy(TorchDistributedStrategy): """ #: Torch distributed communication backend. - backend: Literal['nccl', 'gloo', 'mpi'] + backend: Literal["nccl", "gloo", "mpi"] - def __init__(self, backend: Literal['nccl', 'gloo', 'mpi']) -> None: + def __init__(self, backend: Literal["nccl", "gloo", "mpi"]) -> None: super().__init__() self.backend = backend + self.name = "torch-ddp" def init(self) -> None: """Initializes the distributed process group and the distributed @@ -375,8 +396,7 @@ def init(self) -> None: which is already initialized. """ if not distributed_resources_available(): - raise RuntimeError( - "Trying to run distributed on insufficient resources.") + raise RuntimeError("Trying to run distributed on insufficient resources.") if self.is_initialized: raise DistributedStrategyError("Strategy was already initialized") dist.init_process_group(backend=self.backend) @@ -409,15 +429,18 @@ def init(self) -> None: # return model_engine def distributed( - self, model: nn.Module, optimizer: Optimizer, + self, + model: nn.Module, + optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None, find_unused_parameters: bool = False, - **kwargs + **kwargs, ) -> Tuple[nn.Module, Optimizer, Optional[LRScheduler]]: """Setup model, optimizer and scheduler for distributed.""" if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) if torch.cuda.is_available(): # device = self.dist_lrank() model = model.to(self.device()) @@ -425,7 +448,7 @@ def distributed( model, device_ids=[self.device()], output_device=self.device(), - find_unused_parameters=find_unused_parameters + find_unused_parameters=find_unused_parameters, ) else: dist_model = model @@ -440,7 +463,8 @@ def global_world_size(self) -> int: """ if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) return dist.get_world_size() def local_world_size(self) -> int: @@ -452,7 +476,8 @@ def local_world_size(self) -> int: """ if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) return torch.cuda.device_count() def global_rank(self) -> int: @@ -464,7 +489,8 @@ def global_rank(self) -> int: """ if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) return dist.get_rank() def local_rank(self) -> int: @@ -475,14 +501,16 @@ def local_rank(self) -> int: """ if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) return dist.get_rank() % torch.cuda.device_count() def clean_up(self) -> None: """Destroys the current process group.""" if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) if torch.cuda.is_available(): dist.barrier() dist.destroy_process_group() @@ -500,7 +528,8 @@ def allgather_obj(self, obj: Any) -> List[Any]: # https://pytorch.org/docs/stable/distributed.html#collective-functions if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) res = [None] * self.global_world_size() dist.all_gather_object(res, obj) return res @@ -521,7 +550,8 @@ def gather_obj(self, obj: Any, dst_rank: int = 0) -> Optional[List[Any]]: # https://pytorch.org/docs/stable/distributed.html#collective-functions if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) if self.global_rank() == dst_rank: res = [None] * self.global_world_size() dist.gather_object(obj, res, dst=dst_rank) @@ -533,7 +563,8 @@ def gather(self, tensor: torch.Tensor, dst_rank: int = 0) -> Optional[List]: # https://pytorch.org/docs/stable/distributed.html#collective-functions if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) # Ensure that the tensor is on the correct device (CUDA) tensor = tensor.to(self.device()) @@ -541,8 +572,10 @@ def gather(self, tensor: torch.Tensor, dst_rank: int = 0) -> Optional[List]: dist.gather(tensor, dst=dst_rank) return - res = [torch.zeros_like(tensor, device=self.device()) - for _ in range(self.global_world_size())] + res = [ + torch.zeros_like(tensor, device=self.device()) + for _ in range(self.global_world_size()) + ] dist.gather(tensor, gather_list=res, dst=dst_rank) @@ -561,11 +594,12 @@ class DeepSpeedStrategy(TorchDistributedStrategy): """ #: Torch distributed communication backend. - backend: Literal['nccl', 'gloo', 'mpi'] + backend: Literal["nccl", "gloo", "mpi"] - def __init__(self, backend: Literal['nccl', 'gloo', 'mpi']) -> None: + def __init__(self, backend: Literal["nccl", "gloo", "mpi"]) -> None: super().__init__() self.backend = backend + self.name = "deepspeed" def init(self) -> None: """Initializes the distributed process group and the distributed @@ -577,18 +611,18 @@ def init(self) -> None: already initialized. """ import deepspeed + self.deepspeed = deepspeed if not distributed_resources_available(): - raise RuntimeError( - "Trying to run distributed on insufficient resources.") + raise RuntimeError("Trying to run distributed on insufficient resources.") if self.is_initialized: raise DistributedStrategyError("Strategy was already initialized") # https://github.com/Lightning-AI/pytorch-lightning/issues/13567 - ompi_lrank = os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') - os.environ['OMPI_COMM_WORLD_LOCAL_RANK'] = os.environ.get( - 'LOCAL_RANK', ompi_lrank + ompi_lrank = os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK") + os.environ["OMPI_COMM_WORLD_LOCAL_RANK"] = os.environ.get( + "LOCAL_RANK", ompi_lrank ) # https://deepspeed.readthedocs.io/en/latest/initialize.html#training-initialization @@ -598,10 +632,12 @@ def init(self) -> None: self.set_device() def distributed( - self, model: nn.Module, optimizer: Optional[Optimizer] = None, + self, + model: nn.Module, + optimizer: Optional[Optimizer] = None, lr_scheduler: Optional[LRScheduler] = None, model_parameters: Optional[Any] = None, - **init_kwargs + **init_kwargs, ) -> Tuple[nn.Module, Optimizer, Optional[LRScheduler]]: """Setup model, optimizer and scheduler for distributed.""" if not self.is_initialized: @@ -615,7 +651,7 @@ def distributed( optimizer=optimizer, lr_scheduler=lr_scheduler, dist_init_required=True, - **init_kwargs + **init_kwargs, ) return distrib_model, optimizer, lr_scheduler @@ -627,7 +663,8 @@ def global_world_size(self) -> int: """ if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) return dist.get_world_size() def local_world_size(self) -> int: @@ -639,7 +676,8 @@ def local_world_size(self) -> int: """ if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) return torch.cuda.device_count() def global_rank(self) -> int: @@ -651,7 +689,8 @@ def global_rank(self) -> int: """ if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) return dist.get_rank() def local_rank(self) -> int: @@ -662,14 +701,16 @@ def local_rank(self) -> int: """ if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) return dist.get_rank() % torch.cuda.device_count() def clean_up(self) -> None: """Destroys the current process group.""" if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) # deepspeed.sys.exit() # disabled as it kills the execution def allgather_obj(self, obj: Any) -> List[Any]: @@ -685,7 +726,8 @@ def allgather_obj(self, obj: Any) -> List[Any]: # https://pytorch.org/docs/stable/distributed.html#collective-functions if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) res = [None] * self.global_world_size() dist.all_gather_object(res, obj) return res @@ -706,7 +748,8 @@ def gather_obj(self, obj: Any, dst_rank: int = 0) -> Optional[List[Any]]: # https://pytorch.org/docs/stable/distributed.html#collective-functions if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) if self.global_rank() == dst_rank: res = [None] * self.global_world_size() dist.gather_object(obj, res, dst=dst_rank) @@ -718,7 +761,8 @@ def gather(self, tensor: torch.Tensor, dst_rank: int = 0) -> Optional[List]: # https://pytorch.org/docs/stable/distributed.html#collective-functions if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) # Ensure that the tensor is on the correct device (CUDA) tensor = tensor.to(self.device()) @@ -726,8 +770,10 @@ def gather(self, tensor: torch.Tensor, dst_rank: int = 0) -> Optional[List]: dist.gather(tensor, dst=dst_rank) return - res = [torch.zeros_like(tensor, device=self.device()) - for _ in range(self.global_world_size())] + res = [ + torch.zeros_like(tensor, device=self.device()) + for _ in range(self.global_world_size()) + ] dist.gather(tensor, gather_list=res, dst=dst_rank) @@ -738,6 +784,10 @@ def gather(self, tensor: torch.Tensor, dst_rank: int = 0) -> Optional[List]: class HorovodStrategy(TorchDistributedStrategy): """Horovod distributed strategy class.""" + def __init__(self): + super().__init__() + self.name = "horovod" + def init(self) -> None: """Initializes the Horovod distributed backend. @@ -747,12 +797,12 @@ def init(self) -> None: already initialized. """ if not distributed_resources_available(): - raise RuntimeError( - "Trying to run distributed on insufficient resources.") + raise RuntimeError("Trying to run distributed on insufficient resources.") if self.is_initialized: raise DistributedStrategyError("Strategy was already initialized") import horovod.torch as hvd + self.hvd = hvd self.hvd.init() @@ -761,39 +811,38 @@ def init(self) -> None: self.set_device() def distributed( - self, model: nn.Module, optimizer: Optional[Optimizer] = None, + self, + model: nn.Module, + optimizer: Optional[Optimizer] = None, lr_scheduler: Optional[LRScheduler] = None, - **optim_kwargs + **optim_kwargs, ) -> Tuple[nn.Module, Optimizer, Optional[LRScheduler]]: """Setup model, optimizer and scheduler for distributed.""" if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) model.to(self.device()) # Scale learning rate # https://github.com/horovod/horovod/issues/1653#issuecomment-574764452 lr_scaler = 1 - if optim_kwargs.get('op') == self.hvd.Adasum: + if optim_kwargs.get("op") == self.hvd.Adasum: lr_scaler = self.hvd.local_size() - elif optim_kwargs.get('op') == self.hvd.Average: + elif optim_kwargs.get("op") == self.hvd.Average: lr_scaler = self.hvd.size() for g in optimizer.param_groups: - g['lr'] *= lr_scaler + g["lr"] *= lr_scaler self._broadcast_params(model, optimizer) distOptimizer = self.hvd.DistributedOptimizer( - optimizer, - named_parameters=model.named_parameters(), - **optim_kwargs + optimizer, named_parameters=model.named_parameters(), **optim_kwargs ) return model, distOptimizer, lr_scheduler - def _broadcast_params( - self, model: nn.Module, optimizer: optim.Optimizer - ) -> None: + def _broadcast_params(self, model: nn.Module, optimizer: optim.Optimizer) -> None: """Broadcasts variables from root rank to all other processes. Args: @@ -813,7 +862,8 @@ def global_world_size(self) -> int: """ if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) return self.hvd.size() def local_world_size(self) -> int: @@ -825,7 +875,8 @@ def local_world_size(self) -> int: """ if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) return self.hvd.local_size() def global_rank(self) -> int: @@ -837,7 +888,8 @@ def global_rank(self) -> int: """ if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) return self.hvd.rank() def local_rank(self) -> int: @@ -848,14 +900,16 @@ def local_rank(self) -> int: """ if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) return self.hvd.local_rank() def clean_up(self) -> None: """Shuts Horovod down.""" if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) self.hvd.shutdown() def allgather_obj(self, obj: Any) -> list[Any]: @@ -920,6 +974,10 @@ class NonDistributedStrategy(TorchDistributedStrategy): is_distributed: bool = True is_distributed: bool = False + def __init__(self): + super().__init__() + self.name = "non-distributed" + def init(self) -> None: """If CUDA is available set CUDA device, and do nothing more. @@ -941,20 +999,24 @@ def device(self) -> str: """ if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) if torch.cuda.is_available(): return super().device() return "cpu" def distributed( - self, model: nn.Module, optimizer: Optional[Optimizer] = None, + self, + model: nn.Module, + optimizer: Optional[Optimizer] = None, lr_scheduler: Optional[LRScheduler] = None, - **kwargs + **kwargs, ) -> Tuple[nn.Module, Optimizer, Optional[LRScheduler]]: """Do nothing and return model, optimizer and scheduler.""" if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) if torch.cuda.is_available(): model = model.cuda() return model, optimizer, lr_scheduler diff --git a/src/itwinai/torch/monitoring/monitoring.py b/src/itwinai/torch/monitoring/monitoring.py new file mode 100644 index 00000000..9857ccb7 --- /dev/null +++ b/src/itwinai/torch/monitoring/monitoring.py @@ -0,0 +1,176 @@ +import functools +import time +from multiprocessing import Manager, Process +from pathlib import Path +from typing import Any, Callable, Dict, List + +import pandas as pd +import pynvml +from pynvml import nvmlDeviceGetHandleByIndex, nvmlInit + +from itwinai.torch.trainer import TorchTrainer + +logging_columns = [ + "sample_idx", + "utilization", + "power", + "local_rank", + "node_idx", + "num_global_gpus", + "strategy", + "probing_interval", +] + + +def probe_gpu_utilization_loop( + node_idx: int, + num_local_gpus: int, + num_global_gpus: int, + strategy_name: str, + log_dict: Any, + stop_flag: Any, + probing_interval: int = 2, + warmup_time: int = 5, +) -> None: + """Logs the GPU utilization across all availble GPUs on a single node. Is meant to + be called by multiprocessing's Process and expects variables to be shared using + a multiprocessing.Manager object. Logs utilization into `log_dict` until + stop_flag.value is set to True. + + Args: + node_idx: The index of the compute node that the function is called by, used + for logging purposes. + num_local_gpus: Number of GPUs on the current compute node. + num_global_gpus: Number of GPUs on all nodes combined. + strategy: Which distributed strategy is being used, e.g. "ddp" or "horovod". + log_dict: Dictionary for storing logging data on. Should be managed by a + multiprocessing.Manager object. + stop_flag: Shared value telling the function when to stop logging. Should be + managed by a multiprocessing.Manager object. + probing_interval: How long to wait between each time a read of the GPU + utilization is done. + warmup_time: How long to wait before logging starts, allowing the training to + properly start before reading. + + """ + + if not set(logging_columns).issubset(set(log_dict.keys())): + missing_columns = set(logging_columns) - set(log_dict.keys()) + raise ValueError( + f"log_dict is missing the following columns: {missing_columns}" + ) + + nvmlInit() + time.sleep(warmup_time) + + sample_idx = 0 + while not stop_flag.value: + for idx in range(num_local_gpus): + handle = nvmlDeviceGetHandleByIndex(idx) + utilization_rates = pynvml.nvmlDeviceGetUtilizationRates(handle) + + gpu_util = utilization_rates.gpu + power = pynvml.nvmlDeviceGetPowerUsage(handle) + power = power / 1000 # mW -> W + + log_dict["sample_idx"].append(sample_idx) + log_dict["utilization"].append(gpu_util) + log_dict["power"].append(power) + log_dict["local_rank"].append(idx) + log_dict["node_idx"].append(node_idx) + log_dict["num_global_gpus"].append(num_global_gpus) + log_dict["strategy"].append(strategy_name) + log_dict["probing_interval"].append(probing_interval) + + sample_idx += 1 + + time.sleep(probing_interval) + + +def measure_gpu_utilization(method: Callable) -> Callable: + """Decorator for measuring GPU utilization and storing it to a .csv file.""" + + def write_logs_to_file(utilization_logs: List[Dict], output_path: Path) -> None: + dataframes = [] + for log in utilization_logs: + if len(log) == 0: + continue + dataframes.append(pd.DataFrame(log)) + + log_df = pd.concat(dataframes) + log_df.to_csv(output_path, index=False) + print(f"Writing GPU energy dataframe to '{output_path}'.") + + @functools.wraps(method) + def measured_method(self: TorchTrainer, *args, **kwargs) -> Any: + gpu_probing_interval = 1 + warmup_time = 5 + + strategy = self.strategy + strategy_name = strategy.name + + local_rank = strategy.local_rank() + global_rank = strategy.global_rank() + num_global_gpus = strategy.global_world_size() + num_local_gpus = strategy.local_world_size() + node_idx = global_rank // num_local_gpus + + output_path = Path( + f"scalability_metrics/gpu_energy_data_{strategy_name}_{num_global_gpus}.csv" + ) + output_path.parent.mkdir(exist_ok=True, parents=True) + + gpu_monitor_process = None + manager = None + stop_flag = None + data = None + + # Starting a child process once per node + if local_rank == 0: + + # Setting up shared variables for the child process + manager = Manager() + data = manager.dict() + for col in logging_columns: + data[col] = manager.list() + stop_flag = manager.Value("i", False) + + gpu_monitor_process = Process( + target=probe_gpu_utilization_loop, + kwargs={ + "node_idx": node_idx, + "num_local_gpus": num_local_gpus, + "num_global_gpus": num_global_gpus, + "strategy_name": strategy_name, + "log_dict": data, + "stop_flag": stop_flag, + "probing_interval": gpu_probing_interval, + "warmup_time": warmup_time, + }, + ) + gpu_monitor_process.start() + + local_utilization_log = {} + try: + result = method(self, *args, **kwargs) + finally: + if local_rank == 0: + stop_flag.value = True + grace_period = 5 # extra time to let process finish gracefully + gpu_monitor_process.join(timeout=gpu_probing_interval + grace_period) + + # Converting the shared log to non-shared log + local_utilization_log = {key: list(data[key]) for key in data.keys()} + manager.shutdown() + + global_utilization_log = strategy.gather_obj(local_utilization_log, dst_rank=0) + if strategy.is_main_worker: + output_dir = Path("scalability_metrics/gpu_energy_data") + output_dir.mkdir(exist_ok=True, parents=True) + output_path = output_dir / f"{strategy_name}_{num_global_gpus}.csv" + + write_logs_to_file(global_utilization_log, output_path) + + return result + + return measured_method diff --git a/src/itwinai/torch/monitoring/plotting.py b/src/itwinai/torch/monitoring/plotting.py new file mode 100644 index 00000000..554f5dc4 --- /dev/null +++ b/src/itwinai/torch/monitoring/plotting.py @@ -0,0 +1,140 @@ +from pathlib import Path +from re import Match, Pattern, compile +from typing import Optional, Tuple, Union + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from matplotlib.axes import Axes +from matplotlib.figure import Figure +from scipy.constants import hour as SECONDS_IN_HOUR + +matplotlib.use("Agg") + + +def read_energy_df(pattern: Optional[str], log_dir: Path) -> pd.DataFrame: + """Read files matching the given regex pattern from directory and converting them + into a Pandas DataFrame. If pattern is None, we assume a match on all files. + Expects that the existence of ``log_dir`` is handled before calling this function. + + Args: + pattern: The regex string used to match files. + log_dir: The directory to search for files in. + + Raises: + ValueError: If no matching files are found in the given logging directory. + """ + + pattern_re: Optional[Pattern] = None + if pattern is not None: + pattern_re = compile(pattern) + + # Load and concatenate dataframes + dataframes = [] + for entry in log_dir.iterdir(): + match: Union[bool, Match] = True + if pattern_re is not None: + match = pattern_re.search(str(entry)) + + if not match: + continue + + print(f"Loading data from file: '{entry}' when creating energy DataFrame") + df = pd.read_csv(entry) + dataframes.append(df) + + if len(dataframes) == 0: + if pattern is None: + error_message = f"Unable to find any files in {log_dir.resolve()}!" + else: + error_message = ( + f"No files matched pattern, '{pattern}', in log_dir, " + f"{log_dir.resolve()}!" + ) + raise ValueError(error_message) + + return pd.concat(dataframes) + + +def calculate_aggregated_energy_expenditure( + gpu_utilization_df: pd.DataFrame, +) -> pd.DataFrame: + """Calculates the total energy expenditure in Watt hours for each strategy and + number of GPUs. Expects that the existence of the appropriate DataFrame columns is + handled before calling this function. + + Returns: + pd.DataFrame: A DataFrame containing the total expenditure in Watt hours for + each strategy and number of GPUs, with the columns ``strategy``, + ``num_global_gpus`` and ``total_energy_wh``. + """ + energy_data = [] + + grouped_df = gpu_utilization_df.groupby(["strategy", "num_global_gpus"]) + for (strategy, num_gpus), group in grouped_df: + + if len(group["probing_interval"].unique()) != 1: + raise ValueError( + f"probing_interval must have the same value for each strategy and " + f"number of GPUs, but was heterogeneous for strategy: {strategy} " + f"and number of GPUs: {num_gpus}." + ) + + probing_interval = group["probing_interval"].iloc[0] + total_energy_wh = group["power"].sum() * probing_interval / SECONDS_IN_HOUR + energy_data.append( + { + "strategy": strategy, + "num_global_gpus": num_gpus, + "total_energy_wh": total_energy_wh, + } + ) + return pd.DataFrame(energy_data) + + +def gpu_energy_plot(gpu_utilization_df: pd.DataFrame) -> Tuple[Figure, Axes]: + """Makes an energy bar plot of the GPU utilization dataframe, showing the total + energy expenditure for each strategy and number of GPUs in Watt hours. + """ + required_columns = {"strategy", "power", "num_global_gpus", "probing_interval"} + if not required_columns.issubset(set(gpu_utilization_df.columns)): + missing_columns = set(required_columns) - set(set(gpu_utilization_df.columns)) + raise ValueError( + f"DataFrame is missing the following columns: {missing_columns}" + ) + sns.set_theme() + + energy_df = calculate_aggregated_energy_expenditure(gpu_utilization_df) + + strategies = energy_df["strategy"].unique() + unique_gpu_counts = np.array(energy_df["num_global_gpus"].unique()) + + fig, ax = plt.subplots() + x = np.arange(len(unique_gpu_counts)) + + bar_width = 1 / (len(strategies) + 1) + static_offset = (len(strategies) - 1) / 2 + for strategy_idx, strategy in enumerate(strategies): + dynamic_bar_offset = strategy_idx - static_offset + strategy_data = energy_df[energy_df["strategy"] == strategy] + + # Ensuring the correct spacing of the bars + strategy_num_gpus = len(strategy_data["num_global_gpus"]) + + ax.bar( + x=x[:strategy_num_gpus] + dynamic_bar_offset * bar_width, + height=strategy_data["total_energy_wh"], + width=bar_width, + label=strategy, + ) + + ax.set_xlabel("Num GPUs") + ax.set_ylabel("Energy Consumption (Wh)") + ax.set_title("Energy Consumption by Strategy and Number of GPUs") + ax.set_xticks(x) + ax.set_xticklabels(unique_gpu_counts) + ax.legend(title="Strategy") + + return fig, ax diff --git a/src/itwinai/torch/profiling/communication_plot.py b/src/itwinai/torch/profiling/communication_plot.py index 23dec5a0..285a62da 100644 --- a/src/itwinai/torch/profiling/communication_plot.py +++ b/src/itwinai/torch/profiling/communication_plot.py @@ -1,6 +1,6 @@ from pathlib import Path -from re import Pattern, compile -from typing import Any, List, Tuple +from re import Match, Pattern, compile +from typing import Any, List, Optional, Tuple, Union import matplotlib import matplotlib.pyplot as plt @@ -85,7 +85,7 @@ def create_stacked_plot( fig, ax = plt.subplots() # Creating an offset to "center" around zero - static_offset = len(strategy_labels) / 2 - 0.5 + static_offset = (len(strategy_labels) - 1) / 2 for strategy_idx in range(len(strategy_labels)): dynamic_bar_offset = strategy_idx - static_offset @@ -142,15 +142,21 @@ def create_stacked_plot( return fig, ax -def create_combined_comm_overhead_df(logs_dir: Path, pattern: str) -> pd.DataFrame: +def create_combined_comm_overhead_df( + log_dir: Path, pattern: Optional[str] +) -> pd.DataFrame: """Reads and combines all files in a folder that matches the given regex pattern - into a single DataFrame. The files must be formatted as csv files. + into a single DataFrame. The files must be formatted as csv files. If pattern is + None, we assume a match on all files. Raises: ValueError: If not all expected columns are found in the stored DataFrame. ValueError: If no matching files are found in the given logging directory. """ - re_pattern: Pattern = compile(pattern) + re_pattern: Optional[Pattern] = None + if pattern is not None: + re_pattern = compile(pattern) + dataframes = [] expected_columns = { "strategy", @@ -159,8 +165,11 @@ def create_combined_comm_overhead_df(logs_dir: Path, pattern: str) -> pd.DataFra "name", "self_cuda_time_total", } - for entry in logs_dir.iterdir(): - match = re_pattern.search(str(entry)) + for entry in log_dir.iterdir(): + match: Union[bool, Match] = True + if re_pattern is not None: + match = re_pattern.search(str(entry)) + if not match: continue @@ -168,15 +177,22 @@ def create_combined_comm_overhead_df(logs_dir: Path, pattern: str) -> pd.DataFra if not expected_columns.issubset(df.columns): missing_columns = expected_columns - set(df.columns) raise ValueError( - f"Invalid data format! File at '{match.string}' doesn't contain all" + f"Invalid data format! File at '{str(entry)}' doesn't contain all" f" necessary columns. \nMissing columns: {missing_columns}" ) dataframes.append(df) + if len(dataframes) == 0: - raise ValueError( - f"No matching files found in '{logs_dir.resolve()}' for pattern '{pattern}'" - ) + if pattern is None: + error_message = f"Unable to find any files in {log_dir.resolve()}!" + else: + error_message = ( + f"No files matched pattern, '{pattern}', in log_dir, " + f"{log_dir.resolve()}!" + ) + raise ValueError(error_message) + return pd.concat(dataframes) diff --git a/src/itwinai/torch/profiling/profiler.py b/src/itwinai/torch/profiling/profiler.py index 78c740e7..7ff43665 100644 --- a/src/itwinai/torch/profiling/profiler.py +++ b/src/itwinai/torch/profiling/profiler.py @@ -2,18 +2,12 @@ import functools from pathlib import Path -from typing import Any, Callable, Iterable +from typing import Any, Callable, Iterable, Tuple import matplotlib import pandas as pd from torch.profiler import ProfilerActivity, profile, schedule -from itwinai.torch.distributed import ( - DeepSpeedStrategy, - HorovodStrategy, - NonDistributedStrategy, - TorchDDPStrategy, -) from itwinai.torch.trainer import TorchTrainer # Doing this because otherwise I get an error about X11 Forwarding which I believe @@ -44,17 +38,59 @@ def gather_profiling_data(key_averages: Iterable) -> pd.DataFrame: ) return pd.DataFrame(profiling_data) + def adjust_wait_and_warmup_epochs( + training_epochs: int, wait_epochs: int, warmup_epochs: int + ) -> Tuple[int, int, int]: + """Validates if the given wait and warmup epochs are compatible and if not, + adjusts them so they fit. The largest one is iteratively decreased until + a compatible value is reached. + + Returns: + int: The resulting number of epochs for doing active profiling + int: The resulting number of wait epochs, possibly adjusted + int: The resulting number of warmup epochs, possibly adjusted + """ + active_epochs = training_epochs - wait_epochs - warmup_epochs + if active_epochs > 0: + return active_epochs, wait_epochs, warmup_epochs + + # This can probably be done with a simple math expression, but this was + # simpler to implement and won't really cause much overhead anyway... + while active_epochs <= 0: + if wait_epochs > warmup_epochs: + wait_epochs -= 1 + else: + warmup_epochs -= 1 + active_epochs = training_epochs - wait_epochs - warmup_epochs + + if wait_epochs < 0 or warmup_epochs < 0: + raise ValueError( + f"Unable to adjust wait and warmup epochs to accomodate the" + f"given number of training epochs. Was given the following values: " + f"Training epochs: {training_epochs}, wait epochs: {wait_epochs}" + f", warmup epochs: {warmup_epochs}" + ) + print( + f"Warning: adjusted the given wait and warmup epochs for the profiler - " + f"wait epochs: {wait_epochs}, warmup epochs: {warmup_epochs}." + ) + return active_epochs, wait_epochs, warmup_epochs + @functools.wraps(method) def profiled_method(self: TorchTrainer, *args, **kwargs) -> Any: + active_epochs, wait_epochs, warmup_epochs = adjust_wait_and_warmup_epochs( + training_epochs=self.epochs, + wait_epochs=self.profiling_wait_epochs, + warmup_epochs=self.profiling_warmup_epochs, + ) profiler = profile( activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], with_modules=True, schedule=schedule( - # skip_first=1 - wait=1, - warmup=2, - active=100, + wait=wait_epochs, + warmup=warmup_epochs, + active=active_epochs, ), ) profiler.start() @@ -66,16 +102,7 @@ def profiled_method(self: TorchTrainer, *args, **kwargs) -> Any: profiler.stop() strategy = self.strategy - if isinstance(strategy, NonDistributedStrategy): - strategy_str = "non-dist" - elif isinstance(strategy, TorchDDPStrategy): - strategy_str = "ddp" - elif isinstance(strategy, DeepSpeedStrategy): - strategy_str = "deepspeed" - elif isinstance(strategy, HorovodStrategy): - strategy_str = "horovod" - else: - strategy_str = "unk" + strategy_name = strategy.name global_rank = strategy.global_rank() num_gpus_global = strategy.global_world_size() @@ -83,19 +110,18 @@ def profiled_method(self: TorchTrainer, *args, **kwargs) -> Any: # Extracting and storing the profiling data key_averages = profiler.key_averages() profiling_dataframe = gather_profiling_data(key_averages=key_averages) - profiling_dataframe["strategy"] = strategy_str + profiling_dataframe["strategy"] = strategy_name profiling_dataframe["num_gpus"] = num_gpus_global profiling_dataframe["global_rank"] = global_rank - profiling_log_dir = Path("profiling_logs") + profiling_log_dir = Path("scalability_metrics/communication_data") profiling_log_dir.mkdir(parents=True, exist_ok=True) - filename = f"profile_{strategy_str}_{num_gpus_global}_{global_rank}.csv" + filename = f"{strategy_name}_{num_gpus_global}_{global_rank}.csv" output_path = profiling_log_dir / filename - print(f"Writing profiling dataframe to {output_path}") + print(f"Writing communication profiling dataframe to '{output_path}'.") profiling_dataframe.to_csv(output_path) - strategy.clean_up() return result diff --git a/src/itwinai/torch/trainer.py b/src/itwinai/torch/trainer.py index fa73289a..d8cbdf7e 100644 --- a/src/itwinai/torch/trainer.py +++ b/src/itwinai/torch/trainer.py @@ -4,7 +4,6 @@ import sys from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union -import horovod.torch as hvd import lightning as L import matplotlib.pyplot as plt import numpy as np @@ -114,7 +113,9 @@ def __init__( metrics: Optional[Dict[str, Metric]] = None, checkpoints_location: str = "checkpoints", checkpoint_every: Optional[int] = None, - name: Optional[str] = None + name: Optional[str] = None, + profiling_wait_epochs: int = 1, + profiling_warmup_epochs: int = 2 ) -> None: super().__init__(name) self.save_parameters(**self.locals2params(locals())) @@ -137,6 +138,8 @@ def __init__( os.makedirs(self.checkpoints_location, exist_ok=True) self.checkpoint_every = checkpoint_every self.profiler = None + self.profiling_wait_epochs = profiling_wait_epochs + self.profiling_warmup_epochs = profiling_warmup_epochs @property def strategy(self) -> TorchDistributedStrategy: @@ -158,17 +161,17 @@ def device(self) -> str: def _detect_strategy(self, strategy: str) -> TorchDistributedStrategy: if strategy is None or not distributed_resources_available(): print("WARNING: falling back to non-distributed strategy.") - dist_str = NonDistributedStrategy() + strategy_obj = NonDistributedStrategy() elif strategy == 'ddp': - dist_str = TorchDDPStrategy(backend='nccl') + strategy_obj = TorchDDPStrategy(backend='nccl') elif strategy == 'horovod': - dist_str = HorovodStrategy() + strategy_obj = HorovodStrategy() elif strategy == 'deepspeed': - dist_str = DeepSpeedStrategy(backend='nccl') + strategy_obj = DeepSpeedStrategy(backend='nccl') else: raise NotImplementedError( f"Strategy '{strategy}' is not recognized/implemented.") - return dist_str + return strategy_obj def _init_distributed_strategy(self) -> None: if not self.strategy.is_initialized: @@ -222,6 +225,31 @@ def _loss_from_config(self) -> None: "create_model_loss_optimizer method for more flexibility." ) + def get_default_distributed_kwargs(self) -> Dict: + """Gives the default kwargs for the trainer's strategy's distributed() method.""" + + if isinstance(self.strategy, DeepSpeedStrategy): + # Batch size definition is not optional for DeepSpeedStrategy! + distribute_kwargs = dict( + config_params=dict( + train_micro_batch_size_per_gpu=self.config.batch_size + ) + ) + elif isinstance(self.strategy, HorovodStrategy): + import horovod as hvd + distribute_kwargs = dict( + compression=( + hvd.Compression.fp16 if self.config.fp16_allreduce + else hvd.Compression.none + ), + op=hvd.Adasum if self.config.use_adasum else hvd.Average, + gradient_predivide_factor=self.config.gradient_predivide_factor + ) + else: + distribute_kwargs = {} + + return distribute_kwargs + def create_model_loss_optimizer(self) -> None: """ Instantiate a torch model, loss, optimizer, and LR scheduler using the @@ -248,26 +276,7 @@ def create_model_loss_optimizer(self) -> None: self._loss_from_config() # IMPORTANT: model, optimizer, and scheduler need to be distributed - - # First, define strategy-wise optional configurations - if isinstance(self.strategy, DeepSpeedStrategy): - # Batch size definition is not optional for DeepSpeedStrategy! - distribute_kwargs = dict( - config_params=dict( - train_micro_batch_size_per_gpu=self.config.batch_size - ) - ) - elif isinstance(self.strategy, HorovodStrategy): - distribute_kwargs = dict( - compression=( - hvd.Compression.fp16 if self.config.fp16_allreduce - else hvd.Compression.none - ), - op=hvd.Adasum if self.config.use_adasum else hvd.Average, - gradient_predivide_factor=self.config.gradient_predivide_factor - ) - else: - distribute_kwargs = {} + distribute_kwargs = self.get_default_distributed_kwargs() # Distributed model, optimizer, and scheduler ( @@ -375,7 +384,7 @@ def execute( if self.logger: self.logger.destroy_logger_context() - # self.strategy.clean_up() + self.strategy.clean_up() return train_dataset, validation_dataset, test_dataset, self.model def _set_epoch_dataloaders(self, epoch: int): @@ -527,13 +536,12 @@ def train(self): # Checkpointing current best model worker_val_losses = self.strategy.gather(val_loss, dst_rank=0) if self.strategy.is_main_worker: - avg_loss = torch.mean( - torch.stack(worker_val_losses) - ).detach().cpu() - if avg_loss < best_loss: + avg_loss = torch.mean(torch.stack(worker_val_losses)).detach().cpu() + if avg_loss < best_loss and self.checkpoint_every is not None: ckpt_name = "best_model.pth" self.save_checkpoint( - name=ckpt_name, epoch=epoch, loss=avg_loss) + name=ckpt_name, epoch=epoch, loss=avg_loss + ) best_loss = avg_loss if self.test_every and epoch_n % self.test_every == 0: diff --git a/use-cases/eurac/config.yaml b/use-cases/eurac/config.yaml index 8912e898..47b7101f 100644 --- a/use-cases/eurac/config.yaml +++ b/use-cases/eurac/config.yaml @@ -6,7 +6,7 @@ tmp_stats: /p/scratch/intertwin/datasets/eurac/stats experiment: "drought use case lstm" run_name: "alps_test" -epochs: 5 +epochs: 3 random_seed: 1010 lr: 0.001 batch_size: 256 @@ -57,6 +57,8 @@ rnn_training_pipeline: strategy: ${strategy} epochs: ${epochs} random_seed: ${random_seed} + profiling_wait_epochs: 0 + profiling_warmup_epochs: 0 logger: class_path: itwinai.loggers.LoggersCollection init_args: diff --git a/use-cases/eurac/plots/comm_plot.png b/use-cases/eurac/plots/comm_plot.png deleted file mode 100644 index b406b1a0..00000000 Binary files a/use-cases/eurac/plots/comm_plot.png and /dev/null differ diff --git a/use-cases/eurac/plots/communication_plot.png b/use-cases/eurac/plots/communication_plot.png new file mode 100644 index 00000000..c40459d0 Binary files /dev/null and b/use-cases/eurac/plots/communication_plot.png differ diff --git a/use-cases/eurac/plots/gpu_energy_plot.png b/use-cases/eurac/plots/gpu_energy_plot.png new file mode 100644 index 00000000..618ec8c8 Binary files /dev/null and b/use-cases/eurac/plots/gpu_energy_plot.png differ diff --git a/use-cases/eurac/trainer.py b/use-cases/eurac/trainer.py index 628ac66d..770c8259 100644 --- a/use-cases/eurac/trainer.py +++ b/use-cases/eurac/trainer.py @@ -28,6 +28,7 @@ from itwinai.torch.trainer import TorchTrainer from itwinai.torch.type import Metric from itwinai.torch.profiling.profiler import profile_torch_trainer +from itwinai.torch.monitoring.monitoring import measure_gpu_utilization class RNNDistributedTrainer(TorchTrainer): @@ -92,7 +93,8 @@ def __init__( self.save_parameters(**self.locals2params(locals())) @suppress_workers_print - @profile_torch_trainer + # @profile_torch_trainer + # @measure_gpu_utilization def execute( self, train_dataset: Dataset, @@ -146,6 +148,8 @@ def set_epoch(self, epoch: int): self.train_loader.sampler.set_epoch(epoch) self.val_loader.sampler.set_epoch(epoch) + @profile_torch_trainer + @measure_gpu_utilization def train(self): """Override version of hython to support distributed strategy.""" # Tracking epoch times for scaling test