mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-04-20 21:46:07 +00:00
graph : make mla compatible with FA
This commit is contained in:
parent
23c0090fa4
commit
0100feb33e
@ -1200,9 +1200,6 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
//const auto & n_embd_head_k = hparams.n_embd_head_k;
|
||||
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
||||
|
||||
// note: for MLA with the absorption optimization, the final embedding size will be changed via v_mla
|
||||
const auto n_embd_head_v = v_mla == nullptr ? v_trans ? v->ne[1] : v->ne[0] : v_mla->ne[1];
|
||||
|
||||
const auto n_tokens = q->ne[1];
|
||||
const auto n_head = q->ne[2];
|
||||
const auto n_kv = k->ne[1];
|
||||
@ -1231,7 +1228,12 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
|
||||
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||
|
||||
cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
|
||||
if (v_mla) {
|
||||
cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
|
||||
cur = ggml_mul_mat(ctx0, v_mla, cur);
|
||||
}
|
||||
|
||||
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
||||
} else {
|
||||
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
|
||||
@ -1274,9 +1276,9 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
kqv = ggml_mul_mat(ctx0, v_mla, kqv);
|
||||
}
|
||||
|
||||
ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||
|
||||
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
|
||||
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
||||
|
||||
if (!cparams.offload_kqv) {
|
||||
// all nodes between the KV store and the attention output are run on the CPU
|
||||
|
Loading…
x
Reference in New Issue
Block a user