Skip to content

Commit

Permalink
In PyTorch 2.5, inline_inbuilt_nn_module was enabled by default, lead…
Browse files Browse the repository at this point in the history
…ing to performance degradation in certain models(e.g. ALBERT_XXL). A configuration option has been introduced to turn off this behavior
  • Loading branch information
chaojun-zhang committed Dec 17, 2024
1 parent dd54fda commit 4c1b7f8
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions optimum/habana/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ class GaudiTrainingArguments(TrainingArguments):
Whether to use compiled autograd for training. Currently only for summarization models.
compile_dynamic (`bool|None`, *optional*, defaults to `None`):
Set value of 'dynamic' parameter for torch.compile.
inline_inbuilt_nn_modules (`bool`, *optional*, defaults to `None`):
Set value of 'inline_inbuilt_nn_modules' parameter for torch._dynamo.config.
disable_tensor_cache_hpu_graphs (`bool`, *optional*, defaults to `False`):
Whether to disable tensor cache when using hpu graphs. If True, tensors won't be cached in hpu graph and memory can be saved.
max_hpu_graphs (`int`, *optional*):
Expand Down Expand Up @@ -170,6 +172,11 @@ class GaudiTrainingArguments(TrainingArguments):
metadata={"help": ("Set value of 'dynamic' parameter for torch.compile.")},
)

inline_inbuilt_nn_modules: Optional[bool] = field(
default=None,
metadata={"help": ("Set value of 'inline_inbuilt_nn_modules' parameter for torch._dynamo.config.")},
)

disable_tensor_cache_hpu_graphs: Optional[bool] = field(
default=False,
metadata={"help": "Whether to use a tensor cache for hpu graphs."},
Expand Down Expand Up @@ -860,6 +867,9 @@ def _setup_devices(self) -> "torch.device":
if self.sdp_on_bf16:
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)

if self.inline_inbuilt_nn_modules is not None:
torch._dynamo.config.inline_inbuilt_nn_modules = self.inline_inbuilt_nn_modules

logger.info("PyTorch: setting up devices")
if not is_accelerate_available():
raise ImportError(
Expand Down

0 comments on commit 4c1b7f8

Please sign in to comment.