Simplified is_mla branch in llm_build_deepseek2()

This commit is contained in:
juk 2025-04-13 12:41:33 +01:00
parent a5df71ec9c
commit 925af997e8

View File

@ -10147,27 +10147,27 @@ struct llm_build_deepseek2 : public llm_graph_context {
q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
cb(q_nope, "q_nope_perm", il);
// {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head}
ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope);
cb(q_nope_absorbed, "q_nope_absorbed", il);
// {n_embd_head_qk_rope, n_tokens, n_head}
q_pe = ggml_permute(ctx0, q_pe, 0, 2, 1, 3);
cb(q_pe, "q_pe_perm", il);
// {kv_lora_rank, n_head, n_tokens}
q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3);
cb(q_nope_absorbed, "q_nope_absorbed_perm", il);
// {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
// note: rope must go first for in-place context shifting in build_rope_shift()
ggml_tensor * q_states = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0);
cb(q_states, "q_states", il);
// {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
q_states = ggml_permute(ctx0, q_states, 0, 2, 1, 3);
cb(q_states, "q_states_perm", il);
k_pe = ggml_reshape_2d(ctx0, k_pe, n_embd_head_qk_rope, n_tokens);
cb(k_pe, "k_pe_reshape", il);
kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
cb(kv_cmpr, "kv_cmpr_reshape", il);
// {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
ggml_tensor * k_states = ggml_concat(ctx0, k_pe, kv_cmpr, 0);
cb(k_states, "k_states", il);
// {kv_lora_rank, 1, n_tokens}
ggml_tensor * v_states = kv_cmpr;
cb(v_states, "v_states", il);