vllm.model_executor.layers.quantization.online.turboquant ¶
TurboQuant online weight quantization for vLLM.
3-4 bit weight compression via WHT rotation + Lloyd-Max codebook. Load any BF16 checkpoint, compress weights at startup, serve with ~4x smaller GPU memory. Zero calibration data needed.
Algorithm: scalar case of HIGGS (Malinovskii et al., NAACL 2025, aclanthology.org/2025.naacl-long.543; preprint arXiv:2411.17525) — Random Hadamard Transform + MSE-optimal Lloyd-Max grid + per-group normalization. The implementation was originally based on TurboQuant (Zandieh et al., ICLR 2026, arXiv:2504.19874), which targets online KV-cache and ANN vector search; engineering simplifications (scalar over vector, WHT over general random rotations) converged the weight path onto the HIGGS scalar case. The turboquant name is kept for API and plugin-package compatibility.
Usage
vllm serve
TurboQuantOnlineLinearMethod ¶
Bases: LinearMethodBase
Online TQ3/TQ4 weight compression for Linear layers.
Allocates bf16 weight on meta device (zero GPU at init). After weight loading materializes the bf16 on GPU, compresses to TQ packed format. Forward pass uses Triton dequant-GEMM kernels.
Source code in vllm/model_executor/layers/quantization/online/turboquant.py
648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 | |
_PolarQuant ¶
WHT rotation + Gaussian Lloyd-Max codebook quantizer.
Source code in vllm/model_executor/layers/quantization/online/turboquant.py
dequantize ¶
Dequantize. indices: (n_groups, group_size). Returns (n_groups, group_size).
Source code in vllm/model_executor/layers/quantization/online/turboquant.py
quantize ¶
Quantize grouped vectors. x: (n_groups, group_size). Returns (indices, norms).
Source code in vllm/model_executor/layers/quantization/online/turboquant.py
_build_rotation_matrix ¶
Pre-compute inverse rotation matrix W_rot = H @ D2 @ D1 / sqrt(n).
Source code in vllm/model_executor/layers/quantization/online/turboquant.py
_fast_wht_batch ¶
Batched fast WHT. x: (batch, n) where n is power of 2.
Source code in vllm/model_executor/layers/quantization/online/turboquant.py
_get_cached_rotation_matrix ¶
Get or build cached rotation matrix.
Source code in vllm/model_executor/layers/quantization/online/turboquant.py
_pack_indices ¶
Pack quantization indices into uint8.
Source code in vllm/model_executor/layers/quantization/online/turboquant.py
_padded_size ¶
Return (padded_dim, n_groups) for group quantization.
_polar_fused_gemm_kernel ¶
_polar_fused_gemm_kernel(
x_rot_ptr,
stride_xm,
stride_xk,
codes_ptr,
stride_cn,
stride_ck,
norms_ptr,
stride_nn,
stride_ng,
ct_ptr,
out_ptr,
stride_om,
stride_on,
bias_ptr,
batch_size,
out_f,
in_f_padded,
n_groups,
BLOCK_K: constexpr,
BITS: constexpr,
BLOCK_M: constexpr,
BLOCK_N: constexpr,
)
FWHT-on-input: codebook dot product with pre-rotated input.
Note: 3-bit unpacking logic is duplicated in _tq_fused_gemm_kernel (Triton JIT kernels cannot share helper functions).
Source code in vllm/model_executor/layers/quantization/online/turboquant.py
392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 | |
_rotate_input ¶
Apply forward rotation to input, grouped by group_size.
Source code in vllm/model_executor/layers/quantization/online/turboquant.py
_tq_fused_gemm_kernel ¶
_tq_fused_gemm_kernel(
a_ptr,
stride_am,
stride_ak,
packed_ptr,
norms_ptr,
stride_packed_n,
stride_packed_k,
stride_norms_n,
stride_norms_g,
w_rot_ptr,
centroids_ptr,
c_ptr,
stride_cm,
stride_cn,
bias_ptr,
M,
N,
K,
n_groups,
GROUP_SIZE: constexpr,
BITS: constexpr,
BLOCK_M: constexpr,
BLOCK_N: constexpr,
)
Fused TQ dequant-GEMM: unpack + codebook + rotate + scale + accumulate.
Note: 3-bit unpacking logic is duplicated in _polar_fused_gemm_kernel (Triton JIT kernels cannot share helper functions).
Source code in vllm/model_executor/layers/quantization/online/turboquant.py
300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 | |
_tq_fused_gemm_launcher ¶
_tq_fused_gemm_launcher(
x: Tensor,
packed_weight: Tensor,
norms: Tensor,
signs1: Tensor,
signs2: Tensor,
centroids: Tensor,
group_size: int = 128,
bits: int = 4,
bias: Tensor | None = None,
) -> Tensor
Fused TQ dequant + GEMM launcher.
Source code in vllm/model_executor/layers/quantization/online/turboquant.py
_tq_fwht_input_gemm_launcher ¶
_tq_fwht_input_gemm_launcher(
x: Tensor,
packed_weight: Tensor,
norms: Tensor,
signs1: Tensor,
signs2: Tensor,
centroids: Tensor,
group_size: int = 128,
bits: int = 4,
bias: Tensor | None = None,
) -> Tensor
FWHT-on-input GEMM launcher. Rotates input once, then codebook dot.
Source code in vllm/model_executor/layers/quantization/online/turboquant.py
_unpack_indices ¶
Unpack uint8 packed indices back to int64.