Skip to content

Commit

Permalink
Do not aggregate the losses since last log step (#779)
Browse files Browse the repository at this point in the history
Fixes #763
  • Loading branch information
carmocca authored Jan 16, 2025
1 parent 82f7387 commit 2fa6d83
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 17 deletions.
18 changes: 12 additions & 6 deletions torchtitan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,20 @@ def get_device_info():
device_type, device_module = get_device_info()


def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float:
tensor = torch.tensor(x).to(device_type)
return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh).item()
def dist_reduce(x: torch.Tensor, reduceOp: str, mesh: DeviceMesh) -> float:
if isinstance(x, DTensor):
# functional collectives do not support DTensor inputs
x = x.full_tensor()
assert x.numel() == 1 # required by `.item()`
return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item()


def dist_mean(x: Union[int, float], mesh: DeviceMesh) -> float:
tensor = torch.tensor(x).to(device_type)
return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh).item()
def dist_max(x: torch.Tensor, mesh: DeviceMesh) -> float:
return dist_reduce(x, reduceOp=c10d.ReduceOp.MAX.name, mesh=mesh)


def dist_mean(x: torch.Tensor, mesh: DeviceMesh) -> float:
return dist_reduce(x, reduceOp=c10d.ReduceOp.AVG.name, mesh=mesh)


def _warn_overwrite_env(env, val):
Expand Down
18 changes: 7 additions & 11 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ def loss_fn(pred, labels):
)

# variables used to keep info for metrics logging
losses_since_last_log = []
ntokens_since_last_log = 0
data_loading_times = []
time_last_log = time.perf_counter()
Expand Down Expand Up @@ -295,10 +294,11 @@ def loss_fn(pred, labels):
pp_schedule.step()

# accumulate losses across pipeline microbatches
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
loss = (
torch.mean(torch.stack(losses))
torch.mean(torch.stack(losses)).to(device)
if is_last_stage
else torch.Tensor([-1.0])
else torch.tensor([-1.0], device=device)
)
else:
# Non-PP forward / backward
Expand Down Expand Up @@ -330,26 +330,23 @@ def loss_fn(pred, labels):
# it issues a single all-reduce for all parameters at once for better performance
float8_handler.precompute_float8_dynamic_scale_for_fsdp(model_parts)

losses_since_last_log.append(loss)

# log metrics
if (
train_state.step == 1
or train_state.step % job_config.metrics.log_freq == 0
):
losses = [loss.item() for loss in losses_since_last_log]
avg_loss, max_loss = sum(losses) / len(losses), max(losses)
if (
parallel_dims.dp_replicate_enabled
or parallel_dims.dp_shard_enabled
or parallel_dims.cp_enabled
):
loss = loss.detach()
global_avg_loss, global_max_loss = (
utils.dist_mean(avg_loss, world_mesh["dp_cp"]),
utils.dist_max(max_loss, world_mesh["dp_cp"]),
utils.dist_mean(loss, world_mesh["dp_cp"]),
utils.dist_max(loss, world_mesh["dp_cp"]),
)
else:
global_avg_loss, global_max_loss = avg_loss, max_loss
global_avg_loss = global_max_loss = loss.item()

# update train state
train_state.log_steps.append(train_state.step)
Expand Down Expand Up @@ -399,7 +396,6 @@ def loss_fn(pred, labels):
f"{color.magenta}mfu: {mfu:.2f}%{color.reset}"
)

losses_since_last_log.clear()
ntokens_since_last_log = 0
data_loading_times.clear()
time_last_log = time.perf_counter()
Expand Down

0 comments on commit 2fa6d83

Please sign in to comment.