Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
rename top level UX to convert_to_float8_training (#329)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #329

Old name: `swap_linear_with_float8_linear`
New name: `convert_to_float8_training`

Choosing a more generic name, with the following improvements from the
old name:
1. doesn't mention module swaps, which is an implementation detail
2. doesn't mention `Float8Linear`, which is an implementation detail
3. clarifies that this is for training, not to be confused with
   inference APIs
4. doesn't mention `linear`, which gives more freedom to add other
   modules later

```
find . -name '*.py' -print0 | xargs -0 sed -i 's/swap_linear_with_float8_linear/convert_to_float8_training/g'
```

Reviewed By: weifengpy

Differential Revision: D60195665

fbshipit-source-id: 8157b3d6f5db36c33370014135cbadfd192ac5b4
  • Loading branch information
vkuzo authored and facebook-github-bot committed Jul 25, 2024
1 parent e1c5fe1 commit da487a3
Show file tree
Hide file tree
Showing 14 changed files with 52 additions and 52 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ This is the most accurate recipe as every tensor is scaled dynamically.

```python
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
convert_to_float8_training,
)
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp

Expand All @@ -55,7 +55,7 @@ def module_filter_fn(mod: torch.nn.Module, fqn: str):
return True

# convert all `torch.nn.Linear` modules to `Float8Linear`
swap_linear_with_float8_linear(m, module_filter_fn=module_filter_fn)
convert_to_float8_training(m, module_filter_fn=module_filter_fn)

# optional: use FSDP
model = FSDP(model, use_orig_params=True)
Expand Down Expand Up @@ -83,7 +83,7 @@ This is theoretically the most performant recipe as it minimizes memory reads.

```python
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
convert_to_float8_training,
sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_linear import TensorScalingType
Expand All @@ -106,7 +106,7 @@ config = Float8LinearConfig(

# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling
# type
swap_linear_with_float8_linear(
convert_to_float8_training(
m,
config=config,
)
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/bench_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
TensorScalingType,
)
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
convert_to_float8_training,
sync_float8_amax_and_scale_history,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
Expand Down Expand Up @@ -77,7 +77,7 @@ def get_model(K, N, is_fp8, base_dtype=torch.float32):
modules.append(nn.ReLU())
m = nn.Sequential(*modules)
if is_fp8:
swap_linear_with_float8_linear(
convert_to_float8_training(
m,
config=config,
)
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
TensorScalingType,
)
from float8_experimental.float8_linear_utils import (
convert_to_float8_training,
linear_requires_sync,
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
)
from torch.profiler import profile, ProfilerActivity, record_function
Expand Down Expand Up @@ -268,7 +268,7 @@ def main(
m_ref = m_ref.to(device).to(ref_dtype)

m_float8 = copy.deepcopy(m_ref)
swap_linear_with_float8_linear(m_float8, config=config)
convert_to_float8_training(m_float8, config=config)

def ref_forw_backward(x):
out = m_ref(x)
Expand Down
4 changes: 2 additions & 2 deletions float8_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
TensorScalingType,
)
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
from float8_experimental.float8_linear_utils import convert_to_float8_training
from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
Expand All @@ -29,7 +29,7 @@
"Float8LinearConfig",
"Float8TensorCastConfig",
# top level UX
"swap_linear_with_float8_linear",
"convert_to_float8_training",
# TODO(future): remove Float8Tensor and Float8Linear from public API
"Float8Tensor",
"Float8Linear",
Expand Down
6 changes: 3 additions & 3 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def swap_linear_layers(
from_float_func: Callable[[nn.Linear], nn.Linear],
*,
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
) -> Optional[nn.Module]:
) -> nn.Module:
"""
Generic function to swap linear layers in a module with a new type of linear layer.
Expand Down Expand Up @@ -122,12 +122,12 @@ def post_order_traversal(
return root_module


def swap_linear_with_float8_linear(
def convert_to_float8_training(
module: nn.Module,
*,
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
config: Float8LinearConfig = None,
) -> Optional[nn.Module]:
) -> nn.Module:
"""
Swaps `torch.nn.Linear` in `module` with `Float8Linear`.
Expand Down
2 changes: 1 addition & 1 deletion float8_experimental/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def quantize_to_float8(
*,
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
use_fast_accum: bool = True,
) -> Optional[nn.Module]:
) -> nn.Module:
"""
Converts torch.nn.Linear layers in the given module to Float8InferenceLinear.
Expand Down
12 changes: 6 additions & 6 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
)
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import (
convert_to_float8_training,
linear_requires_sync,
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_python_api import addmm_float8_unwrapped
Expand Down Expand Up @@ -604,7 +604,7 @@ def test_swap_root_linear(self):
for emulate in [True, False]:
module = nn.Linear(3, 3)
config = Float8LinearConfig(emulate=emulate)
module = swap_linear_with_float8_linear(module, config=config)
module = convert_to_float8_training(module, config=config)
self.assertIsInstance(module, Float8Linear)
self.assertEqual(module.linear_mm_config.y.emulate, emulate)
self.assertEqual(module.linear_mm_config.y.emulate, emulate)
Expand All @@ -618,7 +618,7 @@ def test_swap_root_linear_with_children_raises(self):
AssertionError,
"Does not support a root nn.Linear with children",
):
swap_linear_with_float8_linear(module, config=config)
convert_to_float8_training(module, config=config)

def test_swap_submodule_linears(self):
class MLP(nn.Module):
Expand All @@ -630,7 +630,7 @@ def __init__(self, dim: int):
for emulate in [True, False]:
model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3))
config = Float8LinearConfig(emulate=emulate)
model = swap_linear_with_float8_linear(model, config=config)
model = convert_to_float8_training(model, config=config)
self.assertIsInstance(model[0].lin1, Float8Linear)
self.assertIsInstance(model[0].lin2, Float8Linear)
self.assertIsInstance(model[1], Float8Linear)
Expand Down Expand Up @@ -658,7 +658,7 @@ def module_filter_fn(mod, fqn):
)

config = Float8LinearConfig(emulate=True)
model = swap_linear_with_float8_linear(
model = convert_to_float8_training(
model,
config=config,
module_filter_fn=module_filter_fn,
Expand Down Expand Up @@ -687,7 +687,7 @@ def __init__(self, dim: int):
"2.lin1",
]
config = Float8LinearConfig(emulate=True)
model = swap_linear_with_float8_linear(
model = convert_to_float8_training(
model,
config=config,
module_filter_fn=module_filter_fn,
Expand Down
6 changes: 3 additions & 3 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
)
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import (
convert_to_float8_training,
get_float8_layers,
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_tensor import Float8Tensor, LinearMMConfig
Expand Down Expand Up @@ -280,7 +280,7 @@ def test_sync_amax_func():
scaling_type=TensorScalingType.DELAYED
),
)
float8_mod = swap_linear_with_float8_linear(
float8_mod = convert_to_float8_training(
module,
config=config,
)
Expand Down Expand Up @@ -324,7 +324,7 @@ def test_sync_amax_func_cuda_graph_success():
scaling_type=TensorScalingType.DELAYED
),
)
swap_linear_with_float8_linear(
convert_to_float8_training(
my_module,
config=config,
)
Expand Down
10 changes: 5 additions & 5 deletions test/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from float8_experimental import Float8LinearConfig

from float8_experimental.float8_dynamic_utils import NoopFwToFloat8E5M2Bw
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
from float8_experimental.float8_linear_utils import convert_to_float8_training
from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
Expand Down Expand Up @@ -187,12 +187,12 @@ def _test_fp8_mlp_tensor_parallelism_base(
config = Float8LinearConfig(emulate=True)

toy_model = ToyModel().to(device)
toy_model_fp8 = swap_linear_with_float8_linear(toy_model, config=config)
toy_model_fp8 = convert_to_float8_training(toy_model, config=config)

tp_model = copy.deepcopy(toy_model)
tp_model = swap_linear_with_float8_linear(tp_model, config=config)
tp_model = convert_to_float8_training(tp_model, config=config)
sp_model = copy.deepcopy(toy_model)
sp_model = swap_linear_with_float8_linear(sp_model, config=config)
sp_model = convert_to_float8_training(sp_model, config=config)

# vanilla TP
tp_model = parallelize_module(
Expand Down Expand Up @@ -223,7 +223,7 @@ def _test_fp8_mlp_tensor_parallelism_base(

# PrepareFloat8ModuleInput with specific submodule fqn
sp_model2 = copy.deepcopy(toy_model)
sp_model2 = swap_linear_with_float8_linear(sp_model2, config=config)
sp_model2 = convert_to_float8_training(sp_model2, config=config)

sp_model2 = parallelize_module(
sp_model2,
Expand Down
4 changes: 2 additions & 2 deletions test/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
TensorScalingType,
)
from float8_experimental.float8_linear_utils import (
convert_to_float8_training,
linear_requires_sync,
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_utils import compute_error
Expand Down Expand Up @@ -90,7 +90,7 @@ def fsdp_main(rank, world_size, args):

# Note: we only iterate over `scaling_type_weight` because FSDP only interacts
# with weights.
swap_linear_with_float8_linear(
convert_to_float8_training(
model_fp8,
config=config,
)
Expand Down
30 changes: 15 additions & 15 deletions test/test_fsdp2/test_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
Float8TensorCastConfig,
TensorScalingType,
)
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
from float8_experimental.float8_linear_utils import convert_to_float8_training
from float8_experimental.fsdp_utils import WeightWithDynamicFloat8CastTensor
from test_fsdp2_common import check_parity_bf16_mp, check_parity_no_mp
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
Expand Down Expand Up @@ -116,7 +116,7 @@ def _test_transformer_parity(
float8_linear_config1 = Float8LinearConfig(
cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight),
)
swap_linear_with_float8_linear(
convert_to_float8_training(
ref_module,
config=float8_linear_config1,
)
Expand All @@ -128,7 +128,7 @@ def _test_transformer_parity(
enable_fsdp_fp8_all_gather=enable_fsdp_fp8_all_gather,
cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight),
)
swap_linear_with_float8_linear(
convert_to_float8_training(
module,
config=float8_linear_config2,
)
Expand Down Expand Up @@ -187,7 +187,7 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
enable_fsdp_fp8_all_gather=enable_fsdp_fp8_all_gather,
emulate=True,
)
swap_linear_with_float8_linear(model, config=float8_linear_config)
convert_to_float8_training(model, config=float8_linear_config)
model_unsharded_numel = sum(p.numel() for p in model.parameters())
model_sharded_numel = (model_unsharded_numel + 1) // 2
block_lin_weight_numel = 0
Expand Down Expand Up @@ -297,7 +297,7 @@ def test_weight_subclass_dynamic(self):
enable_fsdp_fp8_all_gather=True,
emulate=True,
)
module = swap_linear_with_float8_linear(
module = convert_to_float8_training(
module_fp32,
config=float8_linear_config,
)
Expand All @@ -310,7 +310,7 @@ def test_weight_subclass_dynamic(self):

# Check for multiple FSDP paramter groups
module = self.init_multi_module()
module = swap_linear_with_float8_linear(
module = convert_to_float8_training(
module,
config=float8_linear_config,
)
Expand Down Expand Up @@ -362,7 +362,7 @@ def get_expected_all_gather_size(module: nn.Module):
float8_linear_config = Float8LinearConfig(
enable_fsdp_fp8_all_gather=True,
)
module_fp32 = swap_linear_with_float8_linear(
module_fp32 = convert_to_float8_training(
module_fp32, config=float8_linear_config
)
module = module_fp32
Expand Down Expand Up @@ -392,7 +392,7 @@ def get_expected_all_gather_size(module: nn.Module):
# - Check for multiple FSDP parameter groups
module = self.init_multi_module()
ref_module = copy.deepcopy(module)
module = swap_linear_with_float8_linear(module, config=float8_linear_config)
module = convert_to_float8_training(module, config=float8_linear_config)
for submodule in module:
fully_shard(submodule)
fully_shard(module)
Expand Down Expand Up @@ -433,12 +433,12 @@ def test_fp32_fp8_single_module_parity(self):
)
module_fp32 = self.init_single_module()
ref_module = copy.deepcopy(module_fp32)
ref_module = swap_linear_with_float8_linear(
ref_module = convert_to_float8_training(
ref_module,
config=float8_linear_config1,
)
ref_module = ref_module.cuda()
module = swap_linear_with_float8_linear(
module = convert_to_float8_training(
module_fp32,
config=float8_linear_config2,
)
Expand Down Expand Up @@ -481,11 +481,11 @@ def test_fp32_fp8_multi_module_parity(self):
)
module = self.init_multi_module().cuda()
ref_module = copy.deepcopy(module)
ref_module = swap_linear_with_float8_linear(
ref_module = convert_to_float8_training(
ref_module,
config=float8_linear_config1,
)
module = swap_linear_with_float8_linear(
module = convert_to_float8_training(
module,
config=float8_linear_config2,
)
Expand Down Expand Up @@ -518,12 +518,12 @@ def test_bf16_mp_fp8_dynamic_multi_parity(self):
module = self.init_multi_module()
ref_module_bf16 = copy.deepcopy(module).to(torch.bfloat16)
float8_config = Float8LinearConfig(emulate=True)
ref_module_bf16 = swap_linear_with_float8_linear(
ref_module_bf16 = convert_to_float8_training(
ref_module_bf16,
config=float8_config,
)
ref_module_fp32 = copy.deepcopy(module).cuda()
module = swap_linear_with_float8_linear(module, config=float8_config)
module = convert_to_float8_training(module, config=float8_config)
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
for mlp in module:
fully_shard(mlp, mp_policy=mp_policy)
Expand All @@ -550,7 +550,7 @@ def test_delayed_scaling_inplace_update(self):
scaling_type=TensorScalingType.DELAYED
),
)
m_fp8 = swap_linear_with_float8_linear(
m_fp8 = convert_to_float8_training(
module,
config=float8_linear_config,
)
Expand Down
Loading

0 comments on commit da487a3

Please sign in to comment.