Skip to content

Commit

Permalink
using tensor device in eval distributed (#557)
Browse files Browse the repository at this point in the history
* using tensor device in eval distributed

* tensor_to_rank0 device
  • Loading branch information
fpzh2011 authored Oct 28, 2024
1 parent 13f1b12 commit 9dcbe3b
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 3 deletions.
4 changes: 2 additions & 2 deletions libai/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,12 @@ def inference_on_dataset(

# get valid sample
valid_data = {
key: dist.tensor_to_rank0(value, to_local=True)[:valid_sample]
key: dist.tensor_to_rank0(value, device=value.placement.type, to_local=True)[:valid_sample]
for key, value in data.items()
}
valid_outputs = {}
for key, value in outputs.items():
value = dist.tensor_to_rank0(value, to_local=True)
value = dist.tensor_to_rank0(value, device=value.placement.type, to_local=True)
if value.ndim > 1:
valid_outputs[key] = value[:valid_sample] # Slice if it's batched output
else:
Expand Down
1 change: 0 additions & 1 deletion libai/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,6 @@ def tton(tensor, local_only=False, ranks=None):

def tensor_to_rank0(tensor, device="cuda", to_local=False):
"""Global tensor to rank0."""
# assert device in ["cpu", "cuda"], f"not supported for device:{device}"
if tensor.is_global:
# Consider if it's 2d mesh, ranks should be [[0]] instead of [0]
placement = flow.placement(device, ranks=[0] if tensor.placement.ranks.ndim == 1 else [[0]])
Expand Down

0 comments on commit 9dcbe3b

Please sign in to comment.