diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index a0694c41c5..47946d3037 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -105,8 +105,7 @@ def get_trainer( local_rank = os.environ.get("LOCAL_RANK") if local_rank is not None: local_rank = int(local_rank) - assert dist.is_nccl_available() - dist.init_process_group(backend="nccl") + dist.init_process_group(backend="cuda:nccl,cpu:gloo") def prepare_trainer_input_single( model_params_single, data_dict_single, rank=0, seed=None