Skip to content

[CUDA] New 4bit GEMM kernels for inference#1949

Open
matthewdouglas wants to merge 7 commits into
mainfrom
gemm4bit
Open

[CUDA] New 4bit GEMM kernels for inference#1949
matthewdouglas wants to merge 7 commits into
mainfrom
gemm4bit

Conversation

@matthewdouglas
Copy link
Copy Markdown
Member

@matthewdouglas matthewdouglas commented May 15, 2026

TL;DR

4-bit inference is up to 4x faster at batch serving sizes (i.e. 2-64) across GPUs from Turing through Blackwell. Gains are largest at smaller batches. In many cases the new kernel is faster at batch size of 1 also. We also see improvements when using nested quantization, where the fusion in the new kernels provides an advantage.

Summary

This PR adds new custom fused 4-bit dequantize+GEMM kernels. It is intended to replace both the existing batch size 1 path (GEMV kernel) and thedequantize + F.linear path for small-to-medium batch sizes. The new kernels additionally fuse the nested absmax dequantization and optional bias addition.

New kernels

There are three new kernel tiers, selected automatically at runtime via heuristics calibrated with benchmarks on T4, A10, A100, RTX 4090, L4, L40S, H100, H200, B200, and RTX PRO 6000. The heuristics are expected to generalize well across most hardware on sm75+. Note that while the SIMT kernel does compile and should work on Pascal and Volta, we did not perform benchmarks there and will dispatch to it more conservatively.

Claude Code was used to create a simulation tool to help calibrate these heuristics. We lean towards conservatively dispatching the new kernel to avoid any performance regressions, so across many shapes we may miss some wins by falling back earlier than strictly necessary. Across all of the GPU testing on over 40 shapes and batch sizes from 1 up through 2048, we make a regressive decision less than 1% of the time, and make the most ideal decision over 80% of the time. For the remaining decisions, we simply left performance on the table by either falling back too early, or selecting a kernel where we had a faster alternative available.

Kernel Target dtypes
SIMT sm60+ (all GPUs) bf16, fp16, fp32
MMA m16n8k8 sm75 (Turing, e.g. T4) fp16 only
MMA m16n8k16 sm80+ (Ampere, Ada, Hopper, Blackwell) bf16, fp16

The dispatch system considers problem dimensions, GPU architecture, and SM count when factoring in the decisions on which kernel to launch. Within each of the MMA kernels, there are several tiling configurations for the dispatcher to choose from. For larger batch sizes, we will automatically fall back to the existing dequantize + F.linear path.

New custom op layer

Added a new custom operator bitsandbytes::gemm_4bit to abstract this new operator. It is backwards compatible with other backends, i.e. they will continue to launch their existing GEMV implementations. Note that this operator is not intended as a public API and can change in the future. It serves as an extension point to add implementations for additional hardware.

Additional cleanup

bitsandbytes.matmul_4bit now normalizes packed weights to a canonical shape internally, so callers no longer need to pass the quantized weight tensor in any particular orientation. For weights quantized in [out_features, in_features] orientation passing a .t() of the quantized weight tensor is fine, although no longer required, and will produce correct results. However, weights that were quantized in transposed [in_features, out_features] orientation will now emit a DeprecationWarning. Support for this is likely to be dropped in the future. This is not an expected or typical use case.

End-to-End Benchmark Results

All runs: NF4, input_len=128, output_len=128, bitsandbytes 0.49.2 (stable) vs this PR (new).
PyTorch: 2.10.0+cu130 in eager mode. Transformers: 5.7.0.
All hardware except RTX 4090 was hosted on Modal.com.

TPOT = time per output token (decode step), lower is better.

tpot_4090_qwen3_8b tpot_4090_qwen36_27b tpot_a100_qwen36_27b tpot_rtxpro6000_qwen36_27b tpot_t4_qwen35_9b

@matthewdouglas matthewdouglas added this to the v0.50.0 milestone May 15, 2026
@matthewdouglas matthewdouglas added the CUDA Issues and PRs related to the CUDA backend, excluding installation/support help. label May 15, 2026
@github-actions
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CUDA Issues and PRs related to the CUDA backend, excluding installation/support help.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant