diff --git a/optimum/habana/transformers/training_args.py b/optimum/habana/transformers/training_args.py index 4a2b12593..dbc6ab60b 100644 --- a/optimum/habana/transformers/training_args.py +++ b/optimum/habana/transformers/training_args.py @@ -101,6 +101,8 @@ class GaudiTrainingArguments(TrainingArguments): Whether to use compiled autograd for training. Currently only for summarization models. compile_dynamic (`bool|None`, *optional*, defaults to `None`): Set value of 'dynamic' parameter for torch.compile. + cache_size_limit(`int`, *optional*, defaults to 'None'): + Set value of 'cache_size_limit' parameter for torch._dynamo.config disable_tensor_cache_hpu_graphs (`bool`, *optional*, defaults to `False`): Whether to disable tensor cache when using hpu graphs. If True, tensors won't be cached in hpu graph and memory can be saved. max_hpu_graphs (`int`, *optional*): @@ -170,6 +172,11 @@ class GaudiTrainingArguments(TrainingArguments): metadata={"help": ("Set value of 'dynamic' parameter for torch.compile.")}, ) + cache_size_limit: Optional[int] = field( + default=None, + metadata={"help": "Set value of 'cache_size_limit' parameter for torch._dynamo.config."}, + ) + disable_tensor_cache_hpu_graphs: Optional[bool] = field( default=False, metadata={"help": "Whether to use a tensor cache for hpu graphs."}, @@ -860,6 +867,9 @@ def _setup_devices(self) -> "torch.device": if self.sdp_on_bf16: torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) + if self.torch_compile and self.cache_size_limit is not None: + torch._dynamo.config.cache_size_limit = self.cache_size_limit + logger.info("PyTorch: setting up devices") if not is_accelerate_available(): raise ImportError(