Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue: Performance degradation with int8 quantization in multi-batch scenarios #218

Open
kakarotzzz opened this issue Jan 14, 2025 · 2 comments

Comments

@kakarotzzz
Copy link

When using int8 quantization, there is a significant performance drop in multi-batch inference compared to single-batch inference. The single-batch performance is good, but the performance doesn't scale well with increased batch size.

class WeightOnlyInt8Linear(torch.nn.Module):
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: torch.Tensor

    def __init__(self, in_features: int, out_features: int, bias: bool = True,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8))
        self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
  
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales

Current Behavior

  1. The explicit .to(dtype=input.dtype) creates a separate type conversion kernel
  2. In single batch case, inductor can successfully fuse this conversion with gemm
  3. In multi-batch case, the fusion fails and we get:
    • One kernel for int8->fp16 conversion
    • Another kernel for gemm computation
    • This leads to extra memory traffic and lower performance
@cpuhrsch
Copy link
Contributor

cpuhrsch commented Jan 15, 2025

@kakarotzzz You might be able to fuse this using torch._inductor.config.use_mixed_mm = True depending on the PyTorch version you are using. On that note, which version of PyTorch are you using?

@kakarotzzz
Copy link
Author

kakarotzzz commented Jan 15, 2025

I'm using PyTorch 2.5.0 and have enabled these optimization configurations:

torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True 
torch._inductor.config.use_mixed_mm = True

Adding use_mixed_mm = True didn't bring any performance improvements, and performance significantly degrades even with batch_size = 2, testing on 3090

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants