diff --git a/generate.py b/generate.py index b7a4c113..980fd2c9 100644 --- a/generate.py +++ b/generate.py @@ -216,6 +216,14 @@ def encode_tokens(tokenizer, string, bos=True, device=default_device): tokens = [tokenizer.bos_id()] + tokens return torch.tensor(tokens, dtype=torch.int, device=device) +def _convert_weight(model): + from quantize import WeightOnlyInt4Linear + for fqn, mod in model.named_modules(): + if isinstance(mod, WeightOnlyInt4Linear): + weight = mod.weight.data + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight, mod.inner_k_tiles) + mod.weight = weight_int4pack + def _load_model(checkpoint_path, device, precision, use_tp): use_cuda = 'cuda' in device with torch.device('meta'): @@ -240,12 +248,15 @@ def _load_model(checkpoint_path, device, precision, use_tp): checkpoint = checkpoint["model"] model.load_state_dict(checkpoint, assign=True) + model = model.to(device=device, dtype=precision) + # int4 packed weight needs to be converted after model loading to the specific device + if "int4" in str(checkpoint_path): + _convert_weight(model) + if use_tp: from tp import apply_tp print("Applying tensor parallel to model ...") apply_tp(model) - - model = model.to(device=device, dtype=precision) return model.eval() def _get_model_size(model): diff --git a/mixtral-moe/generate.py b/mixtral-moe/generate.py index 9aa076b6..1807379f 100644 --- a/mixtral-moe/generate.py +++ b/mixtral-moe/generate.py @@ -12,6 +12,7 @@ import torch import torch._dynamo.config import torch._inductor.config +torch._inductor.config.cpp.enable_kernel_profile = True def device_sync(device): if "cuda" in device: @@ -132,7 +133,7 @@ def encode_tokens(tokenizer, string, bos=True, device='cuda'): tokens = tokenizer.encode(string) if bos: tokens = [tokenizer.bos_id()] + tokens - return torch.tensor(tokens, dtype=torch.int, device=device) + return torch.tensor(tokens, dtype=torch.int, device=args.device) def _load_model(checkpoint_path, device, precision, use_tp): with torch.device('meta'): @@ -248,8 +249,13 @@ def callback(x): if (i != num_samples - 1 or not profile) or (use_tp and rank != 0): prof = contextlib.nullcontext() else: - torch.profiler._utils._init_for_cuda_graphs() - prof = torch.profiler.profile() + if device == 'cuda': + torch.profiler._utils._init_for_cuda_graphs() + prof = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], use_cuda=True) + profile_sort = 'self_cuda_time_total' + elif device == 'cpu': + prof = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU]) + profile_sort = 'self_cpu_time_total' with prof: y = generate( model, @@ -263,6 +269,8 @@ def callback(x): if i == -1: print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") continue + if hasattr(prof, "key_averages"): + print(prof.key_averages().table(sort_by=profile_sort, row_limit=-1)) if hasattr(prof, "export_chrome_trace"): if use_tp: prof.export_chrome_trace(f"{profile}_rank_{rank}.json") diff --git a/mixtral-moe/quantize.py b/mixtral-moe/quantize.py index 6312863c..f4857907 100644 --- a/mixtral-moe/quantize.py +++ b/mixtral-moe/quantize.py @@ -98,6 +98,20 @@ def convert_for_runtime(self): return self.mod +# TODO: This is a workaround to speedup int8 woq performance. Will remove this when +# https://github.com/pytorch/pytorch/pull/120985 is in PyTorch stable release. +def linear_forward_int8(x, weight_int8pack, scales, out_features): + if x.is_cuda: + return F.linear(x, weight_int8pack.to(dtype=x.dtype)) * scales + + origin_x_size = x.size() + x = x.reshape(-1, origin_x_size[-1]) + c = torch.ops.aten._weight_int8pack_mm(x, weight_int8pack, scales) + new_shape = origin_x_size[:-1] + (out_features,) + c = c.reshape(new_shape) + return c + + class WeightOnlyBit8Linear(torch.nn.Module): __constants__ = ['in_features', 'out_features'] in_features: int @@ -115,7 +129,12 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True, self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16)) def forward(self, input: torch.Tensor) -> torch.Tensor: - return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales + # return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales + # TODO: This is a workaround to speedup int8 woq performance. Will remove this when + # https://github.com/pytorch/pytorch/pull/120985 is in PyTorch stable release. + return linear_forward_int8( + input, + self.weight, self.scales, self.out_features) class ConditionalFeedForwardBit8(nn.Module): diff --git a/quantize.py b/quantize.py index fb566421..4de61b8d 100644 --- a/quantize.py +++ b/quantize.py @@ -124,8 +124,8 @@ def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128) .to(torch.int32) .reshape_as(w) ) - - return w_int32 + w_uint8 = (w_int32[::,::2] << 4 | w_int32[::,1::2]).to(torch.uint8) + return w_uint8 def group_quantize_tensor(w, n_bit=4, groupsize=128): @@ -335,6 +335,18 @@ def convert_for_runtime(self): replace_linear_weight_only_int8_per_channel(self.mod) return self.mod +# TODO: This is a workaround to speedup int8 woq performance. Will remove this when +# https://github.com/pytorch/pytorch/pull/120985 is in PyTorch stable release. +def linear_forward_int8(x, weight_int8pack, scales, out_features): + if x.is_cuda: + return F.linear(x, weight_int8pack.to(dtype=x.dtype)) * scales + + origin_x_size = x.size() + x = x.reshape(-1, origin_x_size[-1]) + c = torch.ops.aten._weight_int8pack_mm(x, weight_int8pack, scales) + new_shape = origin_x_size[:-1] + (out_features,) + c = c.reshape(new_shape) + return c class WeightOnlyInt8Linear(torch.nn.Module): __constants__ = ['in_features', 'out_features'] @@ -352,15 +364,19 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True, self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16)) def forward(self, input: torch.Tensor) -> torch.Tensor: - return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales + # return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales + # TODO: This is a workaround to speedup int8 woq performance. Will remove this when + # https://github.com/pytorch/pytorch/pull/120985 is in PyTorch stable release. + return linear_forward_int8( + input, + self.weight, self.scales, self.out_features) ##### weight only int4 per channel groupwise quantized code ###### def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles): - weight_int32, scales_and_zeros = group_quantize_tensor( + weight_int4pack, scales_and_zeros = group_quantize_tensor( weight_bf16, n_bit=4, groupsize=groupsize ) - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles) return weight_int4pack, scales_and_zeros @@ -404,7 +420,7 @@ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): @torch.no_grad() def create_quantized_state_dict(self, use_cuda = True): - if use_cuda: + if use_cuda and torch.cuda.is_available(): device="cuda" else: device="cpu" @@ -507,7 +523,7 @@ def __init__( assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0" self.register_buffer( "weight", - torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32) + torch.empty((out_features, in_features // 2), dtype=torch.uint8) ) self.register_buffer( "scales_and_zeros",