diff --git a/apex/transformer/utils.py b/apex/transformer/utils.py index 0991bd862..39d5d7668 100644 --- a/apex/transformer/utils.py +++ b/apex/transformer/utils.py @@ -8,7 +8,8 @@ # The following 4 lines are for backward comparability with # older PyTorch. if "all_gather_into_tensor" not in dir(torch.distributed): - assert torch.distributed.is_available(), "PyTorch Distributed is Not available or Disabled." + if not torch.distributed.is_available(): + raise RuntimeError("PyTorch Distributed is Not available or Disabled.") torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base def ensure_divisibility(numerator, denominator):