diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2bf97475f..89522dee8 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4422,6 +4422,10 @@ class DeepseekV2Model(Model): self._set_vocab_gpt2() def set_gguf_parameters(self): + + # note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group) + self.hparams["num_key_value_heads"] = 1 + super().set_gguf_parameters() hparams = self.hparams @@ -4430,8 +4434,13 @@ class DeepseekV2Model(Model): if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None: self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"]) - self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) - self.gguf_writer.add_value_length(hparams["v_head_dim"]) + + # note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA + self.gguf_writer.add_key_length(hparams["kv_lora_rank"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length(hparams["kv_lora_rank"]) + self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) + self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) @@ -4500,6 +4509,26 @@ class DeepseekV2Model(Model): else: return [] + # note: MLA with the absorption optimization, needs these two split and k_b_proj transposed + if name.endswith("kv_b_proj.weight"): + name_kb = name.replace("kv_b_proj", "k_b_proj") + name_vb = name.replace("kv_b_proj", "v_b_proj") + + n_head_kv = self.hparams["num_key_value_heads"] + v_head_dim = self.hparams["v_head_dim"] + qk_nope_head_dim = self.hparams["qk_nope_head_dim"] + + assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim) + + kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1]) + k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1) + k_b = k_b.transpose(1, 2) + + return [ + (self.map_tensor_name(name_kb), k_b), + (self.map_tensor_name(name_vb), v_b) + ] + return [(self.map_tensor_name(name), data_torch)] def prepare_tensors(self): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 162070e6e..8fcde2626 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -139,6 +139,8 @@ class Keys: REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count" SLIDING_WINDOW = "{arch}.attention.sliding_window" SCALE = "{arch}.attention.scale" + KEY_LENGTH_MLA = "{arch}.attention.key_length_mla" + VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla" class Rope: DIMENSION_COUNT = "{arch}.rope.dimension_count" @@ -382,6 +384,8 @@ class MODEL_TENSOR(IntEnum): ATTN_Q_B = auto() ATTN_KV_A_MQA = auto() ATTN_KV_B = auto() + ATTN_K_B = auto() + ATTN_V_B = auto() ATTN_Q_A_NORM = auto() ATTN_KV_A_NORM = auto() FFN_SUB_NORM = auto() @@ -590,6 +594,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa", MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b", + MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b", + MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b", MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm", MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm", @@ -1517,6 +1523,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_Q_B, MODEL_TENSOR.ATTN_KV_A_MQA, MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_K_B, + MODEL_TENSOR.ATTN_V_B, MODEL_TENSOR.ATTN_Q_A_NORM, MODEL_TENSOR.ATTN_KV_A_NORM, MODEL_TENSOR.ATTN_OUT, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 485550aad..aef03db15 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -689,6 +689,12 @@ class GGUFWriter: def add_value_length(self, length: int) -> None: self.add_uint32(Keys.Attention.VALUE_LENGTH.format(arch=self.arch), length) + def add_key_length_mla(self, length: int) -> None: + self.add_uint32(Keys.Attention.KEY_LENGTH_MLA.format(arch=self.arch), length) + + def add_value_length_mla(self, length: int) -> None: + self.add_uint32(Keys.Attention.VALUE_LENGTH_MLA.format(arch=self.arch), length) + def add_max_alibi_bias(self, bias: float) -> None: self.add_float32(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 35154e9b5..0bc75cf51 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -677,6 +677,14 @@ class TensorNameMap: "model.layers.{bid}.self_attn.kv_b_proj", # deepseek2 ), + MODEL_TENSOR.ATTN_K_B: ( + "model.layers.{bid}.self_attn.k_b_proj", # deepseek2 + ), + + MODEL_TENSOR.ATTN_V_B: ( + "model.layers.{bid}.self_attn.v_b_proj", # deepseek2 + ), + MODEL_TENSOR.ATTN_Q_A_NORM: ( "model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2 ), diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index a6fddc7fd..62e1480bb 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -140,6 +140,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, + { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, + { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, @@ -1103,6 +1105,8 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" }, + { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, @@ -1563,23 +1567,8 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index 2c2099b3c..98ca00a1b 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -144,6 +144,8 @@ enum llm_kv { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, LLM_KV_ATTENTION_SLIDING_WINDOW, LLM_KV_ATTENTION_SCALE, + LLM_KV_ATTENTION_KEY_LENGTH_MLA, + LLM_KV_ATTENTION_VALUE_LENGTH_MLA, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_SECTIONS, @@ -306,6 +308,8 @@ enum llm_tensor { LLM_TENSOR_ATTN_Q_B, LLM_TENSOR_ATTN_KV_A_MQA, LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, LLM_TENSOR_ATTN_Q_A_NORM, LLM_TENSOR_ATTN_KV_A_NORM, LLM_TENSOR_ATTN_SUB_NORM, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 4735e98ea..d3ef1cbde 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -10,6 +10,7 @@ #include #include #include +#include // // llama_context @@ -473,7 +474,6 @@ ggml_tensor * llama_context::build_rope_shift( const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; const auto & yarn_ext_factor = cparams.yarn_ext_factor; - const auto & yarn_attn_factor = cparams.yarn_attn_factor; const auto & yarn_beta_fast = cparams.yarn_beta_fast; const auto & yarn_beta_slow = cparams.yarn_beta_slow; @@ -482,6 +482,10 @@ ggml_tensor * llama_context::build_rope_shift( const auto & n_rot = hparams.n_rot; const auto & rope_type = hparams.rope_type; + // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly. + // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. + const float yarn_attn_factor_scaled = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor; + ggml_tensor * tmp; if (ggml_is_quantized(cur->type)) { @@ -500,14 +504,14 @@ ggml_tensor * llama_context::build_rope_shift( tmp = ggml_rope_ext_inplace(ctx0, tmp, shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); + yarn_ext_factor, yarn_attn_factor_scaled, yarn_beta_fast, yarn_beta_slow); tmp = ggml_cpy(ctx0, tmp, cur); } else { // we rotate only the first n_rot dimensions tmp = ggml_rope_ext_inplace(ctx0, cur, shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); + yarn_ext_factor, yarn_attn_factor_scaled, yarn_beta_fast, yarn_beta_slow); } return tmp; @@ -2274,6 +2278,11 @@ llama_context * llama_init_from_model( params.flash_attn = false; } + if (params.flash_attn && model->arch == LLM_ARCH_DEEPSEEK2) { + LLAMA_LOG_WARN("%s: flash_attn is not compatible with Deepseek2 - forcing off\n", __func__); + params.flash_attn = false; + } + if (ggml_is_quantized(params.type_v) && !params.flash_attn) { LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__); return nullptr; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index cd955d63b..5d0222b98 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1188,6 +1188,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_tensor * v, ggml_tensor * kq_b, ggml_tensor * kq_mask, + ggml_tensor * v_mla, bool v_trans, float kq_scale) const { //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); @@ -1199,7 +1200,8 @@ ggml_tensor * llm_graph_context::build_attn_mha( //const auto & n_embd_head_k = hparams.n_embd_head_k; //const auto & n_embd_head_v = hparams.n_embd_head_v; - const auto n_embd_head_v = v_trans ? v->ne[1] : v->ne[0]; + // note: for MLA with the absorption optimization, the final embedding size will be changed via v_mla + const auto n_embd_head_v = v_mla == nullptr ? v_trans ? v->ne[1] : v->ne[0] : v_mla->ne[1]; const auto n_tokens = q->ne[1]; const auto n_head = q->ne[2]; @@ -1267,6 +1269,11 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); + // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA + if (v_mla) { + kqv = ggml_mul_mat(ctx0, v_mla, kqv); + } + ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); @@ -1304,6 +1311,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, + ggml_tensor * v_mla, float kq_scale, int il) const { GGML_UNUSED(n_tokens); @@ -1325,7 +1333,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); //cb(k, "v", il); - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale); + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale); cb(cur, "kqv_out", il); @@ -1379,6 +1387,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, + ggml_tensor * v_mla, float kq_scale, int il) const { // these nodes are added to the graph together so that they are not reordered @@ -1464,7 +1473,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v, 0); - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale); + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale); cb(cur, "kqv_out", il); if (wo) { @@ -1504,6 +1513,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, + ggml_tensor * v_mla, float kq_scale, int il) const { // these nodes are added to the graph together so that they are not reordered @@ -1523,7 +1533,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); //cb(k, "v", il); - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale); + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale); cb(cur, "kqv_out", il); @@ -1692,4 +1702,3 @@ void llm_graph_context::build_pooling( ggml_build_forward_expand(gf, cur); } - diff --git a/src/llama-graph.h b/src/llama-graph.h index 5b6618f9e..d192dc149 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -505,11 +505,12 @@ struct llm_graph_context { ggml_tensor * build_attn_mha( ggml_cgraph * gf, - ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q] - ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k] - ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false) + ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q] + ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k] + ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false) ggml_tensor * kq_b, ggml_tensor * kq_mask, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] bool v_trans, float kq_scale) const; @@ -524,6 +525,7 @@ struct llm_graph_context { ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] ggml_tensor * kq_b, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale, int il) const; @@ -538,6 +540,7 @@ struct llm_graph_context { ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] ggml_tensor * kq_b, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale, int il) const; @@ -552,6 +555,7 @@ struct llm_graph_context { ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] ggml_tensor * kq_b, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale, int il) const; diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 4e0b57190..80fcd65df 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -43,6 +43,10 @@ struct llama_hparams { uint32_t n_expert_used = 0; uint32_t n_rel_attn_bkts = 0; + // note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA + uint32_t n_embd_head_k_mla = 0; + uint32_t n_embd_head_v_mla = 0; + // for WavTokenizer struct llama_hparams_posnet posnet; struct llama_hparams_convnext convnext; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index dbf5f1187..7c9d46d81 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -27,7 +27,7 @@ bool llama_kv_cache_unified::init( recurrent = llama_model_is_recurrent(&model); v_trans = !recurrent && !cparams.flash_attn; - can_shift = !recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA + can_shift = !recurrent; LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n", __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index b74dd72cf..248c61748 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1156,6 +1156,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); } ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); @@ -3205,8 +3207,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { { const bool is_lite = (hparams.n_layer == 27); + const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k_mla = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k; + const int64_t n_embd_head_v_mla = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v; + const int64_t n_embd_head_qk_rope = hparams.n_rot; - const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; const int64_t q_lora_rank = hparams.n_lora_q; const int64_t kv_lora_rank = hparams.n_lora_kv; @@ -3232,14 +3240,22 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (!is_lite) { layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); - layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); } else { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); } - layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); - layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, 0); + + // note: only old legacy GGUF files will have the unsplit wkv_b tensor in + if (is_mla) { + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0); + } else { + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v_mla)}, 0); + } + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -4290,6 +4306,8 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); + LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla); + LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla); LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); @@ -4496,7 +4514,7 @@ struct llm_build_llama : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -4709,7 +4727,7 @@ struct llm_build_deci : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1) { @@ -4851,7 +4869,7 @@ struct llm_build_baichuan : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -4966,7 +4984,7 @@ struct llm_build_xverse : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -5091,7 +5109,7 @@ struct llm_build_falcon : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -5221,7 +5239,7 @@ struct llm_build_grok : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1) { @@ -5372,7 +5390,7 @@ struct llm_build_dbrx : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -5486,7 +5504,7 @@ struct llm_build_starcoder : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -5585,7 +5603,7 @@ struct llm_build_refact : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -5739,7 +5757,7 @@ struct llm_build_bert : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) { @@ -5856,7 +5874,7 @@ struct llm_build_bloom : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -5997,7 +6015,7 @@ struct llm_build_mpt : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6143,7 +6161,7 @@ struct llm_build_stablelm : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6266,7 +6284,7 @@ struct llm_build_qwen : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6386,7 +6404,7 @@ struct llm_build_qwen2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6507,7 +6525,7 @@ struct llm_build_qwen2vl : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6634,7 +6652,7 @@ struct llm_build_qwen2moe : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6787,7 +6805,7 @@ struct llm_build_qwen3 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6908,7 +6926,7 @@ struct llm_build_qwen3moe : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -7048,7 +7066,7 @@ struct llm_build_phi2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1) { @@ -7177,7 +7195,7 @@ struct llm_build_phi3 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1) { @@ -7312,7 +7330,7 @@ struct llm_build_plamo : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } ggml_tensor * sa_out = cur; @@ -7419,7 +7437,7 @@ struct llm_build_gpt2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -7535,7 +7553,7 @@ struct llm_build_codeshell : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -7664,7 +7682,7 @@ struct llm_build_orion : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -7791,7 +7809,7 @@ struct llm_build_internlm2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -7988,7 +8006,7 @@ struct llm_build_minicpm3 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - q_states, k_states, v_states, nullptr, kq_scale, il); + q_states, k_states, v_states, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1) { @@ -8118,7 +8136,7 @@ struct llm_build_gemma : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1) { @@ -8240,7 +8258,7 @@ struct llm_build_gemma2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } cur = build_norm(cur, @@ -8381,7 +8399,7 @@ struct llm_build_gemma3 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, hparams.f_attention_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il); } cur = build_norm(cur, @@ -8521,7 +8539,7 @@ struct llm_build_starcoder2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -8856,7 +8874,7 @@ struct llm_build_command_r : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -8991,7 +9009,7 @@ struct llm_build_cohere2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -9122,7 +9140,7 @@ struct llm_build_olmo : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, nullptr, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -9242,7 +9260,7 @@ struct llm_build_olmo2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } cur = build_norm(cur, @@ -9375,7 +9393,7 @@ struct llm_build_olmoe : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -9508,7 +9526,7 @@ struct llm_build_openelm : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -9622,7 +9640,7 @@ struct llm_build_gptneox : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -9772,7 +9790,7 @@ struct llm_build_arctic : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -9927,7 +9945,7 @@ struct llm_build_deepseek : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1) { @@ -10017,16 +10035,23 @@ struct llm_build_deepseek2 : public llm_graph_context { llm_build_deepseek2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { bool is_lite = (hparams.n_layer == 27); + const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k; + const int64_t n_embd_head_v = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v; + + const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope; + + const uint32_t kv_lora_rank = hparams.n_lora_kv; + // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly. // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); - const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k)); + const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(n_embd_head_k)); const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); - const uint32_t n_embd_head_qk_rope = hparams.n_rot; - const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; - const uint32_t kv_lora_rank = hparams.n_lora_kv; - ggml_tensor * cur; ggml_tensor * inpL; @@ -10051,16 +10076,14 @@ struct llm_build_deepseek2 : public llm_graph_context { { ggml_tensor * q = NULL; if (!is_lite) { - // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens} q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); cb(q, "q", il); q = build_norm(q, - model.layers[il].attn_q_a_norm, NULL, + model.layers[il].attn_q_a_norm, nullptr, LLM_NORM_RMS, il); cb(q, "q", il); - // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens} q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q); cb(q, "q", il); } else { @@ -10068,96 +10091,125 @@ struct llm_build_deepseek2 : public llm_graph_context { cb(q, "q", il); } - // split into {n_head * n_embd_head_qk_nope, n_tokens} - ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(q->type, hparams.n_embd_head_k), - ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + // split into {n_embd_head_qk_nope, n_head, n_tokens} + ggml_tensor * q_nope = ggml_view_3d(ctx0, q, + n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(q->type, n_embd_head_k), + ggml_row_size(q->type, n_embd_head_k) * n_head, 0); cb(q_nope, "q_nope", il); - // and {n_head * n_embd_head_qk_rope, n_tokens} - ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, - ggml_row_size(q->type, hparams.n_embd_head_k), - ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + // and {n_embd_head_qk_rope, n_head, n_tokens} + ggml_tensor * q_pe = ggml_view_3d(ctx0, q, + n_embd_head_qk_rope, n_head, n_tokens, + ggml_row_size(q->type, n_embd_head_k), + ggml_row_size(q->type, n_embd_head_k) * n_head, ggml_row_size(q->type, n_embd_head_qk_nope)); cb(q_pe, "q_pe", il); - // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens} - ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); - cb(kv_pe_compresseed, "kv_pe_compresseed", il); + ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(kv_cmpr_pe, "kv_cmpr_pe", il); // split into {kv_lora_rank, n_tokens} - ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens, - kv_pe_compresseed->nb[1], + ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_cmpr_pe, + kv_lora_rank, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0); - cb(kv_compressed, "kv_compressed", il); + cb(kv_cmpr, "kv_cmpr", il); - // and {n_embd_head_qk_rope, n_tokens} - ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens, - kv_pe_compresseed->nb[1], - kv_pe_compresseed->nb[1], - ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); + // and {n_embd_head_qk_rope, 1, n_tokens} + ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, + n_embd_head_qk_rope, 1, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank)); cb(k_pe, "k_pe", il); - // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont - kv_compressed = ggml_cont(ctx0, kv_compressed); - kv_compressed = build_norm(kv_compressed, - model.layers[il].attn_kv_a_norm, NULL, - LLM_NORM_RMS, il); - cb(kv_compressed, "kv_compressed", il); - - // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} - ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); - cb(kv, "kv", il); - - // split into {n_head * n_embd_head_qk_nope, n_tokens} - ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), - ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), - 0); - cb(k_nope, "k_nope", il); - - // and {n_head * n_embd_head_v, n_tokens} - ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), - ggml_row_size(kv->type, (n_embd_head_qk_nope))); - cb(v_states, "v_states", il); - - v_states = ggml_cont(ctx0, v_states); - cb(v_states, "v_states", il); - - v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, - ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), - 0); - cb(v_states, "v_states", il); - - q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this - q_pe = ggml_rope_ext( - ctx0, q_pe, inp_pos, nullptr, + q_pe = ggml_rope_ext(ctx0, q_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor_scaled, beta_fast, beta_slow - ); + ); cb(q_pe, "q_pe", il); - // shared RoPE key - k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this - k_pe = ggml_rope_ext( - ctx0, k_pe, inp_pos, nullptr, + k_pe = ggml_rope_ext(ctx0, k_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor_scaled, beta_fast, beta_slow - ); + ); cb(k_pe, "k_pe", il); - ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); - cb(q_states, "q_states", il); + kv_cmpr = build_norm(kv_cmpr, + model.layers[il].attn_kv_a_norm, nullptr, + LLM_NORM_RMS, il); + cb(kv_cmpr, "kv_cmpr", il); - ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); - cb(k_states, "k_states", il); + if (is_mla) { + // {n_embd_head_qk_nope, n_tokens, n_head} + q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); + cb(q_nope, "q_nope_perm", il); - cur = build_attn(inp_attn, gf, - model.layers[il].wo, NULL, - q_states, k_states, v_states, nullptr, kq_scale, il); + // {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head} + ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope); + cb(q_nope_absorbed, "q_nope_absorbed", il); + + // {kv_lora_rank, n_head, n_tokens} + q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3); + cb(q_nope_absorbed, "q_nope_absorbed_perm", il); + + // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens} + // note: rope must go first for in-place context shifting in build_rope_shift() + ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0); + cb(Qcur, "Qcur", il); + + kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens); + cb(kv_cmpr, "kv_cmpr_reshape", il); + + // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens} + ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0); + cb(Kcur, "Kcur", il); + + // {kv_lora_rank, 1, n_tokens} + ggml_tensor * Vcur = kv_cmpr; + cb(Vcur, "Vcur", il); + + // note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group) + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, model.layers[il].wv_b, kq_scale, il); + } else { + ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cmpr); + cb(kv, "kv", il); + + // split into {n_embd_head_qk_nope, n_head, n_tokens} + ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, + n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v), + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v) * n_head, + 0); + cb(k_nope, "k_nope_view", il); + + // and {n_embd_head_v, n_head, n_tokens} + ggml_tensor * Vcur = ggml_view_3d(ctx0, kv, + n_embd_head_v, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v), + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v) * n_head, + ggml_row_size(kv->type, n_embd_head_qk_nope)); + cb(Vcur, "Vcur_view", il); + + Vcur = ggml_cont(ctx0, Vcur); + cb(Vcur, "Vcur_cont", il); + + // note: rope must go first for in-place context shifting in build_rope_shift() + ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope, 0); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0); + cb(Kcur, "Kcur", il); + + // note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups) + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + } } if (il == n_layer - 1) { @@ -10323,7 +10375,7 @@ struct llm_build_bitnet : public llm_graph_context { cur = build_attn(inp_attn, gf, NULL, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cur = build_norm(cur, model.layers[il].attn_sub_norm, NULL, @@ -10446,7 +10498,7 @@ struct llm_build_t5_enc : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo_enc, nullptr, - Qcur, Kcur, Vcur, kq_b, 1.0f, il); + Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il); cb(cur, "kqv_out", il); } @@ -10552,7 +10604,7 @@ struct llm_build_t5_dec : public llm_graph_context { cur = build_attn(inp_attn_self, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, kq_b, 1.0f, il); + Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il); cb(cur, "kqv_out", il); } @@ -10584,7 +10636,7 @@ struct llm_build_t5_dec : public llm_graph_context { cur = build_attn(inp_attn_cross, gf, model.layers[il].wo_cross, nullptr, - Qcur, Kcur, Vcur, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); cb(cur, "kqv_out", il); //ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); @@ -10717,7 +10769,7 @@ struct llm_build_jais : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/float(n_embd_head), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/float(n_embd_head), il); } if (il == n_layer - 1) { @@ -10849,7 +10901,7 @@ struct llm_build_chatglm : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -10982,7 +11034,7 @@ struct llm_build_glm4 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -11126,7 +11178,7 @@ struct llm_build_nemotron : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -11257,7 +11309,7 @@ struct llm_build_exaone : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -12159,7 +12211,7 @@ struct llm_build_chameleon : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, nullptr, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); if (hparams.swin_norm) { cur = build_norm(cur, @@ -12515,7 +12567,7 @@ struct llm_build_plm : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - q_states, k_states, v_states, nullptr, kq_scale, il); + q_states, k_states, v_states, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1) { @@ -12638,7 +12690,7 @@ struct llm_build_bailingmoe : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_rot)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il); } if (il == n_layer - 1) { diff --git a/src/llama-model.h b/src/llama-model.h index 0f18dac16..fd82d106c 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -171,6 +171,8 @@ struct llama_layer { struct ggml_tensor * wq_b = nullptr; struct ggml_tensor * wkv_a_mqa = nullptr; struct ggml_tensor * wkv_b = nullptr; + struct ggml_tensor * wk_b = nullptr; + struct ggml_tensor * wv_b = nullptr; struct ggml_tensor * wq_cross = nullptr; struct ggml_tensor * wk_cross = nullptr; struct ggml_tensor * wv_cross = nullptr;