Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Changes to enable fp8 on multi devices #149

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import dataclasses

from typing import Optional

import torch

from float8_experimental.float8_linear_utils import (
Expand Down Expand Up @@ -92,6 +94,8 @@ def __init__(self, *args, **kwargs):
delayed_scaling_recipe = kwargs.pop(
"delayed_scaling_recipe", DelayedScalingRecipe()
)
# Amax scales should always be kept as float32.
self.always_float32_buffers = set()
super().__init__(*args, **kwargs)

# TODO(future): have a unique recipe per buffer instead of one per
Expand All @@ -100,15 +104,23 @@ def __init__(self, *args, **kwargs):
self.recipe = delayed_scaling_recipe
history_len = self.recipe.history_len

self.register_buffer("fp8_amax_x", torch.tensor(E4M3_MAX_POS))
self.register_buffer("fp8_amax_history_x", torch.zeros(history_len))
self.register_buffer("fp8_scale_x", torch.tensor(1.0))
self.register_buffer("fp8_amax_w", torch.tensor(E4M3_MAX_POS))
self.register_buffer("fp8_amax_history_w", torch.zeros(history_len))
self.register_buffer("fp8_scale_w", torch.tensor(1.0))
self.register_buffer("fp8_amax_dL_dY", torch.tensor(E5M2_MAX_POS))
self.register_buffer("fp8_amax_history_dL_dY", torch.zeros(history_len))
self.register_buffer("fp8_scale_dL_dY", torch.tensor(1.0))
self.register_always_float32_buffer("fp8_amax_x", torch.tensor(E4M3_MAX_POS))
self.register_always_float32_buffer(
"fp8_amax_history_x", torch.zeros(history_len)
)
self.register_always_float32_buffer("fp8_scale_x", torch.tensor(1.0))
self.register_always_float32_buffer("fp8_amax_w", torch.tensor(E4M3_MAX_POS))
self.register_always_float32_buffer(
"fp8_amax_history_w", torch.zeros(history_len)
)
self.register_always_float32_buffer("fp8_scale_w", torch.tensor(1.0))
self.register_always_float32_buffer(
"fp8_amax_dL_dY", torch.tensor(E5M2_MAX_POS)
)
self.register_always_float32_buffer(
"fp8_amax_history_dL_dY", torch.zeros(history_len)
)
self.register_always_float32_buffer("fp8_scale_dL_dY", torch.tensor(1.0))
# Whether to emulate the fp8 matmul logic in float32
self.emulate = False

Expand Down Expand Up @@ -136,6 +148,22 @@ def __init__(self, *args, **kwargs):
# will access the scale when it has ensured that it is on GPU.
self._float8_tensor_ctor = lambda *args, **kwargs: Float8Tensor(*args, **kwargs)

def register_always_float32_buffer(
self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True
) -> None:
self.register_buffer(name=name, tensor=tensor, persistent=persistent)
self.always_float32_buffers.add(name)

def _apply(self, fn, recurse=True):
ret = super()._apply(fn, recurse)
self.convert_amax_buffer_to_float32()
return ret

def convert_amax_buffer_to_float32(self):
for key in self.always_float32_buffers:
if self._buffers[key] is not None:
self._buffers[key] = self._buffers[key].to(torch.float32)

def cast_x_to_float8(
self, x: torch.Tensor, is_amax_initialized: bool
) -> torch.Tensor:
Expand Down
11 changes: 8 additions & 3 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ def swap_linear_with_float8_linear(
swap_linear_with_float8_linear(child, module, emulate)


def sync_float8_amax_and_scale_history(model: torch.nn.Module) -> None:
def sync_float8_amax_and_scale_history(
model: torch.nn.Module, fp8_classes=None
) -> None:
"""
Manages the float8 amax and scale bookkeeping. In detail, it does the
following:
Expand All @@ -138,10 +140,13 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module) -> None:
# the reductions into one and probably make the history update faster.
# Lazy import to avoid circular dependency

from float8_experimental.float8_linear import Float8Linear
if fp8_classes is None:
from float8_experimental.float8_linear import Float8Linear

fp8_classes = [Float8Linear]

for name, child in model.named_modules():
if not isinstance(child, (Float8Linear)):
if not any(isinstance(child, a) for a in fp8_classes):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have removed the NoTs class I think this is likely rebase buggies

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In multi-gpu cases, we have Float8ColumnParallelLinear and Float8RowParallelLinear (which have dependencies of external distributed training code) as the fp8 classes. So I modified here to pass the class types to sync_float8_amax_and_scale_history.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahhh I see I was misreading this, but if we make fp8_classes a Tuple[types] couldn't we still keep the check as is, this is a nit anyways both accomplish the same thing

continue

#
Expand Down