Skip to content

Commit

Permalink
add max and abs to selective AC config (#701)
Browse files Browse the repository at this point in the history
Summary:

We usually want to save results of max(abs(tensor)) as the memory used
to store the result will be negligible.

For float8 training with selective op-based AC, this is a nice small
speedup of 1% wps gain on LLaMa 3 8B pretraining on 8 H100 GPUs with
default settings.

There is no harm to keeping it here for non-float8 training, so just
enabling for everyone with gating to keep things simple.

WPS: 6800 -> 6860

Test Plan:

```
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml"
./run_llama_train.sh --float8.enable_float8_linear
--float8.scaling_type_input dynamic --float8.scaling_type_weight dynamic
--float8.scaling_type_grad_output dynamic --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo authored Nov 26, 2024
1 parent 4d182a1 commit bcebc0d
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,10 @@ def apply_tp(
torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
# for low precision training, it's useful to always save
# the result of max(abs(tensor))
# torch.ops.aten.abs.default,
# torch.ops.aten.max.default,
}


Expand Down

0 comments on commit bcebc0d

Please sign in to comment.