Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Int8 Woq for CPU #161

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions mixtral-moe/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
import torch._dynamo.config
import torch._inductor.config
torch._inductor.config.cpp.enable_kernel_profile = True

def device_sync(device):
if "cuda" in device:
Expand Down Expand Up @@ -132,7 +133,7 @@ def encode_tokens(tokenizer, string, bos=True, device='cuda'):
tokens = tokenizer.encode(string)
if bos:
tokens = [tokenizer.bos_id()] + tokens
return torch.tensor(tokens, dtype=torch.int, device=device)
return torch.tensor(tokens, dtype=torch.int, device=args.device)

def _load_model(checkpoint_path, device, precision, use_tp):
with torch.device('meta'):
Expand Down Expand Up @@ -248,8 +249,13 @@ def callback(x):
if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
prof = contextlib.nullcontext()
else:
torch.profiler._utils._init_for_cuda_graphs()
prof = torch.profiler.profile()
if device == 'cuda':
torch.profiler._utils._init_for_cuda_graphs()
prof = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], use_cuda=True)
profile_sort = 'self_cuda_time_total'
elif device == 'cpu':
prof = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU])
profile_sort = 'self_cpu_time_total'
with prof:
y = generate(
model,
Expand All @@ -263,6 +269,8 @@ def callback(x):
if i == -1:
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
continue
if hasattr(prof, "key_averages"):
print(prof.key_averages().table(sort_by=profile_sort, row_limit=-1))
if hasattr(prof, "export_chrome_trace"):
if use_tp:
prof.export_chrome_trace(f"{profile}_rank_{rank}.json")
Expand Down
21 changes: 20 additions & 1 deletion mixtral-moe/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,20 @@ def convert_for_runtime(self):
return self.mod


# TODO: This is a workaround to speedup int8 woq performance. Will remove this when
# https://github.com/pytorch/pytorch/pull/120985 is in PyTorch stable release.
def linear_forward_int8(x, weight_int8pack, scales, out_features):
if x.is_cuda:
return F.linear(x, weight_int8pack.to(dtype=x.dtype)) * scales

origin_x_size = x.size()
x = x.reshape(-1, origin_x_size[-1])
c = torch.ops.aten._weight_int8pack_mm(x, weight_int8pack, scales)
new_shape = origin_x_size[:-1] + (out_features,)
c = c.reshape(new_shape)
return c


class WeightOnlyBit8Linear(torch.nn.Module):
__constants__ = ['in_features', 'out_features']
in_features: int
Expand All @@ -115,7 +129,12 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True,
self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))

def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
# return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
# TODO: This is a workaround to speedup int8 woq performance. Will remove this when
# https://github.com/pytorch/pytorch/pull/120985 is in PyTorch stable release.
return linear_forward_int8(
input,
self.weight, self.scales, self.out_features)


class ConditionalFeedForwardBit8(nn.Module):
Expand Down
33 changes: 22 additions & 11 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,18 @@ def convert_for_runtime(self):
replace_linear_weight_only_int8_per_channel(self.mod)
return self.mod

# TODO: This is a workaround to speedup int8 woq performance. Will remove this when
# https://github.com/pytorch/pytorch/pull/120985 is in PyTorch stable release.
def linear_forward_int8(x, weight_int8pack, scales, out_features):
if x.is_cuda:
return F.linear(x, weight_int8pack.to(dtype=x.dtype)) * scales

origin_x_size = x.size()
x = x.reshape(-1, origin_x_size[-1])
c = torch.ops.aten._weight_int8pack_mm(x, weight_int8pack, scales)
new_shape = origin_x_size[:-1] + (out_features,)
c = c.reshape(new_shape)
return c

class WeightOnlyInt8Linear(torch.nn.Module):
__constants__ = ['in_features', 'out_features']
Expand All @@ -354,7 +366,12 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True,
self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))

def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
# return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
# TODO: This is a workaround to speedup int8 woq performance. Will remove this when
# https://github.com/pytorch/pytorch/pull/120985 is in PyTorch stable release.
return linear_forward_int8(
input,
self.weight, self.scales, self.out_features)

##### weight only int4 per channel groupwise quantized code ######

Expand Down Expand Up @@ -502,16 +519,10 @@ def __init__(

assert out_features % 8 == 0, "require out_features % 8 == 0"
assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
if use_cuda:
self.register_buffer(
"weight",
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
)
else:
self.register_buffer(
"weight",
torch.empty((out_features, in_features // 2), dtype=torch.uint8)
)
self.register_buffer(
"weight",
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
)
self.register_buffer(
"scales_and_zeros",
torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)
Expand Down