diff --git a/msamp/common/tensor/meta.py b/msamp/common/tensor/meta.py index 37ff6958..cc41884d 100644 --- a/msamp/common/tensor/meta.py +++ b/msamp/common/tensor/meta.py @@ -4,6 +4,7 @@ """MS-AMP ScalingMeta.""" import copy +from typing import Optional import torch from msamp.common.dtype import Floating, Dtypes @@ -13,7 +14,7 @@ class ScalingMeta: """The meta data for scaling tensor.""" in_time_scaling: bool = True - def __init__(self, qtype, scale=None, scale_inv=None, amax=None, window_size=1, group=None): + def __init__(self, qtype, scale=None, scale_inv=None, amax=None, window_size=1, pre_scale=None, group=None): """Constructor. Args: @@ -22,11 +23,13 @@ def __init__(self, qtype, scale=None, scale_inv=None, amax=None, window_size=1, scale_inv (torch.Tensor, optional): The reciprocal of scaling tensor, defaults to None. amax (torch.Tensor, optional): Absolute maximum tensor, defaults to None. window_size (int, optional): Window size, defaults to 1. + pre_scale (torch.Tensor, optional): A pre-scale factor group (torch.distributed.ProcessGroup, optional): Distributed group, defaults to None. """ self.qtype = qtype self.scale = scale if scale is not None else torch.ones((), device='cuda') self.scale_inv = scale_inv if scale_inv is not None else torch.ones((), device='cuda') + self.pre_scale = pre_scale if pre_scale is not None else torch.ones((), device='cuda') self.amax = amax if amax is not None else torch.zeros((window_size, ), device='cuda') self.amax_counter = torch.zeros((), dtype=torch.int32) self.window_size = window_size @@ -36,7 +39,7 @@ def __init__(self, qtype, scale=None, scale_inv=None, amax=None, window_size=1, @staticmethod @torch.jit.script - def compute_scaling_factor(amax, scale, fp_max: float, margin: int): + def compute_scaling_factor(amax, scale, fp_max: float, margin: int, pre_scale: Optional[torch.Tensor] = None): """A function to compute scaling factor. Args: @@ -44,12 +47,15 @@ def compute_scaling_factor(amax, scale, fp_max: float, margin: int): scale (torch.Tensor): Scale tensor. fp_max (float): The maximum value of float point. margin (int): Margin value. + pre_scale (torch.Tensor, optional): A pre-scale factor Returns: return new scaling tensor. """ exp = torch.floor(torch.log2(fp_max / amax)) - margin sf = torch.round(torch.pow(2, torch.abs(exp))) + if pre_scale is not None: + sf.mul_(pre_scale) sf = torch.where(amax > 0.0, sf, scale) sf = torch.where(torch.isfinite(amax), sf, scale) sf = torch.where(exp < 0, 1 / sf, sf) @@ -108,7 +114,7 @@ def reset_scaling_factor(self, qtype=None): self.scale.fill_(1) else: fp_max = Floating.qfp_max[qtype] - sf = ScalingMeta.compute_scaling_factor(self.amax[0], self.scale, fp_max, 0) + sf = ScalingMeta.compute_scaling_factor(self.amax[0], self.scale, fp_max, 0, pre_scale=self.pre_scale) self.scale.copy_(sf) def copy_(self, src): @@ -122,6 +128,7 @@ def copy_(self, src): self.scale_inv.copy_(src.scale_inv) self.amax.copy_(src.amax) self.amax_counter.copy_(src.amax_counter) + self.pre_scale.copy_(src.pre_scale) self.window_size = src.window_size def clone(self): @@ -156,4 +163,5 @@ def __repr__(self): """ return f'ScalingMeta(qtype={self.qtype}, '\ f'scale={self.scale.data:g}, scale_inv={self.scale_inv.data:g}, '\ + f'pre_scale={self.pre_scale.data:g}, '\ f'amax={self.amax.max():g}, window_size={self.window_size})' diff --git a/msamp/megatron/distributed.py b/msamp/megatron/distributed.py index a2079f99..dd08357d 100644 --- a/msamp/megatron/distributed.py +++ b/msamp/megatron/distributed.py @@ -141,7 +141,7 @@ def _get_buffer_type(param): start = pi * max_fp8_mems for p in fp8_partitions[pi]: meta = ScalingMeta(self.wgrad_qtype, scale=scales[t], scale_inv=scale_invs[t], amax=amaxs[t]) - meta.pre_scale = pre_scale + meta.pre_scale.fill_(pre_scale) t += 1 p.main_grad = ScalingTensor(self._grad_buffers[self.wgrad_dtype].get(p.shape, start), meta) self._grad_buffer_param_index_map[self.wgrad_dtype][p] = (start, start + p.numel()) diff --git a/msamp/megatron/optimizer/distrib_optimizer.py b/msamp/megatron/optimizer/distrib_optimizer.py index 3b8f8148..be14e0ec 100644 --- a/msamp/megatron/optimizer/distrib_optimizer.py +++ b/msamp/megatron/optimizer/distrib_optimizer.py @@ -538,6 +538,47 @@ def reduce_model_grads(self, args, timers): # noqa: C901 timers('grads-reduce-scatter').stop() + if args.wgrad_auto_scaling: + # Weight Gradient Auto Scaling + if args.curr_iteration % args.wgrad_auto_scaling_freq == 0: + timers('wgrad-auto-scaling', log_level=1).start(barrier=args.barrier_with_L1_time) + + # update pre_scale in this partition + for model_group in self.model_fp8_groups: + for p in model_group: + g = p.main_grad + if g is not None and not torch.is_tensor(g): + if g.qtype != Dtypes.kfloat8_e4m3: + raise TypeError('g.qtype != Dtypes.kfloat8_e4m3: {}'.format(g.qtype)) + # stat overflow ratio + num_infs = torch.count_nonzero((g.value & 0x7f) == 126) + overflow_ratio = num_infs / g.numel() + if overflow_ratio > args.wgrad_auto_scaling_ratio: + g.meta.pre_scale.div_(2.0) + else: + g.meta.pre_scale.mul_(2.0**(1.0 / args.wgrad_auto_scaling_window)) + + # synchonize pre_scale in all partitions + for model_id, model in enumerate(self.models): + # all fp8 gradients + partitions = self.model_gbuf_ranges[model_id][torch.uint8]['partitions'] + fp8_grads = [[p.main_grad for p in part.keys()] for part in partitions] + # pre_scales in the partition `data_parallel_rank` + pre_scales = [g.meta.pre_scale for g in fp8_grads[data_parallel_rank]] + max_elems_per_rank = max(model._grad_buffer_num_params) + pre_scales = torch.stack(pre_scales) + # padding to max_elems_per_rank + pad = max_elems_per_rank - pre_scales.numel() + pre_scales = F.pad(pre_scales, (0, pad)) + output_pre_scales = pre_scales.new_empty((data_parallel_world_size, max_elems_per_rank)) + torch.distributed._all_gather_base(output_pre_scales, pre_scales, group=data_parallel_group) + # assign pre_scale to all fp8 gradients + for grads, pre_scales in zip(fp8_grads, output_pre_scales): + for g, pre_scale in zip(grads, pre_scales): + g.meta.pre_scale.copy_(pre_scale) + + timers('wgrad-auto-scaling').stop() + def gather_model_params(self, args, timers): # noqa: C901 """All-gather updated model params. diff --git a/msamp/operators/arithmetic/arithmetic.cu b/msamp/operators/arithmetic/arithmetic.cu index ba07ab18..0ab87532 100644 --- a/msamp/operators/arithmetic/arithmetic.cu +++ b/msamp/operators/arithmetic/arithmetic.cu @@ -14,21 +14,23 @@ void add_to_fp8(at::Tensor fp8_tensor, at::Tensor scale, at::Tensor scale_inv, at::Tensor amax, + at::Tensor pre_scale, const at::Tensor& other, bool is_e4m3) { const size_t N = other.numel(); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); TORCH_DTYPE_SWITCH_INPUT(other.scalar_type(), IType, SELECT_FP8_TYPE(is_e4m3, OType, - + constexpr int nvec = 32 / sizeof(IType); - + VectorizedAddToFp8KernelLauncher( reinterpret_cast(other.data_ptr()), reinterpret_cast(fp8_tensor.data_ptr()), reinterpret_cast(scale.data_ptr()), reinterpret_cast(scale_inv.data_ptr()), reinterpret_cast(amax.data_ptr()), + reinterpret_cast(pre_scale.data_ptr()), N, stream ); diff --git a/msamp/operators/arithmetic/arithmetic.py b/msamp/operators/arithmetic/arithmetic.py index 1c5ed748..00d5bb65 100644 --- a/msamp/operators/arithmetic/arithmetic.py +++ b/msamp/operators/arithmetic/arithmetic.py @@ -32,4 +32,6 @@ def add_to_fp8(fp8_tensor, meta, other): is_e4m3 = meta.qtype == Dtypes.kfloat8_e4m3 - msamp_arithmetic.add_to_fp8(fp8_tensor, meta.scale, meta.scale_inv, meta.amax[0], other, is_e4m3) + msamp_arithmetic.add_to_fp8( + fp8_tensor, meta.scale, meta.scale_inv, meta.amax[0], meta.pre_scale, other, is_e4m3 + ) diff --git a/msamp/operators/arithmetic/vectorized_pointwise.h b/msamp/operators/arithmetic/vectorized_pointwise.h index bd765637..852c2dd6 100644 --- a/msamp/operators/arithmetic/vectorized_pointwise.h +++ b/msamp/operators/arithmetic/vectorized_pointwise.h @@ -191,6 +191,7 @@ __global__ void add_to_fp8_kernel(InputType *input, ComputeType *scale, ComputeType *scale_inv, ComputeType *amax, + ComputeType *pre_scale, const size_t N, const size_t num_aligned_elements) { if (threadIdx.x == 0 && blockIdx.x == 0) { @@ -262,12 +263,14 @@ __global__ void add_to_fp8_kernel(InputType *input, ComputeType exp = floorf(log2f(fp_max/(amax_value))); ComputeType sf = roundf(powf(2, fabsf(exp))); + sf *= *pre_scale; + if (amax_value <= 0 || !isfinite(amax_value)) { sf = *scale; } if (exp < 0) { - sf = 1 / sf; + sf = 1.0f / sf; } // using new scaling factor to quantize the input @@ -280,9 +283,9 @@ __global__ void add_to_fp8_kernel(InputType *input, for (int i = 0; i < nvec; ++i) { const InputType val1 = static_cast(input_storer.separate()[i]); const ComputeType val2 = static_cast(output_storer.separate()[i]); - + InputType temp1 = static_cast(val2 * s); - + if constexpr (is_half::value) { temp1 = static_cast(__hadd(temp1, val1)); } else { @@ -296,7 +299,7 @@ __global__ void add_to_fp8_kernel(InputType *input, if (threadIdx.x == 0 && blockIdx.x == 0) { *scale = sf; - *scale_inv = 1.0 / sf; + *scale_inv = 1.0f / sf; } } @@ -363,6 +366,7 @@ void VectorizedAddToFp8KernelLauncher(InputType *input, fp32 *scale, fp32 *scale_inv, fp32 *amax, + fp32 *pre_scale, const size_t N, cudaStream_t stream) { if (N != 0) { @@ -373,9 +377,9 @@ void VectorizedAddToFp8KernelLauncher(InputType *input, constexpr size_t threads = unary_kernel_threads; size_t num_blocks = DIVUP(num_aligned_elements, threads); - // We use DeviceSyncer to sync the amax value between blocks, the block number should be less than - // (SMCount*MaxThreadsPerSM)/unary_kernel_threads, which is 132*2048/512 = 528 on H100 SXM. We set - // max_blocks to half of 528 to make sure it works on other H100 GPUs. + // We use DeviceSyncer to sync the amax value between blocks, the block number should be less than + // (SMCount*MaxThreadsPerSM)/unary_kernel_threads, which is 132*2048/512 = 528 on H100 SXM. We set + // max_blocks to half of 528 to make sure it works on other H100 GPUs. // constexpr size_t max_blocks = 65535; constexpr size_t max_blocks = 264; num_blocks = std::min(num_blocks, max_blocks); @@ -383,16 +387,16 @@ void VectorizedAddToFp8KernelLauncher(InputType *input, switch (align) { case Alignment::SAME_ALIGNED: add_to_fp8_kernel<<>>( - input, output, scale, scale_inv, amax, N, num_aligned_elements); + input, output, scale, scale_inv, amax, pre_scale, N, num_aligned_elements); break; case Alignment::SAME_UNALIGNED: add_to_fp8_kernel<<>>( - input, output, scale, scale_inv, amax, N, num_aligned_elements); + input, output, scale, scale_inv, amax, pre_scale, N, num_aligned_elements); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize add_to_fp8_kernel<1, true, fp32><<>>( - input, output, scale, scale_inv, amax, N, num_aligned_elements); + input, output, scale, scale_inv, amax, pre_scale, N, num_aligned_elements); break; } } @@ -401,4 +405,4 @@ void VectorizedAddToFp8KernelLauncher(InputType *input, } // namespace msamp -#endif // MSAMP_VECTORIZED_POINTWISE_H \ No newline at end of file +#endif // MSAMP_VECTORIZED_POINTWISE_H diff --git a/tests/common/tensor/test_meta.py b/tests/common/tensor/test_meta.py index 09e3cd8b..a64944d3 100644 --- a/tests/common/tensor/test_meta.py +++ b/tests/common/tensor/test_meta.py @@ -59,3 +59,18 @@ def test_disable_in_time_scaling(self): meta = ScalingMeta(Dtypes.kfloat8_e4m3) self.assertFalse(meta.is_in_time_scaling()) ScalingMeta.in_time_scaling = bak + + def test_pre_scale(self): + """Test pre_scale in ScalingMeta.""" + x = torch.randn((4, 4), device='cuda') + meta = ScalingMeta(Dtypes.kfloat8_e4m3) + qtype = Dtypes.kfloat8_e4m3 + q1 = x.cast(qtype, meta) + + r = 0.5 + meta2 = ScalingMeta(Dtypes.kfloat8_e4m3) + meta2.pre_scale.fill_(r) + q2 = x.cast(qtype, meta2) + self.assertTrue(torch.allclose(q1.float(), q2.float(), atol=5e-4)) + self.assertTrue(torch.allclose(q1.meta.scale * r, q2.meta.scale)) + self.assertTrue(torch.allclose(q1.meta.scale_inv / r, q2.meta.scale_inv)) diff --git a/tests/operators/test_arithmetic.py b/tests/operators/test_arithmetic.py index 23386991..e3fa0775 100644 --- a/tests/operators/test_arithmetic.py +++ b/tests/operators/test_arithmetic.py @@ -16,10 +16,11 @@ class ArithmeticTestCase(unittest.TestCase): """A class for Arithmetic test cases.""" def _check_scaling_tensor(self, scaling_tensor1, scaling_tensor2): - self.assertTrue(torch.all(torch.eq(scaling_tensor1.value, scaling_tensor2.value))) - self.assertTrue(torch.all(torch.eq(scaling_tensor1.meta.scale, scaling_tensor2.meta.scale))) - self.assertTrue(torch.all(torch.eq(scaling_tensor1.meta.scale_inv, scaling_tensor2.meta.scale_inv))) - self.assertTrue(torch.all(torch.eq(scaling_tensor1.meta.amax, scaling_tensor2.meta.amax))) + atol = 1e-6 + self.assertTrue(torch.allclose(scaling_tensor1.value, scaling_tensor2.value, atol=3)) + self.assertTrue(torch.allclose(scaling_tensor1.meta.scale, scaling_tensor2.meta.scale, atol=atol)) + self.assertTrue(torch.allclose(scaling_tensor1.meta.scale_inv, scaling_tensor2.meta.scale_inv, atol=atol)) + self.assertTrue(torch.allclose(scaling_tensor1.meta.amax, scaling_tensor2.meta.amax, atol=atol)) @decorator.cuda_test def test_add_to_fp8(self): @@ -31,10 +32,26 @@ def test_add_to_fp8(self): for i, j, dtype, qtype, in itertools.product(sizes, sizes, dtypes, qtypes): size = (i, j) input1 = torch.rand(size, dtype=dtype, device='cuda') + + # w/o pre_scale + scaling_tensor1 = input1.cast(qtype) + scaling_tensor2 = input1.cast(qtype) + + for i in range(10): + input2 = torch.rand(size, dtype=dtype, device='cuda') + meta = scaling_tensor1.meta + Arithmetic.add_to_fp8(scaling_tensor1.value, meta, input2) + scaling_tensor2.copy_((scaling_tensor2.to(dtype) + input2).cast(qtype, meta=scaling_tensor2.meta)) + self._check_scaling_tensor(scaling_tensor1, scaling_tensor2) + + # w/ pre_scale scaling_tensor1 = input1.cast(qtype) scaling_tensor2 = input1.cast(qtype) for i in range(10): + pre_scale = torch.rand(1).item() + scaling_tensor1.meta.pre_scale.fill_(pre_scale) + scaling_tensor2.meta.pre_scale.fill_(pre_scale) input2 = torch.rand(size, dtype=dtype, device='cuda') meta = scaling_tensor1.meta Arithmetic.add_to_fp8(scaling_tensor1.value, meta, input2)