Skip to content

[XPU] decouple split_kv_cache and block_attn#6489

Open
RuohengMa wants to merge 1 commit intoPaddlePaddle:developfrom
RuohengMa:decouple
Open

[XPU] decouple split_kv_cache and block_attn#6489
RuohengMa wants to merge 1 commit intoPaddlePaddle:developfrom
RuohengMa:decouple

Conversation

@RuohengMa
Copy link
Contributor

Motivation

decouple split_kv_cache and block_attn

Modifications

decouple split_kv_cache and block_attn

Usage or Command

decouple split_kv_cache and block_attn

Accuracy Tests

None

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link

paddle-bot bot commented Feb 24, 2026

Thanks for your contribution!

@paddle-bot paddle-bot bot added the XPU label Feb 24, 2026
@codecov-commenter
Copy link

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (develop@60e75ea). Learn more about missing BASE report.

Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #6489   +/-   ##
==========================================
  Coverage           ?   68.75%           
==========================================
  Files              ?      391           
  Lines              ?    52809           
  Branches           ?     8225           
==========================================
  Hits               ?    36308           
  Misses             ?    13856           
  Partials           ?     2645           
Flag Coverage Δ
GPU 68.75% <ø> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment on lines +41 to +55
template <typename TC, typename TS>
struct SplitRopeTypeTrait {
using E_Scale = TS;
using D_Scale = TS;
};
template <>
struct SplitRopeTypeTrait<bfloat16, bfloat16> {
using E_Scale = bfloat16;
using D_Scale = float;
};
template <>
struct SplitRopeTypeTrait<int8_t, bfloat16> {
using E_Scale = bfloat16;
using D_Scale = bfloat16;
};

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些代码是不是可以删掉了?

self.rope_3d,
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

多余的空格

Comment on lines +252 to +281
'''
# q = q * k_scales_inv
if is_cache_int8 and has_zp:
if enc_batch > 0 and is_prefix_cache:
origin_shape = q_enc.shape
q_enc_reshaped = paddle.view(
q_enc,
[total_enc_len, kv_num_heads, num_heads // kv_num_heads, head_dim])
q_enc_reshaped = q_enc_reshaped * paddle.view(k_scales_inv, [1, kv_num_heads, 1, head_dim])
q_enc = paddle.view(q_enc_reshaped, origin_shape)

# q_enc_reshaped = paddle.reshape(
# q_enc,
# [total_enc_len, kv_num_heads, num_heads // kv_num_heads, head_dim])
# q_enc_reshaped = q_enc_reshaped * paddle.reshape(k_scales_inv, [1, kv_num_heads, 1, head_dim])
# q_enc = paddle.reshape(q_enc_reshaped, q_enc.shape)
if dec_batch > 0:
origin_shape = q_dec.shape
q_dec_reshaped = paddle.view(
q_dec,
[total_dec_len, kv_num_heads, num_heads // kv_num_heads, head_dim])
q_dec_reshaped = q_dec_reshaped * paddle.view(k_scales_inv, [1, kv_num_heads, 1, head_dim])
q_dec = paddle.view(q_dec_reshaped, origin_shape)

# q_dec_reshaped = paddle.reshape(
# q_dec,
# [total_dec_len, kv_num_heads, num_heads // kv_num_heads, head_dim])
# q_dec_reshaped = q_dec_reshaped * paddle.reshape(k_scales_inv, [1, kv_num_heads, 1, head_dim])
# q_dec = paddle.reshape(q_dec_reshaped, q_dec.shape)
'''

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些看起来是处理 cache 非对称量化的逻辑?应该要保留吧

# if shift:
# out[total_enc_len:, :] = out[total_enc_len:, :] + shift
# if smooth:
# out[total_enc_len:, :] = out[total_enc_len:, :] * smooth

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分代码逻辑同上,是不是原先都是在 C 代码里的?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants