mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-04-16 03:26:08 +00:00
graph : normalize Q, K, V shapes + sync cross attention (#12449)
* graph : normalize Q, K, V shapes and add comments ggml-ci * context : synchronize before getting cross attention data * model : fix command-r attention norm check
This commit is contained in:
parent
bb115d2bf7
commit
75422e8bc4
@ -1143,6 +1143,8 @@ int llama_context::encode(llama_batch & inp_batch) {
|
||||
if (model.arch == LLM_ARCH_T5 && t_embd) {
|
||||
//cross.t_embd = t_embd;
|
||||
|
||||
synchronize();
|
||||
|
||||
cross.n_embd = t_embd->ne[0];
|
||||
cross.n_enc = t_embd->ne[1];
|
||||
cross.v_embd.resize(cross.n_embd*cross.n_enc);
|
||||
|
@ -1378,7 +1378,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
// note: storing RoPE-ed version of K in the KV cache
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
|
||||
|
||||
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
|
||||
v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
|
||||
|
||||
ggml_tensor * v_cache_view = nullptr;
|
||||
|
||||
|
@ -487,9 +487,9 @@ struct llm_graph_context {
|
||||
|
||||
ggml_tensor * build_attn_mha(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
|
||||
ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
|
||||
ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * kq_mask,
|
||||
bool v_trans,
|
||||
@ -502,9 +502,9 @@ struct llm_graph_context {
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur,
|
||||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||
ggml_tensor * kq_b,
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
@ -516,9 +516,9 @@ struct llm_graph_context {
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur,
|
||||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||
ggml_tensor * kq_b,
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
@ -530,9 +530,9 @@ struct llm_graph_context {
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur,
|
||||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||
ggml_tensor * kq_b,
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user