Skip to content

Commit

Permalink
use helpers for selective per layer ac
Browse files Browse the repository at this point in the history
  • Loading branch information
danielvegamyhre committed Jan 16, 2025
1 parent 8b29798 commit 4ab683c
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion torchtitan/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def convert_to_float8_training(self, model: nn.Module):
logger.info("Using float8nocompile prototype")
from torchao.prototype.float8nocompile.float8nocompile_linear_utils import (
convert_to_float8_nocompile_training,
no_precompute_for_backward_every_nth_layer,
)

# for full AC or no AC
Expand Down Expand Up @@ -175,3 +174,20 @@ def sync_float8_amax_and_scale_history(
models = [model] if isinstance(model, nn.Module) else model
for m in models:
self._sync_float8_amax_and_scale_history(m)


def no_precompute_for_backward_every_nth_layer(model: nn.Module, n: int):
"""Set no_precompute_for_backward to True for every nth layer in the model."""
for layer_idx, (layer_id, layer) in enumerate(model.layers.named_children()):
if layer_idx % n == 0:
logger.info(f"Enabling no_precompute_for_backward for layer {layer_id}")
_enable_no_precompute_for_backward(layer)


def _enable_no_precompute_for_backward(model: nn.Module):
"""Recursively set no_precompute_for_backward to True for all linear layers in the given model."""
for child_layer in model.children():
if isinstance(child_layer, nn.Linear):
child_layer.no_precompute_for_backward = True
else:
_enable_no_precompute_for_backward(child_layer)

0 comments on commit 4ab683c

Please sign in to comment.