Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Candle won't use half-gemm from cublas when doing fp16 matmul #2139

Closed
lucasavila00 opened this issue Apr 28, 2024 · 8 comments
Closed

Candle won't use half-gemm from cublas when doing fp16 matmul #2139

lucasavila00 opened this issue Apr 28, 2024 · 8 comments

Comments

@lucasavila00
Copy link
Contributor

lucasavila00 commented Apr 28, 2024

This relates to #2136

Related to improving mistral.rs prompt processing speed EricLBuehler/mistral.rs#153

Why does candle use

turing_fp16_s1688gemm_fp16_256x128_ldg8_f2f_tn kernel for F16 matmuls?

Llama.cpp uses

turing_h1688gemm_256x128_ldg8_tn for the same tensor.

image

image

If I understand it correctly from the docs https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-gemm%5B/url%5D

h-gemm stands for half-gemm where as s-gemm stands for standard F32 gemm.

So, is it possible that candle is not using the best kernel, for some reason?

Is it possible that the candle version is doing the matmuls in F32, as the name would suggest, thus being slower than the other kernel?

Our benchmarks are:

Llama.cpp: ~1500t/s
mistral.rs: 1000t/s

And the major contributors are the kernels I mentioned above. Notice the proportion of time spent on each kernel pretty much matches our observed slowdown. More info here EricLBuehler/mistral.rs#153 (comment)

@lucasavila00
Copy link
Contributor Author

@LaurentMazare
Copy link
Collaborator

We're actually using this function which calls the generic gemm variant with sys::cublasComputeType_t::CUBLAS_COMPUTE_32F, a few lines below, so this means that it's a f16 kernel (as per the actual kernel name) but the accumulation is done using f32 as detailed here. Using f16 accumulation would indeed be faster but with lower precision so it's a bit unclear to me what the impact of this would end up being and if that would be good enough for most use cases.

@lucasavila00
Copy link
Contributor Author

Ah, I see, thanks.

We're currently trying to match llama.cpp speed using quantized models, so the loss of precision shouldn't matter for us.

But I can see how it matters for a regular F16 model...

@EricLBuehler
Copy link
Member

Perhaps I could add this to my fork so we can try it out, and then we can merge it if we find an elegant solution?

@lucasavila00
Copy link
Contributor Author

@EricLBuehler I'd be glad to benchmark it, profile it etc if you implement it

@lucasavila00
Copy link
Contributor Author

lucasavila00 commented Apr 29, 2024

I forked candle locally and hacked a call to the following function at https://github.com/huggingface/candle/blob/main/candle-core/src/cuda_backend/mod.rs#L1654

unsafe fn gemm_strided_batched<
    A: cudarc::driver::DevicePtr<half::f16>,
    B: cudarc::driver::DevicePtr<half::f16>,
    C: cudarc::driver::DevicePtrMut<half::f16>,
>(
    handle: sys::cublasHandle_t,
    cfg: StridedBatchedConfig<half::f16>,
    a: &A,
    b: &B,
    c: &mut C,
) -> std::result::Result<(), CublasError> {
    let alpha = cfg.gemm.alpha;
    let beta = cfg.gemm.beta;
    result::gemm_strided_batched_ex(
        handle,
        cfg.gemm.transa,
        cfg.gemm.transb,
        cfg.gemm.m,
        cfg.gemm.n,
        cfg.gemm.k,
        (&alpha) as *const half::f16 as *const _,
        *a.device_ptr() as *const _,
        sys::cudaDataType_t::CUDA_R_16F,
        cfg.gemm.lda,
        cfg.stride_a,
        *b.device_ptr() as *const _,
        sys::cudaDataType_t::CUDA_R_16F,
        cfg.gemm.ldb,
        cfg.stride_b,
        (&beta) as *const half::f16 as *const _,
        *c.device_ptr_mut() as *mut _,
        sys::cudaDataType_t::CUDA_R_16F,
        cfg.gemm.ldc,
        cfg.stride_c,
        cfg.batch_size,
        sys::cublasComputeType_t::CUBLAS_COMPUTE_16F,
        sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP,
    )
}

This matches the llama.cpp config, and it matches the same used kernels. Llama.cpp is at the bottom.

image

I perceived no difference in output quality.

Using these settings made it improve by 15% getting mistral.rs to ~1150t/s

@LaurentMazare
Copy link
Collaborator

Great that it works well with the reduced precision, I've looked a bit at the pytorch codebase and it seems that they use f32 accumulation by default. PyTorch provides an option to disable "reduced precision" here) (which is turned on by default) but this only impacts the truncation setting in SetMathMode. See this issue pytorch/pytorch#123157 .

So to get around this, I've pushed #2141 , this provides a toggle to flip between the reduced precision accumulation and f32 accumulation - which remains the default. It's a global flag so not ideal but at least provides a way to test the reduced precision accumulation, the quantized example has been adapted to use it and indeed benefits from the speedup when using the f16 matmul for the prompt processing. Would that work for your use case?

When it comes to changing the default, it might be better to wait a bit for what happens on the PyTorch side. If models are trained with f32 accumulation, it's a bit unclear to me what the impact will be if one runs inference with a less precise accumulation.

@lucasavila00
Copy link
Contributor Author

I'm also not sure about making it the default.

The approach of #2141 fits our use case.

Thank you a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants