1
0
mirror of https://github.com/ggerganov/llama.cpp.git synced 2025-04-20 21:46:07 +00:00

Used reshape in llm_graph_context::build_attn_mha()

This commit is contained in:
juk 2025-04-12 19:32:19 +01:00
parent e2153236ce
commit 815f4f9ecf

@ -1175,10 +1175,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
// for MQA (ie: GQA with 1 group) we don't need to use a batched matrix multiply
if (n_head_kv == 1) {
q = ggml_view_2d(ctx0, q,
n_embd, n_tokens*n_head,
ggml_row_size(q->type, n_embd),
0);
q = ggml_reshape_2d(ctx0, q, n_embd, n_tokens*n_head);
}
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
@ -1188,11 +1185,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
if (n_head_kv == 1) {
kq = ggml_view_3d(ctx0, kq,
n_kv, n_tokens, n_head,
ggml_row_size(kq->type, n_kv),
ggml_row_size(kq->type, n_kv)*n_tokens,
0);
kq = ggml_reshape_3d(ctx0, kq, n_kv, n_tokens, n_head);
}
if (arch == LLM_ARCH_GROK) {