llama : Support llama 4 text-only (#12791)

* llama4 conversion

* initial support, no chat template

* clean up a bit

* fix tokenizer conversion

* correct hparams

* try this

* fix shexp

* ffn_inp_normed

* chat template

* clean up model conversion

* add_bos

* add scale_before_ffn

* fix order

* weight_before_ffn

* llm_graph_input_attn_temp

* add chunk attn mask

* build_inp_attn_scale()

* add comment about ggml_repeat

* clarify comments

* fix build
This commit is contained in:
Xuan-Son Nguyen 2025-04-07 23:06:44 +02:00 committed by GitHub
parent 82974011f3
commit 1466621e73
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 532 additions and 22 deletions

View File

@ -714,6 +714,9 @@ class Model:
if chkhsh == "96a5f08be6259352137b512d4157e333e21df7edd3fcd152990608735a65b224":
# ref: https://huggingface.co/inclusionAI/Ling-lite
res = "bailingmoe"
if chkhsh == "d353350c764d8c3b39c763113960e4fb4919bea5fbf208a0e3b22e8469dc7406":
# ref: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct
res = "llama4"
if res is None:
logger.warning("\n")
@ -1608,6 +1611,7 @@ class StableLMModel(Model):
@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM")
class LlamaModel(Model):
model_arch = gguf.MODEL_ARCH.LLAMA
undo_permute = True
def set_vocab(self):
try:
@ -1672,10 +1676,11 @@ class LlamaModel(Model):
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")
if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight", "k_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
if self.undo_permute:
if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight", "k_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
# process the experts separately
if name.find("block_sparse_moe.experts") != -1:
@ -1752,6 +1757,61 @@ class LlamaModel(Model):
raise ValueError(f"Unprocessed experts: {experts}")
@Model.register("Llama4ForConditionalGeneration")
class Llama4Model(LlamaModel):
model_arch = gguf.MODEL_ARCH.LLAMA4
has_vision: bool = False
undo_permute = False
# TODO @ngxson : avoid duplicate this code everywhere by at least support "text_config"
# same with llama, but we need to merge the text_config into the root level of hparams
def __init__(self, *args, **kwargs):
hparams = kwargs["hparams"] if "hparams" in kwargs else Model.load_hparams(args[0])
if "text_config" in hparams:
hparams = {**hparams, **hparams["text_config"]}
kwargs["hparams"] = hparams
super().__init__(*args, **kwargs)
if "vision_config" in hparams:
logger.info("Has vision encoder, but it will be ignored")
self.has_vision = True
# IMPORTANT: the normal "intermediate_size" is renamed to "intermediate_size_mlp", we need to undo this
self.hparams["intermediate_size_moe"] = self.hparams["intermediate_size"]
self.hparams["intermediate_size"] = self.hparams["intermediate_size_mlp"]
def set_vocab(self):
self._set_vocab_gpt2()
self.gguf_writer.add_add_bos_token(True)
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_interleave_moe_layer_step(self.hparams["interleave_moe_layer_step"])
self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size_moe"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
name = name.replace("language_model.", "")
name = name.replace("feed_forward.", "mlp.") # a bit hacky for now
name = name.replace(".router.weight", ".gate.weight") # a bit hacky for now
# split the gate_up into gate and up
if "gate_up_proj" in name:
name_up = name.replace("gate_up_proj", "up_proj.weight")
name_gate = name.replace("gate_up_proj", "gate_proj.weight")
dim_half = data_torch.shape[-1] // 2
gate_proj_weight, up_proj_weight = data_torch.transpose(-1, -2).split(dim_half, dim=-2)
return [
(self.map_tensor_name(name_gate), gate_proj_weight),
(self.map_tensor_name(name_up), up_proj_weight)
]
if name.endswith("down_proj"):
name += ".weight"
data_torch = data_torch.transpose(-1, -2)
if "multi_modal_projector" in name or "vision_model" in name:
return []
return super().modify_tensors(data_torch, name, bid)
@Model.register("Mistral3ForConditionalGeneration")
class Mistral3Model(LlamaModel):
model_arch = gguf.MODEL_ARCH.LLAMA

View File

@ -113,6 +113,7 @@ models = [
{"name": "superbpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/UW/OLMo2-8B-SuperBPE-t180k", },
{"name": "trillion", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/trillionlabs/Trillion-7B-preview", },
{"name": "bailingmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-lite", },
{"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", },
]

View File

@ -116,6 +116,7 @@ class Keys:
RESIDUAL_SCALE = "{arch}.residual_scale"
EMBEDDING_SCALE = "{arch}.embedding_scale"
TOKEN_SHIFT_COUNT = "{arch}.token_shift_count"
INTERLEAVE_MOE_LAYER_STEP = "{arch}.interleave_moe_layer_step"
class Attention:
HEAD_COUNT = "{arch}.attention.head_count"
@ -227,6 +228,7 @@ class GGUFType:
class MODEL_ARCH(IntEnum):
LLAMA = auto()
LLAMA4 = auto()
DECI = auto()
FALCON = auto()
BAICHUAN = auto()
@ -431,6 +433,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.LLAMA: "llama",
MODEL_ARCH.LLAMA4: "llama4",
MODEL_ARCH.DECI: "deci",
MODEL_ARCH.FALCON: "falcon",
MODEL_ARCH.BAICHUAN: "baichuan",
@ -654,6 +657,29 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.LLAMA4: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
],
MODEL_ARCH.DECI: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,

View File

@ -746,6 +746,9 @@ class GGUFWriter:
def add_token_shift_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.TOKEN_SHIFT_COUNT.format(arch=self.arch), count)
def add_interleave_moe_layer_step(self, value: int) -> None:
self.add_uint32(Keys.LLM.INTERLEAVE_MOE_LAYER_STEP.format(arch=self.arch), value)
def add_layer_norm_eps(self, value: float) -> None:
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)

