Skip to content

Conversation

@Abdennacer-Badaoui
Copy link
Contributor

Description:

Summary

  • Add kQuantizeBlockwise64 kernel that supports blocksize=64 with 4-bit quantization (FP4/NF4) on both warp32 (RDNA) and warp64 (CDNA) hardware
  • Previously, blocksize=64 for 4-bit was only supported on consumer RDNA GPU (warp size 32). Data center CDNA GPUs (MI300, MI325) could not use it because the existing kernel requires threads == blocksize/2 = 32, which underutilizes the 64-wide wavefront
  • The new kernel processes 2 quantization blocks of 64 values per thread block using 64 threads, with logical warps of 32 (WarpReduce<float, 32>) to perform independent reductions per block

Quick comparaison

Test configuration:

Device: AMD Instinct Mi325X VF
PyTorch: 2.8.0+rocm7.1.0.git7a520360
HIP: 7.1.25424-4179531dcd

FP4 Quantization Error (Mean Absolute Error)

Shape Blocksize=128 Blocksize=64 Error Reduction
1K x 1K 0.102941 0.096551 +6.2%
2K x 2K 0.102949 0.096549 +6.2%
4K x 4K 0.102950 0.096545 +6.2%
8K x 4K 0.102948 0.096545 +6.2%
4K x 11K (LLaMA FFN) 0.102948 0.096545 +6.2%
4K x 14K (LLaMA2 FFN) 0.102946 0.096545 +6.2%

NF4 Quantization Error (Mean Absolute Error)

Shape Blocksize=128 Blocksize=64 Error Reduction
1K x 1K 0.076826 0.072796 +5.2%
2K x 2K 0.076834 0.072794 +5.3%
4K x 4K 0.076836 0.072794 +5.3%
8K x 4K 0.076836 0.072796 +5.3%
4K x 11K (LLaMA FFN) 0.076835 0.072796 +5.3%
4K x 14K (LLaMA2 FFN) 0.076835 0.072796 +5.3%

@github-actions
Copy link

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants