llama : fix Gemma3 SWA KV cache shift (#12373)

* llama : fix Gemma3 SWA KV cache shift

ggml-ci

* hparams : add comment [no ci]
This commit is contained in:
Georgi Gerganov 2025-03-13 19:08:07 +02:00 committed by GitHub
parent be7c303410
commit 84d5475541
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 37 additions and 43 deletions

View File

@ -442,10 +442,10 @@ ggml_tensor * llama_context::build_rope_shift(
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * factors,
float freq_base,
float freq_scale,
ggml_backend_buffer * bbuf) const {
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
const auto & freq_base = cparams.rope_freq_base;
const auto & freq_scale = cparams.rope_freq_scale;
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
const auto & yarn_attn_factor = cparams.yarn_attn_factor;
@ -537,6 +537,17 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
const int64_t n_head_kv = hparams.n_head_kv(il);
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
float freq_base_l = cparams.rope_freq_base;
float freq_scale_l = cparams.rope_freq_scale;
// TODO: improve
if (model.arch == LLM_ARCH_GEMMA3) {
const bool is_sliding = hparams.is_sliding(il);
freq_base_l = is_sliding ? 10000.0f : cparams.rope_freq_base;
freq_scale_l = is_sliding ? 1.0f : cparams.rope_freq_scale;
}
ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
ggml_tensor * k =
@ -546,7 +557,7 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
0);
ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, kv_self->k_l[il]->buffer);
ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv_self->k_l[il]->buffer);
ggml_build_forward_expand(gf, cur);
}

View File

@ -168,6 +168,8 @@ private:
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * factors,
float freq_base,
float freq_scale,
ggml_backend_buffer * bbuf) const;
llm_graph_result_ptr build_kv_self_shift(

View File

@ -1403,34 +1403,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
}
// TODO: improve
bool is_sliding = false;
switch (arch) {
case LLM_ARCH_COHERE2:
{
const int32_t sliding_window_pattern = 4;
is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
} break;
case LLM_ARCH_GEMMA2:
{
const int32_t sliding_window_pattern = 2;
is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
} break;
case LLM_ARCH_GEMMA3:
{
const int32_t sliding_window_pattern = 6;
is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
} break;
case LLM_ARCH_PHI3:
{
is_sliding = hparams.n_swa > 0;
} break;
default:
{
is_sliding = false;
}
};
const bool is_sliding = hparams.is_sliding(il);
const auto & kq_mask = is_sliding ? inp->get_kq_mask_swa() : inp->get_kq_mask();

View File

@ -69,3 +69,11 @@ uint32_t llama_hparams::n_embd_v_s() const {
// corresponds to Mamba's ssm_states size
return ssm_d_state * ssm_d_inner;
}
bool llama_hparams::is_sliding(uint32_t il) const {
if (il < n_layer) {
return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1);
}
GGML_ABORT("fatal error");
}

View File

@ -36,6 +36,7 @@ struct llama_hparams {
uint32_t n_layer;
uint32_t n_rot;
uint32_t n_swa = 0; // sliding window attention (SWA)
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
uint32_t n_expert = 0;
@ -133,6 +134,8 @@ struct llama_hparams {
// dimension of the recurrent state embeddings
uint32_t n_embd_v_s() const;
bool is_sliding(uint32_t il) const;
};
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");

View File

@ -858,11 +858,13 @@ void llama_model::load_hparams(llama_model_loader & ml) {
case LLM_ARCH_GEMMA2:
{
hparams.n_swa = 4096; // default value of gemma 2
hparams.n_swa_pattern = 2;
hparams.attn_soft_cap = true;
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
hparams.attn_soft_cap = true;
switch (hparams.n_layer) {
case 26: type = LLM_TYPE_2B; break;
@ -873,6 +875,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
} break;
case LLM_ARCH_GEMMA3:
{
hparams.n_swa_pattern = 6;
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@ -952,6 +956,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
} break;
case LLM_ARCH_COHERE2:
{
hparams.n_swa_pattern = 4;
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@ -7374,12 +7380,8 @@ struct llm_build_gemma3 : public llm_graph_context {
// TODO: is causal == true correct? might need some changes
auto * inp_attn = build_attn_inp_kv_unified(true, true);
// "5-to-1 interleaved attention"
// 5 layers of local attention followed by 1 layer of global attention
static const int sliding_window_pattern = 6;
for (int il = 0; il < n_layer; ++il) {
const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
const bool is_sliding = hparams.is_sliding(il);
const float freq_base_l = is_sliding ? 10000.0f : freq_base;
const float freq_scale_l = is_sliding ? 1.0f : freq_scale;
@ -7970,13 +7972,8 @@ struct llm_build_cohere2 : public llm_graph_context {
auto * inp_attn = build_attn_inp_kv_unified(true, true);
// sliding window switch pattern
const int32_t sliding_window_pattern = 4;
for (int il = 0; il < n_layer; ++il) {
// three layers sliding window attention (window size 4096) and ROPE
// fourth layer uses global attention without positional embeddings
const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
const bool is_sliding = hparams.is_sliding(il);
// norm
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM, il);