Skip to content

Commit

Permalink
Avoid importing apex transformer automatically and make error message…
Browse files Browse the repository at this point in the history
…s more clear when apex.transformer is explicitly called on unsupported platform
  • Loading branch information
nWEIdia committed May 8, 2024
1 parent a7de60e commit 64aea56
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 1 deletion.
1 change: 0 additions & 1 deletion apex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
# load time) the error message is timely and visible.
from . import optimizers
from . import normalization
from . import transformer


# Logging utilities for apex.transformer module
Expand Down
1 change: 1 addition & 0 deletions apex/transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# The following 4 lines are for backward comparability with
# older PyTorch.
if "all_gather_into_tensor" not in dir(torch.distributed):
assert torch.distributed.is_available(), "PyTorch Distributed is Not available or Disabled."
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base

def ensure_divisibility(numerator, denominator):
Expand Down

0 comments on commit 64aea56

Please sign in to comment.