From c434086aca9544b9fbfe9344a86a2517371df33b Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Wed, 21 Aug 2024 14:50:36 +0800 Subject: [PATCH] replace cuda to device --- libai/utils/distributed.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/libai/utils/distributed.py b/libai/utils/distributed.py index e36c3c82c..f1f8cd59b 100644 --- a/libai/utils/distributed.py +++ b/libai/utils/distributed.py @@ -228,13 +228,7 @@ def device_type(self): return self._device_type def set_device_type(self, device_type): - assert device_type in [ - "cpu", - "cuda", - "npu", - "xpu", - "mlu", - ], f"not supported for {device_type}" + #assert device in ["cpu", "cuda"], f"not supported for device:{device}" self._device_type = device_type def get_layer_ranks(self, layer_idx): @@ -441,10 +435,10 @@ def convert_to_distributed_default_setting(t): return t.to_global(placement=flow.placement(device_type, ranks=t.placement.ranks)) -def ttol(tensor, pure_local=False, ranks=None): +def ttol(tensor, pure_local=False, device="cuda", ranks=None): """Global tensor to local tensor.""" if tensor.is_global: - placement = tensor.placement if not ranks else flow.placement("cuda", ranks) + placement = tensor.placement if not ranks else flow.placement(device, ranks) if pure_local: tensor = tensor.to_global(placement=placement).to_local() else: @@ -465,7 +459,7 @@ def tton(tensor, local_only=False, ranks=None): def tensor_to_rank0(tensor, device="cuda", to_local=False): """Global tensor to rank0.""" - assert device in ["cpu", "cuda"], f"not supported for device:{device}" + #assert device in ["cpu", "cuda"], f"not supported for device:{device}" if tensor.is_global: # Consider if it's 2d mesh, ranks should be [[0]] instead of [0] placement = flow.placement(device, ranks=[0] if tensor.placement.ranks.ndim == 1 else [[0]])