Skip to content

Commit

Permalink
Fix is_torch_tpu_available in ORT Trainer (#2028)
Browse files Browse the repository at this point in the history
  • Loading branch information
regisss authored Sep 18, 2024
1 parent bf1befd commit 2fb5ea5
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions optimum/onnxruntime/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,14 @@
from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled

if check_if_transformers_greater("4.39"):
from transformers.utils import is_torch_xla_available
from transformers.utils import is_torch_xla_available as is_torch_tpu_xla_available

if is_torch_xla_available():
if is_torch_tpu_xla_available():
import torch_xla.core.xla_model as xm
else:
from transformers.utils import is_torch_tpu_available
from transformers.utils import is_torch_tpu_available as is_torch_tpu_xla_available

if is_torch_tpu_available(check_device=False):
if is_torch_tpu_xla_available(check_device=False):
import torch_xla.core.xla_model as xm

if TYPE_CHECKING:
Expand Down Expand Up @@ -735,7 +735,7 @@ def get_dataloader_sampler(dataloader):

if (
args.logging_nan_inf_filter
and not is_torch_tpu_available()
and not is_torch_tpu_xla_available()
and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
):
# if loss is nan or inf simply add the average of previous logged losses
Expand Down

0 comments on commit 2fb5ea5

Please sign in to comment.