Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Changes to enable fp8 on multi devices #149

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import dataclasses

from typing import Optional

import torch

from float8_experimental.float8_linear_utils import (
Expand Down Expand Up @@ -92,6 +94,8 @@ def __init__(self, *args, **kwargs):
delayed_scaling_recipe = kwargs.pop(
"delayed_scaling_recipe", DelayedScalingRecipe()
)
# Amax scales should always be kept as float32.
self.always_float32_buffers = set()
super().__init__(*args, **kwargs)

# TODO(future): have a unique recipe per buffer instead of one per
Expand All @@ -100,15 +104,23 @@ def __init__(self, *args, **kwargs):
self.recipe = delayed_scaling_recipe
history_len = self.recipe.history_len

self.register_buffer("fp8_amax_x", torch.tensor(E4M3_MAX_POS))
self.register_buffer("fp8_amax_history_x", torch.zeros(history_len))
self.register_buffer("fp8_scale_x", torch.tensor(1.0))
self.register_buffer("fp8_amax_w", torch.tensor(E4M3_MAX_POS))
self.register_buffer("fp8_amax_history_w", torch.zeros(history_len))
self.register_buffer("fp8_scale_w", torch.tensor(1.0))
self.register_buffer("fp8_amax_dL_dY", torch.tensor(E5M2_MAX_POS))
self.register_buffer("fp8_amax_history_dL_dY", torch.zeros(history_len))
self.register_buffer("fp8_scale_dL_dY", torch.tensor(1.0))
self.register_always_float32_buffer("fp8_amax_x", torch.tensor(E4M3_MAX_POS))
self.register_always_float32_buffer(
"fp8_amax_history_x", torch.zeros(history_len)
)
self.register_always_float32_buffer("fp8_scale_x", torch.tensor(1.0))
self.register_always_float32_buffer("fp8_amax_w", torch.tensor(E4M3_MAX_POS))
self.register_always_float32_buffer(
"fp8_amax_history_w", torch.zeros(history_len)
)
self.register_always_float32_buffer("fp8_scale_w", torch.tensor(1.0))
self.register_always_float32_buffer(
"fp8_amax_dL_dY", torch.tensor(E5M2_MAX_POS)
)
self.register_always_float32_buffer(
"fp8_amax_history_dL_dY", torch.zeros(history_len)
)
self.register_always_float32_buffer("fp8_scale_dL_dY", torch.tensor(1.0))
# Whether to emulate the fp8 matmul logic in float32
self.emulate = False

Expand Down Expand Up @@ -136,6 +148,22 @@ def __init__(self, *args, **kwargs):
# will access the scale when it has ensured that it is on GPU.
self._float8_tensor_ctor = lambda *args, **kwargs: Float8Tensor(*args, **kwargs)

def register_always_float32_buffer(
self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True
) -> None:
self.register_buffer(name=name, tensor=tensor, persistent=persistent)
self.always_float32_buffers.add(name)

def _apply(self, fn, recurse=True):
ret = super()._apply(fn, recurse)
self.convert_amax_buffer_to_float32()
return ret

def convert_amax_buffer_to_float32(self):
for key in self.always_float32_buffers:
if self._buffers[key] is not None:
self._buffers[key] = self._buffers[key].to(torch.float32)

def cast_x_to_float8(
self, x: torch.Tensor, is_amax_initialized: bool
) -> torch.Tensor:
Expand Down
11 changes: 8 additions & 3 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ def swap_linear_with_float8_linear(
swap_linear_with_float8_linear(child, module, emulate)


def sync_float8_amax_and_scale_history(model: torch.nn.Module) -> None:
def sync_float8_amax_and_scale_history(
model: torch.nn.Module, fp8_classes=None
) -> None:
"""
Manages the float8 amax and scale bookkeeping. In detail, it does the
following:
Expand All @@ -138,10 +140,13 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module) -> None:
# the reductions into one and probably make the history update faster.
# Lazy import to avoid circular dependency

from float8_experimental.float8_linear import Float8Linear
if fp8_classes is None:
from float8_experimental.float8_linear import Float8Linear

fp8_classes = Float8Linear

for name, child in model.named_modules():
if not isinstance(child, (Float8Linear)):
if not isinstance(child, fp8_classes):
continue

#
Expand Down
54 changes: 54 additions & 0 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,60 @@ def test_linear_float8_weight_tag(self):
m_fp8 = Float8Linear.from_float(copy.deepcopy(m_ref))
assert m_fp8.weight._is_fp8_weight

@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
@pytest.mark.parametrize(
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
)
def test_type_cast(self, linear_type: LinearType, linear_dtype: torch.dtype):
emulate = (
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0)
)
x_shape = (16, 16)

x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype)
self._test_linear_impl(x, m_ref, linear_type, emulate)

m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
m = Float8Linear.from_float(m, emulate)

# Cast the module to dtype
m = m.to(dtype=linear_dtype)
# Check amax buffer types
for key in [
"fp8_amax_x",
"fp8_amax_history_x",
"fp8_scale_x",
"fp8_amax_w",
"fp8_amax_history_w",
"fp8_scale_w",
"fp8_amax_dL_dY",
"fp8_amax_history_dL_dY",
"fp8_scale_dL_dY",
]:
assert (
m._buffers[key].dtype == torch.float32
), f"{key}.dtype is {m._buffers[key].dtype}, expected torch.float32"

# autocast off
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this test does cover it but also could we assert the buffer types are still fp32

x = torch.randn(16, 32, device="cuda", dtype=linear_dtype)
sync_float8_amax_and_scale_history(m)
y = m(x)
assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}"

# autocast on
with torch.autocast("cuda"):
sync_float8_amax_and_scale_history(m)
y = m(x)
assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}"

with torch.autocast("cuda", dtype=torch.bfloat16):
sync_float8_amax_and_scale_history(m)
y = m(x)
assert (
y.dtype == torch.bfloat16
), f"y.dtype is {y.dtype}, expected {torch.bfloat16}"


class TestScaledMM:
@unittest.skipIf(
Expand Down
Loading