diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 0a43a3af8..89fb33cbc 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -442,10 +442,10 @@ ggml_tensor * llama_context::build_rope_shift( ggml_tensor * cur, ggml_tensor * shift, ggml_tensor * factors, + float freq_base, + float freq_scale, ggml_backend_buffer * bbuf) const { const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; - const auto & freq_base = cparams.rope_freq_base; - const auto & freq_scale = cparams.rope_freq_scale; const auto & yarn_ext_factor = cparams.yarn_ext_factor; const auto & yarn_attn_factor = cparams.yarn_attn_factor; @@ -537,6 +537,17 @@ llm_graph_result_ptr llama_context::build_kv_self_shift( const int64_t n_head_kv = hparams.n_head_kv(il); const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + float freq_base_l = cparams.rope_freq_base; + float freq_scale_l = cparams.rope_freq_scale; + + // TODO: improve + if (model.arch == LLM_ARCH_GEMMA3) { + const bool is_sliding = hparams.is_sliding(il); + + freq_base_l = is_sliding ? 10000.0f : cparams.rope_freq_base; + freq_scale_l = is_sliding ? 1.0f : cparams.rope_freq_scale; + } + ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il); ggml_tensor * k = @@ -546,7 +557,7 @@ llm_graph_result_ptr llama_context::build_kv_self_shift( ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa), 0); - ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, kv_self->k_l[il]->buffer); + ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv_self->k_l[il]->buffer); ggml_build_forward_expand(gf, cur); } diff --git a/src/llama-context.h b/src/llama-context.h index 71d702e8b..88df8950e 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -168,6 +168,8 @@ private: ggml_tensor * cur, ggml_tensor * shift, ggml_tensor * factors, + float freq_base, + float freq_scale, ggml_backend_buffer * bbuf) const; llm_graph_result_ptr build_kv_self_shift( diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 1e3f2efc8..4a53e8392 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1403,34 +1403,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view)); } - // TODO: improve - bool is_sliding = false; - - switch (arch) { - case LLM_ARCH_COHERE2: - { - const int32_t sliding_window_pattern = 4; - is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1); - } break; - case LLM_ARCH_GEMMA2: - { - const int32_t sliding_window_pattern = 2; - is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1); - } break; - case LLM_ARCH_GEMMA3: - { - const int32_t sliding_window_pattern = 6; - is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1); - } break; - case LLM_ARCH_PHI3: - { - is_sliding = hparams.n_swa > 0; - } break; - default: - { - is_sliding = false; - } - }; + const bool is_sliding = hparams.is_sliding(il); const auto & kq_mask = is_sliding ? inp->get_kq_mask_swa() : inp->get_kq_mask(); diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index ea87b2953..58e98bf23 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -69,3 +69,11 @@ uint32_t llama_hparams::n_embd_v_s() const { // corresponds to Mamba's ssm_states size return ssm_d_state * ssm_d_inner; } + +bool llama_hparams::is_sliding(uint32_t il) const { + if (il < n_layer) { + return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1); + } + + GGML_ABORT("fatal error"); +} diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 1fe454103..e3091c812 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -36,6 +36,7 @@ struct llama_hparams { uint32_t n_layer; uint32_t n_rot; uint32_t n_swa = 0; // sliding window attention (SWA) + uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head uint32_t n_expert = 0; @@ -133,6 +134,8 @@ struct llama_hparams { // dimension of the recurrent state embeddings uint32_t n_embd_v_s() const; + + bool is_sliding(uint32_t il) const; }; static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable"); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 522219c01..5647d2ad6 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -858,11 +858,13 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_GEMMA2: { hparams.n_swa = 4096; // default value of gemma 2 + hparams.n_swa_pattern = 2; + hparams.attn_soft_cap = true; + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); - hparams.attn_soft_cap = true; switch (hparams.n_layer) { case 26: type = LLM_TYPE_2B; break; @@ -873,6 +875,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_GEMMA3: { + hparams.n_swa_pattern = 6; + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -952,6 +956,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_COHERE2: { + hparams.n_swa_pattern = 4; + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -7374,12 +7380,8 @@ struct llm_build_gemma3 : public llm_graph_context { // TODO: is causal == true correct? might need some changes auto * inp_attn = build_attn_inp_kv_unified(true, true); - // "5-to-1 interleaved attention" - // 5 layers of local attention followed by 1 layer of global attention - static const int sliding_window_pattern = 6; - for (int il = 0; il < n_layer; ++il) { - const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1); + const bool is_sliding = hparams.is_sliding(il); const float freq_base_l = is_sliding ? 10000.0f : freq_base; const float freq_scale_l = is_sliding ? 1.0f : freq_scale; @@ -7970,13 +7972,8 @@ struct llm_build_cohere2 : public llm_graph_context { auto * inp_attn = build_attn_inp_kv_unified(true, true); - // sliding window switch pattern - const int32_t sliding_window_pattern = 4; - for (int il = 0; il < n_layer; ++il) { - // three layers sliding window attention (window size 4096) and ROPE - // fourth layer uses global attention without positional embeddings - const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1); + const bool is_sliding = hparams.is_sliding(il); // norm cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM, il);