From 0d499d87c585639920e0ce0bf5e39f4f60c267c3 Mon Sep 17 00:00:00 2001 From: Jianhua Zheng Date: Sat, 26 Oct 2024 16:53:12 +0800 Subject: [PATCH] tensor_to_rank0 device --- libai/evaluation/evaluator.py | 4 ++-- libai/utils/distributed.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/libai/evaluation/evaluator.py b/libai/evaluation/evaluator.py index 1414cdaa0..39c699c7a 100644 --- a/libai/evaluation/evaluator.py +++ b/libai/evaluation/evaluator.py @@ -203,12 +203,12 @@ def inference_on_dataset( # get valid sample valid_data = { - key: dist.tensor_to_rank0(value, to_local=True)[:valid_sample] + key: dist.tensor_to_rank0(value, device=value.placement.type, to_local=True)[:valid_sample] for key, value in data.items() } valid_outputs = {} for key, value in outputs.items(): - value = dist.tensor_to_rank0(value, to_local=True) + value = dist.tensor_to_rank0(value, device=value.placement.type, to_local=True) if value.ndim > 1: valid_outputs[key] = value[:valid_sample] # Slice if it's batched output else: diff --git a/libai/utils/distributed.py b/libai/utils/distributed.py index ea1b3e189..a76372956 100644 --- a/libai/utils/distributed.py +++ b/libai/utils/distributed.py @@ -473,7 +473,7 @@ def tensor_to_rank0(tensor, device="cuda", to_local=False): """Global tensor to rank0.""" if tensor.is_global: # Consider if it's 2d mesh, ranks should be [[0]] instead of [0] - placement = flow.placement(tensor.placement.type, ranks=[0] if tensor.placement.ranks.ndim == 1 else [[0]]) + placement = flow.placement(device, ranks=[0] if tensor.placement.ranks.ndim == 1 else [[0]]) tensor = tensor.to_global( sbp=get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), placement=placement )