View File

@ -110,6 +110,7 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
};
enum llama_rope_type {

View File

@ -0,0 +1,112 @@
ied 4 ½ months
__ggml_vocab_test__
Führer
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
Hello world
__ggml_vocab_test__
Hello world
__ggml_vocab_test__
Hello World
__ggml_vocab_test__
Hello World
__ggml_vocab_test__
Hello World!
__ggml_vocab_test__
Hello, world!
__ggml_vocab_test__
Hello, world!
__ggml_vocab_test__
this is 🦙.cpp
__ggml_vocab_test__
w048 7tuijk dsdfhu
__ggml_vocab_test__
нещо на Български
__ggml_vocab_test__
កាន់តែពិសេសអាចខលចេញ
__ggml_vocab_test__
🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
Hello
__ggml_vocab_test__
(
__ggml_vocab_test__
=
__ggml_vocab_test__
' era
__ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天
__ggml_vocab_test__
!!!!!!
__ggml_vocab_test__
3
__ggml_vocab_test__
33
__ggml_vocab_test__
333
__ggml_vocab_test__
3333
__ggml_vocab_test__
33333
__ggml_vocab_test__
333333
__ggml_vocab_test__
3333333
__ggml_vocab_test__
33333333
__ggml_vocab_test__
333333333
__ggml_vocab_test__
Cửa Việt
__ggml_vocab_test__
discards
__ggml_vocab_test__
🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天 ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL
__ggml_vocab_test__

View File

@ -0,0 +1,46 @@
1190 220 32 220 18215 7112
50 16800 258
220
256
277
197
198
368
2946
3271
19873 3817
39715 3817
19873 7353
39715 7353
39715 7353 13
19873 24 3817 13
39715 24 3817 13
544 373 9522 112 247 26 36315
99 39923 220 35 9607 21498 21470 3679 9433
1595 7653 633 79829 34051 1636
8755 102595 115960 21125 148305 96819 102816 39048 14105 22528 160234
114590 222 330 14879 21 51358 127 12817 93293 117 24204 330 68239 881 120327 170428 21 89101 330 7384 88230 511 947 1492 3742 7233 21
19873
39715
220 39715
256 39715
277 39715
277 39715 198 277 39715
330
198 319
19 7359
19873 24 386 87799 13 2403 583 650 51358 223 1663 155736 1522 42056 7544 13336 28785 29 4412 20645
17931 4959
31
1922
12325
12325 31
12325 1922
12325 12325
12325 12325 31
12325 12325 1922
12325 12325 12325
47 19811 12077
3260 3579
198 7283 51499 191231 20192 3271 3322 9287 2143 17860 114590 222 330 14879 21 51358 127 12817 93293 117 24204 330 68239 881 120327 170428 21 89101 9522 112 247 172394 247 220 31 220 1922 220 12325 220 12325 31 220 12325 1922 220 12325 12325 220 12325 12325 31 220 12325 12325 1922 220 31 26 31 220 31 396 31 220 31 1043 31 117131 102595 115960 21125 148305 96819 102816 80883 223 1663 155736 1522 42056 7544 13336 28785 29 4412 20645 79745 150278 117079 633 79829 34051 1636 25611 41990 109428 1488 91054 24072 17931 4959 29795 9296 16517 1806 481 96 1386 36633 1609 24 481 1109 650 5074 43 481 57 702 5074 27088 2170 536 24 481 48 650 1933 1696 30262 43 1665 19 32818 262 27236 56

View File

@ -6,6 +6,7 @@
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_LLAMA, "llama" },
{ LLM_ARCH_LLAMA4, "llama4" },
{ LLM_ARCH_DECI, "deci" },
{ LLM_ARCH_FALCON, "falcon" },
{ LLM_ARCH_GROK, "grok" },
@ -114,6 +115,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
{ LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
{ LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" },
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@ -233,6 +235,35 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_LLAMA4,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{
LLM_ARCH_DECI,
{

View File

@ -10,6 +10,7 @@
enum llm_arch {
LLM_ARCH_LLAMA,
LLM_ARCH_LLAMA4,
LLM_ARCH_DECI,
LLM_ARCH_FALCON,
LLM_ARCH_BAICHUAN,
@ -118,6 +119,7 @@ enum llm_kv {
LLM_KV_RESIDUAL_SCALE,
LLM_KV_EMBEDDING_SCALE,
LLM_KV_TOKEN_SHIFT_COUNT,
LLM_KV_INTERLEAVE_MOE_LAYER_STEP,
LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV,

View File

@ -61,6 +61,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
{ "yandex", LLM_CHAT_TEMPLATE_YANDEX },
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
};
llm_chat_template llm_chat_template_from_str(const std::string & name) {
@ -174,6 +175,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_YANDEX;
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("'HUMAN'")) {
return LLM_CHAT_TEMPLATE_BAILING;
} else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
return LLM_CHAT_TEMPLATE_LLAMA4;
}
return LLM_CHAT_TEMPLATE_UNKNOWN;
}
@ -608,7 +611,16 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "<role>ASSISTANT</role>";
}
} else {
} else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA4) {
// Llama 4
for (auto message : chat) {
std::string role(message->role);
ss << "<|header_start|>" << role << "<|header_end|>\n\n" << trim(message->content) << "<|eot|>";
}
if (add_ass) {
ss << "<|header_start|>assistant<|header_end|>\n\n";
}
} else {
// template not supported
return -1;
}

View File

@ -40,6 +40,7 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_MEGREZ,
LLM_CHAT_TEMPLATE_YANDEX,
LLM_CHAT_TEMPLATE_BAILING,
LLM_CHAT_TEMPLATE_LLAMA4,
LLM_CHAT_TEMPLATE_UNKNOWN,
};

