Skip to content

Commit

Permalink
replace cuda to device
Browse files Browse the repository at this point in the history
  • Loading branch information
ShawnXuan committed Aug 21, 2024
1 parent 1aaccff commit c434086
Showing 1 changed file with 4 additions and 10 deletions.
14 changes: 4 additions & 10 deletions libai/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,7 @@ def device_type(self):
return self._device_type

def set_device_type(self, device_type):
assert device_type in [
"cpu",
"cuda",
"npu",
"xpu",
"mlu",
], f"not supported for {device_type}"
#assert device in ["cpu", "cuda"], f"not supported for device:{device}"
self._device_type = device_type

def get_layer_ranks(self, layer_idx):
Expand Down Expand Up @@ -441,10 +435,10 @@ def convert_to_distributed_default_setting(t):
return t.to_global(placement=flow.placement(device_type, ranks=t.placement.ranks))


def ttol(tensor, pure_local=False, ranks=None):
def ttol(tensor, pure_local=False, device="cuda", ranks=None):
"""Global tensor to local tensor."""
if tensor.is_global:
placement = tensor.placement if not ranks else flow.placement("cuda", ranks)
placement = tensor.placement if not ranks else flow.placement(device, ranks)
if pure_local:
tensor = tensor.to_global(placement=placement).to_local()
else:
Expand All @@ -465,7 +459,7 @@ def tton(tensor, local_only=False, ranks=None):

def tensor_to_rank0(tensor, device="cuda", to_local=False):
"""Global tensor to rank0."""
assert device in ["cpu", "cuda"], f"not supported for device:{device}"
#assert device in ["cpu", "cuda"], f"not supported for device:{device}"
if tensor.is_global:
# Consider if it's 2d mesh, ranks should be [[0]] instead of [0]
placement = flow.placement(device, ranks=[0] if tensor.placement.ranks.ndim == 1 else [[0]])
Expand Down

0 comments on commit c434086

Please sign in to comment.