Skip to content

Commit

Permalink
tensor_to_rank0 device
Browse files Browse the repository at this point in the history
  • Loading branch information
fpzh2011 committed Oct 26, 2024
1 parent 7107548 commit 0d499d8
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions libai/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,12 @@ def inference_on_dataset(

# get valid sample
valid_data = {
key: dist.tensor_to_rank0(value, to_local=True)[:valid_sample]
key: dist.tensor_to_rank0(value, device=value.placement.type, to_local=True)[:valid_sample]
for key, value in data.items()
}
valid_outputs = {}
for key, value in outputs.items():
value = dist.tensor_to_rank0(value, to_local=True)
value = dist.tensor_to_rank0(value, device=value.placement.type, to_local=True)
if value.ndim > 1:
valid_outputs[key] = value[:valid_sample] # Slice if it's batched output
else:
Expand Down
2 changes: 1 addition & 1 deletion libai/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def tensor_to_rank0(tensor, device="cuda", to_local=False):
"""Global tensor to rank0."""
if tensor.is_global:
# Consider if it's 2d mesh, ranks should be [[0]] instead of [0]
placement = flow.placement(tensor.placement.type, ranks=[0] if tensor.placement.ranks.ndim == 1 else [[0]])
placement = flow.placement(device, ranks=[0] if tensor.placement.ranks.ndim == 1 else [[0]])
tensor = tensor.to_global(
sbp=get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), placement=placement
)
Expand Down

0 comments on commit 0d499d8

Please sign in to comment.