View File

@ -59,6 +59,22 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
}
}
void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
if (ubatch->pos && attn_scale) {
const int64_t n_tokens = ubatch->n_tokens;
std::vector<float> attn_scale_data(n_tokens, 0.0f);
for (int i = 0; i < n_tokens; ++i) {
const float pos = ubatch->pos[i];
attn_scale_data[i] = std::log(
std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0
) * f_attn_temp_scale + 1.0;
}
ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*n_pos_per_token*ggml_element_size(attn_scale));
}
}
void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
if (pos_bucket) {
const int64_t n_tokens = ubatch->n_tokens;
@ -458,9 +474,17 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
}
// may need to cut off old tokens for sliding window
// TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
if (data_swa) {
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
f = -INFINITY;
if (hparams.n_attn_chunk) {
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
f = -INFINITY;
}
} else {
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
f = -INFINITY;
}
}
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
}
@ -812,8 +836,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
float w_scale,
llama_expert_gating_func_type gating_op,
int il) const {
int64_t n_embd = cur->ne[0];
int64_t n_tokens = cur->ne[1];
const int64_t n_embd = cur->ne[0];
const int64_t n_tokens = cur->ne[1];
const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
cb(logits, "ffn_moe_logits", il);
@ -841,6 +866,12 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
cb(selection_probs, "ffn_moe_probs_biased", il);
}
// llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k
// see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198
if (arch == LLM_ARCH_LLAMA4) {
selection_probs = logits;
}
// select experts
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
cb(selected_experts->src[0], "ffn_moe_argsort", il);
@ -867,6 +898,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
}
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
if (weight_before_ffn) {
// TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d)
ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens);
repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
cur = ggml_mul(ctx0, repeated, weights);
cb(cur, "ffn_moe_weighted", il);
}
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(up, "ffn_moe_up", il);
@ -894,7 +934,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
cb(experts, "ffn_moe_down", il);
experts = ggml_mul(ctx0, experts, weights);
if (!weight_before_ffn) {
experts = ggml_mul(ctx0, experts, weights);
cb(cur, "ffn_moe_weighted", il);
}
// aggregate experts
ggml_tensor * moe_out = nullptr;
@ -914,6 +957,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
moe_out = ggml_cont(ctx0, moe_out);
}
cb(moe_out, "ffn_moe_out", il);
return moe_out;
}
@ -981,6 +1026,19 @@ ggml_tensor * llm_graph_context::build_inp_pos() const {
return cur;
}
ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
auto inp = std::make_unique<llm_graph_input_attn_temp>(n_pos_per_token(), hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
auto & cur = inp->attn_scale;
cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_token());
ggml_set_input(cur);
res->add_input(std::move(inp));
return cur;
}
ggml_tensor * llm_graph_context::build_inp_out_ids() const {
auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);

