Removed MQA optimisation from build_attn_mha() as no gains now

This commit is contained in:
juk 2025-04-13 12:40:31 +01:00
parent 638b092d7a
commit a5df71ec9c

View File

@ -1200,18 +1200,12 @@ ggml_tensor * llm_graph_context::build_attn_mha(
//const auto & n_embd_head_k = hparams.n_embd_head_k;
//const auto & n_embd_head_v = hparams.n_embd_head_v;
const auto n_embd = q->ne[0];
const auto n_tokens = q->ne[1];
const auto n_head = q->ne[2];
const auto n_kv = k->ne[1];
const auto n_head_kv = k->ne[2];
// note: for MLA with the absorption optimization, the final embedding size will be changed via v_mla
const auto n_embd_head_v = v_mla == nullptr ? v_trans ? v->ne[1] : v->ne[0] : v_mla->ne[1];
GGML_ASSERT(k->ne[0] == q->ne[0] && "K and Q embedding size mismatch");
GGML_ASSERT(k->ne[2] == v->ne[2] && "K and V number of heads mismatch");
const auto n_tokens = q->ne[1];
const auto n_head = q->ne[2];
const auto n_kv = k->ne[1];
ggml_tensor * cur;
@ -1239,22 +1233,12 @@ ggml_tensor * llm_graph_context::build_attn_mha(
cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
} else {
// for MQA (ie: GQA with 1 group) we don't need to use a batched matrix multiply
if (n_head_kv == 1) {
q = ggml_reshape_2d(ctx0, q, n_embd, n_tokens*n_head);
}
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
// note: this op tends to require high floating point range
// while for some models F16 is enough, for others it is not, so we default to F32 here
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
if (n_head_kv == 1) {
kq = ggml_reshape_3d(ctx0, kq, n_kv, n_tokens, n_head);
}
if (arch == LLM_ARCH_GROK) {
// need to do the following:
// multiply by attn_output_multiplyer of 0.08838834764831845