Integrate DeepSeek Sparse Attention with Tokamax Flash Attention#3087
Integrate DeepSeek Sparse Attention with Tokamax Flash Attention#3087copybara-service[bot] merged 1 commit intomainfrom
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
4b7e14c to
908d7bb
Compare
7b2fe81 to
ffdc797
Compare
|
🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
📋 Review Summary
This pull request successfully integrates DeepSeek Sparse Attention with Tokamax Flash Attention, which provides a notable performance improvement. The changes are well-structured, and the logic for handling the dynamic sparse mask within the flash attention kernel appears correct.
🔍 General Feedback
- The unit tests have been significantly improved to cover both the
dot_productandflashattention paths, ensuring parity between the two implementations. The parameterization of the tests is a good addition. - The movement of the
index_maskreshaping logic into theapply_attention_dotfunction is a clean refactoring. - The updates to the configuration validation and compile tests are thorough and ensure the new attention mechanism is properly supported.
Overall, this is a high-quality contribution that enhances the performance of sparse attention. The suggestions provided are minor and aimed at improving code clarity and test readability.
RissyRan
left a comment
There was a problem hiding this comment.
@gemini-cli /review
|
🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
📋 Review Summary
This pull request successfully integrates DeepSeek Sparse Attention with Tokamax Flash Attention, which is a valuable performance enhancement. The implementation is solid, and the added tests are comprehensive, ensuring the correctness of the new attention path.
🔍 General Feedback
- The changes are well-structured and easy to follow.
- The extension of the test suite to cover both
dot_productandflashattention with various configurations is a great addition and significantly improves the robustness of the implementation. - The PR description is clear and provides good context, including performance numbers.
There was a problem hiding this comment.
Thank you for the tokamax flash integration! Excited to see throughput improving! Overall looks good.
Observed that both dot_product and flash attention require normalization to pass at longer sequence length like seq=128.
Yes, normalization is needed with larger seq. In the test, the torch params are manually initialized with N(0, 1); without init they are all zeros.
maxtext/tests/unit/deepseek32_vs_reference_test.py
Lines 708 to 715 in 95ef3e1
As a result of std=1, the logits are much larger than what we usually have, e.g., 2000 (log). Also I see normalization is used for gpt-oss attention tests.
maxtext/tests/unit/gpt_vs_reference_test.py
Line 445 in 95ef3e1
Performed a manual diff of the final three entries of the last batch. I do observed 1-2 outliers with index_topk=4 between the dot_product and flash_attention implementations
(Optional) Do you think we can add a test that directly compare flash against dot product, like what we have in attention_test.py?
maxtext/tests/unit/attention_test.py
Lines 478 to 479 in 95ef3e1
either inside deepseek32_vs_reference_test.py (e.g., this can pass) or attention_test.
ffdc797 to
364e38b
Compare
Thanks for giving that a shot, Shuning! If this normalization holds up, I’d actually prefer to compare it directly against the reference implementation. Since the codebase is evolving so fast, there's a risk of regressions in dot_product as well. Comparing against the reference feels like the safest bet—what do you think? |
gagika
left a comment
There was a problem hiding this comment.
Thanks, one comment, can be a follow up fix as well.
5114fd5 to
eaeca36
Compare
eaeca36 to
b9457e4
Compare
Description
Integrate DSA with Tokamax Flash Attention
make_dynamic_splash_mhato support 3D dynamic index masking within the Flash AttentionTests
python3 -m unittest tests.unit.deepseek32_vs_reference_test- logs: linkChecklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.