Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Auto scaling factor tuning for FP8 collective communication #140

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
14 changes: 11 additions & 3 deletions msamp/common/tensor/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""MS-AMP ScalingMeta."""

import copy
from typing import Optional
import torch

from msamp.common.dtype import Floating, Dtypes
Expand All @@ -13,7 +14,7 @@ class ScalingMeta:
"""The meta data for scaling tensor."""
in_time_scaling: bool = True

def __init__(self, qtype, scale=None, scale_inv=None, amax=None, window_size=1, group=None):
def __init__(self, qtype, scale=None, scale_inv=None, amax=None, window_size=1, pre_scale=None, group=None):
"""Constructor.

Args:
Expand All @@ -22,11 +23,13 @@ def __init__(self, qtype, scale=None, scale_inv=None, amax=None, window_size=1,
scale_inv (torch.Tensor, optional): The reciprocal of scaling tensor, defaults to None.
amax (torch.Tensor, optional): Absolute maximum tensor, defaults to None.
window_size (int, optional): Window size, defaults to 1.
pre_scale (torch.Tensor, optional): A pre-scale factor
group (torch.distributed.ProcessGroup, optional): Distributed group, defaults to None.
"""
self.qtype = qtype
self.scale = scale if scale is not None else torch.ones((), device='cuda')
self.scale_inv = scale_inv if scale_inv is not None else torch.ones((), device='cuda')
self.pre_scale = pre_scale if pre_scale is not None else torch.ones((), device='cuda')
self.amax = amax if amax is not None else torch.zeros((window_size, ), device='cuda')
self.amax_counter = torch.zeros((), dtype=torch.int32)
self.window_size = window_size
Expand All @@ -36,20 +39,23 @@ def __init__(self, qtype, scale=None, scale_inv=None, amax=None, window_size=1,

@staticmethod
@torch.jit.script
def compute_scaling_factor(amax, scale, fp_max: float, margin: int):
def compute_scaling_factor(amax, scale, fp_max: float, margin: int, pre_scale: Optional[torch.Tensor] = None):
"""A function to compute scaling factor.

Args:
amax (torch.Tensor): Absolute maximum tensor.
scale (torch.Tensor): Scale tensor.
fp_max (float): The maximum value of float point.
margin (int): Margin value.
pre_scale (torch.Tensor, optional): A pre-scale factor

Returns:
return new scaling tensor.
"""
exp = torch.floor(torch.log2(fp_max / amax)) - margin
sf = torch.round(torch.pow(2, torch.abs(exp)))
if pre_scale is not None:
sf.mul_(pre_scale)
sf = torch.where(amax > 0.0, sf, scale)
sf = torch.where(torch.isfinite(amax), sf, scale)
sf = torch.where(exp < 0, 1 / sf, sf)
Expand Down Expand Up @@ -108,7 +114,7 @@ def reset_scaling_factor(self, qtype=None):
self.scale.fill_(1)
else:
fp_max = Floating.qfp_max[qtype]
sf = ScalingMeta.compute_scaling_factor(self.amax[0], self.scale, fp_max, 0)
sf = ScalingMeta.compute_scaling_factor(self.amax[0], self.scale, fp_max, 0, pre_scale=self.pre_scale)
self.scale.copy_(sf)

def copy_(self, src):
Expand All @@ -122,6 +128,7 @@ def copy_(self, src):
self.scale_inv.copy_(src.scale_inv)
self.amax.copy_(src.amax)
self.amax_counter.copy_(src.amax_counter)
self.pre_scale.copy_(src.pre_scale)
self.window_size = src.window_size

def clone(self):
Expand Down Expand Up @@ -156,4 +163,5 @@ def __repr__(self):
"""
return f'ScalingMeta(qtype={self.qtype}, '\
f'scale={self.scale.data:g}, scale_inv={self.scale_inv.data:g}, '\
f'pre_scale={self.pre_scale.data:g}, '\
f'amax={self.amax.max():g}, window_size={self.window_size})'
2 changes: 1 addition & 1 deletion msamp/megatron/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def _get_buffer_type(param):
start = pi * max_fp8_mems
for p in fp8_partitions[pi]:
meta = ScalingMeta(self.wgrad_qtype, scale=scales[t], scale_inv=scale_invs[t], amax=amaxs[t])
meta.pre_scale = pre_scale
meta.pre_scale.fill_(pre_scale)
t += 1
p.main_grad = ScalingTensor(self._grad_buffers[self.wgrad_dtype].get(p.shape, start), meta)
self._grad_buffer_param_index_map[self.wgrad_dtype][p] = (start, start + p.numel())
Expand Down
41 changes: 41 additions & 0 deletions msamp/megatron/optimizer/distrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,47 @@ def reduce_model_grads(self, args, timers): # noqa: C901

timers('grads-reduce-scatter').stop()

if args.wgrad_auto_scaling:
# Weight Gradient Auto Scaling
if args.curr_iteration % args.wgrad_auto_scaling_freq == 0:
timers('wgrad-auto-scaling', log_level=1).start(barrier=args.barrier_with_L1_time)

# update pre_scale in this partition
for model_group in self.model_fp8_groups:
for p in model_group:
g = p.main_grad
if g is not None and not torch.is_tensor(g):
if g.qtype != Dtypes.kfloat8_e4m3:
raise TypeError('g.qtype != Dtypes.kfloat8_e4m3: {}'.format(g.qtype))
# stat overflow ratio
num_infs = torch.count_nonzero((g.value & 0x7f) == 126)
overflow_ratio = num_infs / g.numel()
if overflow_ratio > args.wgrad_auto_scaling_ratio:
g.meta.pre_scale.div_(2.0)
else:
g.meta.pre_scale.mul_(2.0**(1.0 / args.wgrad_auto_scaling_window))

# synchonize pre_scale in all partitions
for model_id, model in enumerate(self.models):
# all fp8 gradients
partitions = self.model_gbuf_ranges[model_id][torch.uint8]['partitions']
fp8_grads = [[p.main_grad for p in part.keys()] for part in partitions]
# pre_scales in the partition `data_parallel_rank`
pre_scales = [g.meta.pre_scale for g in fp8_grads[data_parallel_rank]]
max_elems_per_rank = max(model._grad_buffer_num_params)
pre_scales = torch.stack(pre_scales)
# padding to max_elems_per_rank
pad = max_elems_per_rank - pre_scales.numel()
pre_scales = F.pad(pre_scales, (0, pad))
output_pre_scales = pre_scales.new_empty((data_parallel_world_size, max_elems_per_rank))
torch.distributed._all_gather_base(output_pre_scales, pre_scales, group=data_parallel_group)
# assign pre_scale to all fp8 gradients
for grads, pre_scales in zip(fp8_grads, output_pre_scales):
for g, pre_scale in zip(grads, pre_scales):
g.meta.pre_scale.copy_(pre_scale)

timers('wgrad-auto-scaling').stop()

def gather_model_params(self, args, timers): # noqa: C901
"""All-gather updated model params.

Expand Down
6 changes: 4 additions & 2 deletions msamp/operators/arithmetic/arithmetic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,23 @@ void add_to_fp8(at::Tensor fp8_tensor,
at::Tensor scale,
at::Tensor scale_inv,
at::Tensor amax,
at::Tensor pre_scale,
const at::Tensor& other,
bool is_e4m3) {
const size_t N = other.numel();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_DTYPE_SWITCH_INPUT(other.scalar_type(), IType,
SELECT_FP8_TYPE(is_e4m3, OType,

constexpr int nvec = 32 / sizeof(IType);

VectorizedAddToFp8KernelLauncher<nvec>(
reinterpret_cast<IType*>(other.data_ptr()),
reinterpret_cast<OType*>(fp8_tensor.data_ptr()),
reinterpret_cast<fp32*>(scale.data_ptr()),
reinterpret_cast<fp32*>(scale_inv.data_ptr()),
reinterpret_cast<fp32*>(amax.data_ptr()),
reinterpret_cast<fp32*>(pre_scale.data_ptr()),
N,
stream
);
Expand Down
4 changes: 3 additions & 1 deletion msamp/operators/arithmetic/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,6 @@ def add_to_fp8(fp8_tensor, meta, other):

is_e4m3 = meta.qtype == Dtypes.kfloat8_e4m3

msamp_arithmetic.add_to_fp8(fp8_tensor, meta.scale, meta.scale_inv, meta.amax[0], other, is_e4m3)
msamp_arithmetic.add_to_fp8(
fp8_tensor, meta.scale, meta.scale_inv, meta.amax[0], meta.pre_scale, other, is_e4m3
)
26 changes: 15 additions & 11 deletions msamp/operators/arithmetic/vectorized_pointwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ __global__ void add_to_fp8_kernel(InputType *input,
ComputeType *scale,
ComputeType *scale_inv,
ComputeType *amax,
ComputeType *pre_scale,
const size_t N,
const size_t num_aligned_elements) {
if (threadIdx.x == 0 && blockIdx.x == 0) {
Expand Down Expand Up @@ -262,12 +263,14 @@ __global__ void add_to_fp8_kernel(InputType *input,
ComputeType exp = floorf(log2f(fp_max/(amax_value)));
ComputeType sf = roundf(powf(2, fabsf(exp)));

sf *= *pre_scale;

if (amax_value <= 0 || !isfinite(amax_value)) {
sf = *scale;
}

if (exp < 0) {
sf = 1 / sf;
sf = 1.0f / sf;
}

// using new scaling factor to quantize the input
Expand All @@ -280,9 +283,9 @@ __global__ void add_to_fp8_kernel(InputType *input,
for (int i = 0; i < nvec; ++i) {
const InputType val1 = static_cast<InputType>(input_storer.separate()[i]);
const ComputeType val2 = static_cast<ComputeType>(output_storer.separate()[i]);

InputType temp1 = static_cast<InputType>(val2 * s);

if constexpr (is_half<InputType>::value) {
temp1 = static_cast<ComputeType>(__hadd(temp1, val1));
} else {
Expand All @@ -296,7 +299,7 @@ __global__ void add_to_fp8_kernel(InputType *input,

if (threadIdx.x == 0 && blockIdx.x == 0) {
*scale = sf;
*scale_inv = 1.0 / sf;
*scale_inv = 1.0f / sf;
}
}

Expand Down Expand Up @@ -363,6 +366,7 @@ void VectorizedAddToFp8KernelLauncher(InputType *input,
fp32 *scale,
fp32 *scale_inv,
fp32 *amax,
fp32 *pre_scale,
const size_t N,
cudaStream_t stream) {
if (N != 0) {
Expand All @@ -373,26 +377,26 @@ void VectorizedAddToFp8KernelLauncher(InputType *input,
constexpr size_t threads = unary_kernel_threads;
size_t num_blocks = DIVUP(num_aligned_elements, threads);

// We use DeviceSyncer to sync the amax value between blocks, the block number should be less than
// (SMCount*MaxThreadsPerSM)/unary_kernel_threads, which is 132*2048/512 = 528 on H100 SXM. We set
// max_blocks to half of 528 to make sure it works on other H100 GPUs.
// We use DeviceSyncer to sync the amax value between blocks, the block number should be less than
// (SMCount*MaxThreadsPerSM)/unary_kernel_threads, which is 132*2048/512 = 528 on H100 SXM. We set
// max_blocks to half of 528 to make sure it works on other H100 GPUs.
// constexpr size_t max_blocks = 65535;
constexpr size_t max_blocks = 264;
num_blocks = std::min(num_blocks, max_blocks);

switch (align) {
case Alignment::SAME_ALIGNED:
add_to_fp8_kernel<nvec, true, fp32><<<num_blocks, threads, 0, stream>>>(
input, output, scale, scale_inv, amax, N, num_aligned_elements);
input, output, scale, scale_inv, amax, pre_scale, N, num_aligned_elements);
break;
case Alignment::SAME_UNALIGNED:
add_to_fp8_kernel<nvec, false, fp32><<<num_blocks, threads, 0, stream>>>(
input, output, scale, scale_inv, amax, N, num_aligned_elements);
input, output, scale, scale_inv, amax, pre_scale, N, num_aligned_elements);
break;
case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize
add_to_fp8_kernel<1, true, fp32><<<num_blocks, threads, 0, stream>>>(
input, output, scale, scale_inv, amax, N, num_aligned_elements);
input, output, scale, scale_inv, amax, pre_scale, N, num_aligned_elements);
break;
}
}
Expand All @@ -401,4 +405,4 @@ void VectorizedAddToFp8KernelLauncher(InputType *input,

} // namespace msamp

#endif // MSAMP_VECTORIZED_POINTWISE_H
#endif // MSAMP_VECTORIZED_POINTWISE_H
15 changes: 15 additions & 0 deletions tests/common/tensor/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,18 @@ def test_disable_in_time_scaling(self):
meta = ScalingMeta(Dtypes.kfloat8_e4m3)
self.assertFalse(meta.is_in_time_scaling())
ScalingMeta.in_time_scaling = bak

def test_pre_scale(self):
"""Test pre_scale in ScalingMeta."""
x = torch.randn((4, 4), device='cuda')
meta = ScalingMeta(Dtypes.kfloat8_e4m3)
qtype = Dtypes.kfloat8_e4m3
q1 = x.cast(qtype, meta)

r = 0.5
meta2 = ScalingMeta(Dtypes.kfloat8_e4m3)
meta2.pre_scale.fill_(r)
q2 = x.cast(qtype, meta2)
self.assertTrue(torch.allclose(q1.float(), q2.float(), atol=5e-4))
self.assertTrue(torch.allclose(q1.meta.scale * r, q2.meta.scale))
self.assertTrue(torch.allclose(q1.meta.scale_inv / r, q2.meta.scale_inv))
25 changes: 21 additions & 4 deletions tests/operators/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
class ArithmeticTestCase(unittest.TestCase):
"""A class for Arithmetic test cases."""
def _check_scaling_tensor(self, scaling_tensor1, scaling_tensor2):
self.assertTrue(torch.all(torch.eq(scaling_tensor1.value, scaling_tensor2.value)))
self.assertTrue(torch.all(torch.eq(scaling_tensor1.meta.scale, scaling_tensor2.meta.scale)))
self.assertTrue(torch.all(torch.eq(scaling_tensor1.meta.scale_inv, scaling_tensor2.meta.scale_inv)))
self.assertTrue(torch.all(torch.eq(scaling_tensor1.meta.amax, scaling_tensor2.meta.amax)))
atol = 1e-6
self.assertTrue(torch.allclose(scaling_tensor1.value, scaling_tensor2.value, atol=3))
self.assertTrue(torch.allclose(scaling_tensor1.meta.scale, scaling_tensor2.meta.scale, atol=atol))
self.assertTrue(torch.allclose(scaling_tensor1.meta.scale_inv, scaling_tensor2.meta.scale_inv, atol=atol))
self.assertTrue(torch.allclose(scaling_tensor1.meta.amax, scaling_tensor2.meta.amax, atol=atol))

@decorator.cuda_test
def test_add_to_fp8(self):
Expand All @@ -31,10 +32,26 @@ def test_add_to_fp8(self):
for i, j, dtype, qtype, in itertools.product(sizes, sizes, dtypes, qtypes):
size = (i, j)
input1 = torch.rand(size, dtype=dtype, device='cuda')

# w/o pre_scale
scaling_tensor1 = input1.cast(qtype)
scaling_tensor2 = input1.cast(qtype)

for i in range(10):
input2 = torch.rand(size, dtype=dtype, device='cuda')
meta = scaling_tensor1.meta
Arithmetic.add_to_fp8(scaling_tensor1.value, meta, input2)
scaling_tensor2.copy_((scaling_tensor2.to(dtype) + input2).cast(qtype, meta=scaling_tensor2.meta))
self._check_scaling_tensor(scaling_tensor1, scaling_tensor2)

# w/ pre_scale
scaling_tensor1 = input1.cast(qtype)
scaling_tensor2 = input1.cast(qtype)

for i in range(10):
pre_scale = torch.rand(1).item()
scaling_tensor1.meta.pre_scale.fill_(pre_scale)
scaling_tensor2.meta.pre_scale.fill_(pre_scale)
input2 = torch.rand(size, dtype=dtype, device='cuda')
meta = scaling_tensor1.meta
Arithmetic.add_to_fp8(scaling_tensor1.value, meta, input2)
Expand Down
Loading