View File

@ -100,6 +100,23 @@ public:
const int64_t n_pos_per_token = 1;
};
// temperature tuning, used by llama4
class llm_graph_input_attn_temp : public llm_graph_input_i {
public:
llm_graph_input_attn_temp(int64_t n_pos_per_token, uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
: n_pos_per_token(n_pos_per_token), n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
virtual ~llm_graph_input_attn_temp() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
const int64_t n_pos_per_token = 1;
const uint32_t n_attn_temp_floor_scale;
const float f_attn_temp_scale;
};
class llm_graph_input_pos_bucket : public llm_graph_input_i {
public:
llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
@ -470,6 +487,7 @@ struct llm_graph_context {
ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
ggml_tensor * build_inp_pos() const;
ggml_tensor * build_inp_attn_scale() const;
ggml_tensor * build_inp_out_ids() const;
ggml_tensor * build_inp_mean() const;
ggml_tensor * build_inp_cls() const;

View File

@ -112,6 +112,14 @@ struct llama_hparams {
bool use_alibi = false;
bool attn_soft_cap = false;
uint32_t n_moe_layer_step = 0;
bool use_kq_norm = true;
uint32_t n_attn_chunk = 0;
// values below seems to be fixed on llama4
uint32_t n_no_rope_layer_step = 4;
uint32_t n_attn_temp_floor_scale = 8192;
float f_attn_temp_scale = 0.1;
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;

View File

@ -90,6 +90,8 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_57B_A14B: return "57B.A14B";
case LLM_TYPE_27B: return "27B";
case LLM_TYPE_290B: return "290B";
case LLM_TYPE_17B_16E: return "17Bx16E (Scout)";
case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
default: return "?B";
}
}
@ -550,6 +552,25 @@ void llama_model::load_hparams(llama_model_loader & ml) {
}
}
} break;
case LLM_ARCH_LLAMA4:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step);
hparams.n_swa_pattern = 4; // pattern: 3 chunked - 1 full
hparams.n_attn_chunk = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
hparams.n_swa = 1; // TODO @ngxson : this is added to trigger the SWA branch (we store the chunked attn mask in the SWA tensor), will need to clean this up later
switch (hparams.n_expert) {
case 16: type = LLM_TYPE_17B_16E; break;
case 128: type = LLM_TYPE_17B_128E; break;
default: type = LLM_TYPE_UNKNOWN;
}
if (type == LLM_TYPE_17B_128E) {
hparams.use_kq_norm = false;
}
} break;
case LLM_ARCH_DECI:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@ -1690,6 +1711,56 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
}
} break;
case LLM_ARCH_LLAMA4:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}
GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Llama 4 requires n_moe_layer_step > 0");
for (int i = 0; i < n_layer; ++i) {
bool is_moe_layer = (i + 1) % hparams.n_moe_layer_step == 0;
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
if (is_moe_layer) {
int n_ff_exp = hparams.n_ff_exp;
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
// Shared expert
const int64_t n_ff_shexp = n_ff_exp;
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0);
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd }, 0);
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0);
} else {
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
}
} break;
case LLM_ARCH_DECI:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@ -4203,12 +4274,22 @@ struct llm_build_llama : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
// temperature tuning
ggml_tensor * inp_attn_scale = nullptr;
if (arch == LLM_ARCH_LLAMA4) {
inp_attn_scale = build_inp_attn_scale();
}
auto * inp_attn = build_attn_inp_kv_unified();
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
bool use_rope = arch == LLM_ARCH_LLAMA4
? (il + 1) % hparams.n_no_rope_layer_step != 0
: true;
// norm
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
@ -4246,25 +4327,38 @@ struct llm_build_llama : public llm_graph_context {
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
if (use_rope) {
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
} else if (inp_attn_scale) {
Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale);
}
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
if (arch == LLM_ARCH_LLAMA4 && use_rope && hparams.use_kq_norm) {
// Llama4TextL2Norm
Qcur = ggml_rms_norm(ctx0, Qcur, 1e-6);
Kcur = ggml_rms_norm(ctx0, Kcur, 1e-6);
cb(Qcur, "Qcur_normed", il);
cb(Kcur, "Kcur_normed", il);
}
cur = build_attn(inp_attn, gf,
model.layers[il].wo, model.layers[il].bo,
Qcur, Kcur, Vcur, nullptr, kq_scale, il);
cb(cur, "attn_out", il);
}
if (il == n_layer - 1) {
@ -4282,7 +4376,7 @@ struct llm_build_llama : public llm_graph_context {
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// feed-forward network
// feed-forward network (non-MoE)
if (model.layers[il].ffn_gate_inp == nullptr) {
cur = build_norm(ffn_inp,
@ -4297,6 +4391,38 @@ struct llm_build_llama : public llm_graph_context {
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
} else if (arch == LLM_ARCH_LLAMA4) {
// llama4 MoE
ggml_tensor * ffn_inp_normed = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
ggml_tensor * moe_out = build_moe_ffn(ffn_inp_normed,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
nullptr,
n_expert, n_expert_used,
LLM_FFN_SILU, false,
false, 0.0,
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID,
il);
// Shared experts
ggml_tensor * shexp_out = build_ffn(ffn_inp_normed,
model.layers[il].ffn_up_shexp, NULL, NULL,
model.layers[il].ffn_gate_shexp, NULL, NULL,
model.layers[il].ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(shexp_out, "ffn_moe_shexp", il);
cur = ggml_add(ctx0, moe_out, shexp_out);
cb(cur, "ffn_moe_out_merged", il);
} else {
// MoE branch
cur = build_norm(ffn_inp,
@ -12091,6 +12217,7 @@ llm_graph_result_ptr llama_model::build_graph(
switch (arch) {
case LLM_ARCH_LLAMA:
case LLM_ARCH_LLAMA4:
case LLM_ARCH_MINICPM:
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
@ -12440,6 +12567,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
// use what we call a normal RoPE, operating on pairs of consecutive head values
case LLM_ARCH_LLAMA:
case LLM_ARCH_LLAMA4:
case LLM_ARCH_DECI:
case LLM_ARCH_BAICHUAN:
case LLM_ARCH_STARCODER:

View File

@ -86,6 +86,8 @@ enum llm_type {
LLM_TYPE_57B_A14B,
LLM_TYPE_27B,
LLM_TYPE_290B,
LLM_TYPE_17B_16E, // llama4 Scout
LLM_TYPE_17B_128E, // llama4 Maverick
};
struct llama_layer_posnet {

View File

@ -1616,7 +1616,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "megrez") {
pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
} else if (
tokenizer_pre == "gpt-4o") {
tokenizer_pre == "gpt-4o" ||
tokenizer_pre == "llama4") {
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O;
clean_spaces = false;
} else if (