mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-04-13 01:56:09 +00:00
llama : Support llama 4 text-only (#12791)
* llama4 conversion * initial support, no chat template * clean up a bit * fix tokenizer conversion * correct hparams * try this * fix shexp * ffn_inp_normed * chat template * clean up model conversion * add_bos * add scale_before_ffn * fix order * weight_before_ffn * llm_graph_input_attn_temp * add chunk attn mask * build_inp_attn_scale() * add comment about ggml_repeat * clarify comments * fix build
This commit is contained in:
parent
82974011f3
commit
1466621e73
@ -714,6 +714,9 @@ class Model:
|
||||
if chkhsh == "96a5f08be6259352137b512d4157e333e21df7edd3fcd152990608735a65b224":
|
||||
# ref: https://huggingface.co/inclusionAI/Ling-lite
|
||||
res = "bailingmoe"
|
||||
if chkhsh == "d353350c764d8c3b39c763113960e4fb4919bea5fbf208a0e3b22e8469dc7406":
|
||||
# ref: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||
res = "llama4"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
@ -1608,6 +1611,7 @@ class StableLMModel(Model):
|
||||
@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM")
|
||||
class LlamaModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.LLAMA
|
||||
undo_permute = True
|
||||
|
||||
def set_vocab(self):
|
||||
try:
|
||||
@ -1672,10 +1676,11 @@ class LlamaModel(Model):
|
||||
n_head = self.hparams["num_attention_heads"]
|
||||
n_kv_head = self.hparams.get("num_key_value_heads")
|
||||
|
||||
if name.endswith(("q_proj.weight", "q_proj.bias")):
|
||||
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
|
||||
if name.endswith(("k_proj.weight", "k_proj.bias")):
|
||||
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
|
||||
if self.undo_permute:
|
||||
if name.endswith(("q_proj.weight", "q_proj.bias")):
|
||||
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
|
||||
if name.endswith(("k_proj.weight", "k_proj.bias")):
|
||||
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
|
||||
|
||||
# process the experts separately
|
||||
if name.find("block_sparse_moe.experts") != -1:
|
||||
@ -1752,6 +1757,61 @@ class LlamaModel(Model):
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@Model.register("Llama4ForConditionalGeneration")
|
||||
class Llama4Model(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.LLAMA4
|
||||
has_vision: bool = False
|
||||
undo_permute = False
|
||||
|
||||
# TODO @ngxson : avoid duplicate this code everywhere by at least support "text_config"
|
||||
# same with llama, but we need to merge the text_config into the root level of hparams
|
||||
def __init__(self, *args, **kwargs):
|
||||
hparams = kwargs["hparams"] if "hparams" in kwargs else Model.load_hparams(args[0])
|
||||
if "text_config" in hparams:
|
||||
hparams = {**hparams, **hparams["text_config"]}
|
||||
kwargs["hparams"] = hparams
|
||||
super().__init__(*args, **kwargs)
|
||||
if "vision_config" in hparams:
|
||||
logger.info("Has vision encoder, but it will be ignored")
|
||||
self.has_vision = True
|
||||
# IMPORTANT: the normal "intermediate_size" is renamed to "intermediate_size_mlp", we need to undo this
|
||||
self.hparams["intermediate_size_moe"] = self.hparams["intermediate_size"]
|
||||
self.hparams["intermediate_size"] = self.hparams["intermediate_size_mlp"]
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_gpt2()
|
||||
self.gguf_writer.add_add_bos_token(True)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_interleave_moe_layer_step(self.hparams["interleave_moe_layer_step"])
|
||||
self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size_moe"])
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||
name = name.replace("language_model.", "")
|
||||
name = name.replace("feed_forward.", "mlp.") # a bit hacky for now
|
||||
name = name.replace(".router.weight", ".gate.weight") # a bit hacky for now
|
||||
|
||||
# split the gate_up into gate and up
|
||||
if "gate_up_proj" in name:
|
||||
name_up = name.replace("gate_up_proj", "up_proj.weight")
|
||||
name_gate = name.replace("gate_up_proj", "gate_proj.weight")
|
||||
dim_half = data_torch.shape[-1] // 2
|
||||
gate_proj_weight, up_proj_weight = data_torch.transpose(-1, -2).split(dim_half, dim=-2)
|
||||
return [
|
||||
(self.map_tensor_name(name_gate), gate_proj_weight),
|
||||
(self.map_tensor_name(name_up), up_proj_weight)
|
||||
]
|
||||
|
||||
if name.endswith("down_proj"):
|
||||
name += ".weight"
|
||||
data_torch = data_torch.transpose(-1, -2)
|
||||
|
||||
if "multi_modal_projector" in name or "vision_model" in name:
|
||||
return []
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@Model.register("Mistral3ForConditionalGeneration")
|
||||
class Mistral3Model(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.LLAMA
|
||||
|
@ -113,6 +113,7 @@ models = [
|
||||
{"name": "superbpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/UW/OLMo2-8B-SuperBPE-t180k", },
|
||||
{"name": "trillion", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/trillionlabs/Trillion-7B-preview", },
|
||||
{"name": "bailingmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-lite", },
|
||||
{"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", },
|
||||
]
|
||||
|
||||
|
||||
|
@ -116,6 +116,7 @@ class Keys:
|
||||
RESIDUAL_SCALE = "{arch}.residual_scale"
|
||||
EMBEDDING_SCALE = "{arch}.embedding_scale"
|
||||
TOKEN_SHIFT_COUNT = "{arch}.token_shift_count"
|
||||
INTERLEAVE_MOE_LAYER_STEP = "{arch}.interleave_moe_layer_step"
|
||||
|
||||
class Attention:
|
||||
HEAD_COUNT = "{arch}.attention.head_count"
|
||||
@ -227,6 +228,7 @@ class GGUFType:
|
||||
|
||||
class MODEL_ARCH(IntEnum):
|
||||
LLAMA = auto()
|
||||
LLAMA4 = auto()
|
||||
DECI = auto()
|
||||
FALCON = auto()
|
||||
BAICHUAN = auto()
|
||||
@ -431,6 +433,7 @@ class MODEL_TENSOR(IntEnum):
|
||||
|
||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.LLAMA: "llama",
|
||||
MODEL_ARCH.LLAMA4: "llama4",
|
||||
MODEL_ARCH.DECI: "deci",
|
||||
MODEL_ARCH.FALCON: "falcon",
|
||||
MODEL_ARCH.BAICHUAN: "baichuan",
|
||||
@ -654,6 +657,29 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
MODEL_ARCH.LLAMA4: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
],
|
||||
MODEL_ARCH.DECI: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
@ -746,6 +746,9 @@ class GGUFWriter:
|
||||
def add_token_shift_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.LLM.TOKEN_SHIFT_COUNT.format(arch=self.arch), count)
|
||||
|
||||
def add_interleave_moe_layer_step(self, value: int) -> None:
|
||||
self.add_uint32(Keys.LLM.INTERLEAVE_MOE_LAYER_STEP.format(arch=self.arch), value)
|
||||
|
||||
def add_layer_norm_eps(self, value: float) -> None:
|
||||
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
|
||||
|
||||
|
@ -110,6 +110,7 @@ extern "C" {
|
||||
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
|
||||
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
|
||||
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
|
||||
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
|
||||
};
|
||||
|
||||
enum llama_rope_type {
|
||||
|
112
models/ggml-vocab-llama4.gguf.inp
Normal file
112
models/ggml-vocab-llama4.gguf.inp
Normal file
@ -0,0 +1,112 @@
|
||||
ied 4 ½ months
|
||||
__ggml_vocab_test__
|
||||
Führer
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
__ggml_vocab_test__
|
||||
Hello world
|
||||
__ggml_vocab_test__
|
||||
Hello world
|
||||
__ggml_vocab_test__
|
||||
Hello World
|
||||
__ggml_vocab_test__
|
||||
Hello World
|
||||
__ggml_vocab_test__
|
||||
Hello World!
|
||||
__ggml_vocab_test__
|
||||
Hello, world!
|
||||
__ggml_vocab_test__
|
||||
Hello, world!
|
||||
__ggml_vocab_test__
|
||||
this is 🦙.cpp
|
||||
__ggml_vocab_test__
|
||||
w048 7tuijk dsdfhu
|
||||
__ggml_vocab_test__
|
||||
нещо на Български
|
||||
__ggml_vocab_test__
|
||||
កាន់តែពិសេសអាចខលចេញ
|
||||
__ggml_vocab_test__
|
||||
🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
(
|
||||
__ggml_vocab_test__
|
||||
|
||||
=
|
||||
__ggml_vocab_test__
|
||||
' era
|
||||
__ggml_vocab_test__
|
||||
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
|
||||
__ggml_vocab_test__
|
||||
!!!!!!
|
||||
__ggml_vocab_test__
|
||||
3
|
||||
__ggml_vocab_test__
|
||||
33
|
||||
__ggml_vocab_test__
|
||||
333
|
||||
__ggml_vocab_test__
|
||||
3333
|
||||
__ggml_vocab_test__
|
||||
33333
|
||||
__ggml_vocab_test__
|
||||
333333
|
||||
__ggml_vocab_test__
|
||||
3333333
|
||||
__ggml_vocab_test__
|
||||
33333333
|
||||
__ggml_vocab_test__
|
||||
333333333
|
||||
__ggml_vocab_test__
|
||||
Cửa Việt
|
||||
__ggml_vocab_test__
|
||||
discards
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL
|
||||
__ggml_vocab_test__
|
46
models/ggml-vocab-llama4.gguf.out
Normal file
46
models/ggml-vocab-llama4.gguf.out
Normal file
@ -0,0 +1,46 @@
|
||||
1190 220 32 220 18215 7112
|
||||
50 16800 258
|
||||
|
||||
220
|
||||
256
|
||||
277
|
||||
197
|
||||
198
|
||||
368
|
||||
2946
|
||||
3271
|
||||
19873 3817
|
||||
39715 3817
|
||||
19873 7353
|
||||
39715 7353
|
||||
39715 7353 13
|
||||
19873 24 3817 13
|
||||
39715 24 3817 13
|
||||
544 373 9522 112 247 26 36315
|
||||
99 39923 220 35 9607 21498 21470 3679 9433
|
||||
1595 7653 633 79829 34051 1636
|
||||
8755 102595 115960 21125 148305 96819 102816 39048 14105 22528 160234
|
||||
114590 222 330 14879 21 51358 127 12817 93293 117 24204 330 68239 881 120327 170428 21 89101 330 7384 88230 511 947 1492 3742 7233 21
|
||||
19873
|
||||
39715
|
||||
220 39715
|
||||
256 39715
|
||||
277 39715
|
||||
277 39715 198 277 39715
|
||||
330
|
||||
198 319
|
||||
19 7359
|
||||
19873 24 386 87799 13 2403 583 650 51358 223 1663 155736 1522 42056 7544 13336 28785 29 4412 20645
|
||||
17931 4959
|
||||
31
|
||||
1922
|
||||
12325
|
||||
12325 31
|
||||
12325 1922
|
||||
12325 12325
|
||||
12325 12325 31
|
||||
12325 12325 1922
|
||||
12325 12325 12325
|
||||
47 19811 12077
|
||||
3260 3579
|
||||
198 7283 51499 191231 20192 3271 3322 9287 2143 17860 114590 222 330 14879 21 51358 127 12817 93293 117 24204 330 68239 881 120327 170428 21 89101 9522 112 247 172394 247 220 31 220 1922 220 12325 220 12325 31 220 12325 1922 220 12325 12325 220 12325 12325 31 220 12325 12325 1922 220 31 26 31 220 31 396 31 220 31 1043 31 117131 102595 115960 21125 148305 96819 102816 80883 223 1663 155736 1522 42056 7544 13336 28785 29 4412 20645 79745 150278 117079 633 79829 34051 1636 25611 41990 109428 1488 91054 24072 17931 4959 29795 9296 16517 1806 481 96 1386 36633 1609 24 481 1109 650 5074 43 481 57 702 5074 27088 2170 536 24 481 48 650 1933 1696 30262 43 1665 19 32818 262 27236 56
|
@ -6,6 +6,7 @@
|
||||
|
||||
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_LLAMA, "llama" },
|
||||
{ LLM_ARCH_LLAMA4, "llama4" },
|
||||
{ LLM_ARCH_DECI, "deci" },
|
||||
{ LLM_ARCH_FALCON, "falcon" },
|
||||
{ LLM_ARCH_GROK, "grok" },
|
||||
@ -114,6 +115,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
|
||||
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
|
||||
{ LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
|
||||
{ LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" },
|
||||
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
||||
@ -233,6 +235,35 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_LLAMA4,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
|
||||
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_DECI,
|
||||
{
|
||||
|
@ -10,6 +10,7 @@
|
||||
|
||||
enum llm_arch {
|
||||
LLM_ARCH_LLAMA,
|
||||
LLM_ARCH_LLAMA4,
|
||||
LLM_ARCH_DECI,
|
||||
LLM_ARCH_FALCON,
|
||||
LLM_ARCH_BAICHUAN,
|
||||
@ -118,6 +119,7 @@ enum llm_kv {
|
||||
LLM_KV_RESIDUAL_SCALE,
|
||||
LLM_KV_EMBEDDING_SCALE,
|
||||
LLM_KV_TOKEN_SHIFT_COUNT,
|
||||
LLM_KV_INTERLEAVE_MOE_LAYER_STEP,
|
||||
|
||||
LLM_KV_ATTENTION_HEAD_COUNT,
|
||||
LLM_KV_ATTENTION_HEAD_COUNT_KV,
|
||||
|
@ -61,6 +61,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
|
||||
{ "yandex", LLM_CHAT_TEMPLATE_YANDEX },
|
||||
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
|
||||
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
|
||||
};
|
||||
|
||||
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
||||
@ -174,6 +175,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
||||
return LLM_CHAT_TEMPLATE_YANDEX;
|
||||
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("'HUMAN'")) {
|
||||
return LLM_CHAT_TEMPLATE_BAILING;
|
||||
} else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
|
||||
return LLM_CHAT_TEMPLATE_LLAMA4;
|
||||
}
|
||||
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
||||
}
|
||||
@ -608,7 +611,16 @@ int32_t llm_chat_apply_template(
|
||||
if (add_ass) {
|
||||
ss << "<role>ASSISTANT</role>";
|
||||
}
|
||||
} else {
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA4) {
|
||||
// Llama 4
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
ss << "<|header_start|>" << role << "<|header_end|>\n\n" << trim(message->content) << "<|eot|>";
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|header_start|>assistant<|header_end|>\n\n";
|
||||
}
|
||||
} else {
|
||||
// template not supported
|
||||
return -1;
|
||||
}
|
||||
|
@ -40,6 +40,7 @@ enum llm_chat_template {
|
||||
LLM_CHAT_TEMPLATE_MEGREZ,
|
||||
LLM_CHAT_TEMPLATE_YANDEX,
|
||||
LLM_CHAT_TEMPLATE_BAILING,
|
||||
LLM_CHAT_TEMPLATE_LLAMA4,
|
||||
LLM_CHAT_TEMPLATE_UNKNOWN,
|
||||
};
|
||||
|
||||
|
@ -59,6 +59,22 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
|
||||
if (ubatch->pos && attn_scale) {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
std::vector<float> attn_scale_data(n_tokens, 0.0f);
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
const float pos = ubatch->pos[i];
|
||||
attn_scale_data[i] = std::log(
|
||||
std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0
|
||||
) * f_attn_temp_scale + 1.0;
|
||||
}
|
||||
|
||||
ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*n_pos_per_token*ggml_element_size(attn_scale));
|
||||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
||||
if (pos_bucket) {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
@ -458,9 +474,17 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
|
||||
// may need to cut off old tokens for sliding window
|
||||
// TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
|
||||
if (data_swa) {
|
||||
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
|
||||
f = -INFINITY;
|
||||
if (hparams.n_attn_chunk) {
|
||||
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
|
||||
if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
|
||||
f = -INFINITY;
|
||||
}
|
||||
} else {
|
||||
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
|
||||
f = -INFINITY;
|
||||
}
|
||||
}
|
||||
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
||||
}
|
||||
@ -812,8 +836,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
float w_scale,
|
||||
llama_expert_gating_func_type gating_op,
|
||||
int il) const {
|
||||
int64_t n_embd = cur->ne[0];
|
||||
int64_t n_tokens = cur->ne[1];
|
||||
const int64_t n_embd = cur->ne[0];
|
||||
const int64_t n_tokens = cur->ne[1];
|
||||
const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
|
||||
|
||||
ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
|
||||
cb(logits, "ffn_moe_logits", il);
|
||||
@ -841,6 +866,12 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
cb(selection_probs, "ffn_moe_probs_biased", il);
|
||||
}
|
||||
|
||||
// llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k
|
||||
// see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198
|
||||
if (arch == LLM_ARCH_LLAMA4) {
|
||||
selection_probs = logits;
|
||||
}
|
||||
|
||||
// select experts
|
||||
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
|
||||
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
||||
@ -867,6 +898,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
}
|
||||
|
||||
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
|
||||
|
||||
if (weight_before_ffn) {
|
||||
// TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d)
|
||||
ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens);
|
||||
repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
|
||||
cur = ggml_mul(ctx0, repeated, weights);
|
||||
cb(cur, "ffn_moe_weighted", il);
|
||||
}
|
||||
|
||||
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
||||
cb(up, "ffn_moe_up", il);
|
||||
|
||||
@ -894,7 +934,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
|
||||
cb(experts, "ffn_moe_down", il);
|
||||
|
||||
experts = ggml_mul(ctx0, experts, weights);
|
||||
if (!weight_before_ffn) {
|
||||
experts = ggml_mul(ctx0, experts, weights);
|
||||
cb(cur, "ffn_moe_weighted", il);
|
||||
}
|
||||
|
||||
// aggregate experts
|
||||
ggml_tensor * moe_out = nullptr;
|
||||
@ -914,6 +957,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
moe_out = ggml_cont(ctx0, moe_out);
|
||||
}
|
||||
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
return moe_out;
|
||||
}
|
||||
|
||||
@ -981,6 +1026,19 @@ ggml_tensor * llm_graph_context::build_inp_pos() const {
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
|
||||
auto inp = std::make_unique<llm_graph_input_attn_temp>(n_pos_per_token(), hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
|
||||
|
||||
auto & cur = inp->attn_scale;
|
||||
|
||||
cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_token());
|
||||
ggml_set_input(cur);
|
||||
|
||||
res->add_input(std::move(inp));
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_out_ids() const {
|
||||
auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
|
||||
|
||||
|
@ -100,6 +100,23 @@ public:
|
||||
const int64_t n_pos_per_token = 1;
|
||||
};
|
||||
|
||||
// temperature tuning, used by llama4
|
||||
class llm_graph_input_attn_temp : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_attn_temp(int64_t n_pos_per_token, uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
|
||||
: n_pos_per_token(n_pos_per_token), n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
|
||||
virtual ~llm_graph_input_attn_temp() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
|
||||
|
||||
const int64_t n_pos_per_token = 1;
|
||||
|
||||
const uint32_t n_attn_temp_floor_scale;
|
||||
const float f_attn_temp_scale;
|
||||
};
|
||||
|
||||
class llm_graph_input_pos_bucket : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
|
||||
@ -470,6 +487,7 @@ struct llm_graph_context {
|
||||
|
||||
ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
|
||||
ggml_tensor * build_inp_pos() const;
|
||||
ggml_tensor * build_inp_attn_scale() const;
|
||||
ggml_tensor * build_inp_out_ids() const;
|
||||
ggml_tensor * build_inp_mean() const;
|
||||
ggml_tensor * build_inp_cls() const;
|
||||
|
@ -112,6 +112,14 @@ struct llama_hparams {
|
||||
bool use_alibi = false;
|
||||
bool attn_soft_cap = false;
|
||||
|
||||
uint32_t n_moe_layer_step = 0;
|
||||
bool use_kq_norm = true;
|
||||
uint32_t n_attn_chunk = 0;
|
||||
// values below seems to be fixed on llama4
|
||||
uint32_t n_no_rope_layer_step = 4;
|
||||
uint32_t n_attn_temp_floor_scale = 8192;
|
||||
float f_attn_temp_scale = 0.1;
|
||||
|
||||
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
|
||||
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
|
||||
|
@ -90,6 +90,8 @@ const char * llm_type_name(llm_type type) {
|
||||
case LLM_TYPE_57B_A14B: return "57B.A14B";
|
||||
case LLM_TYPE_27B: return "27B";
|
||||
case LLM_TYPE_290B: return "290B";
|
||||
case LLM_TYPE_17B_16E: return "17Bx16E (Scout)";
|
||||
case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
|
||||
default: return "?B";
|
||||
}
|
||||
}
|
||||
@ -550,6 +552,25 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_LLAMA4:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step);
|
||||
hparams.n_swa_pattern = 4; // pattern: 3 chunked - 1 full
|
||||
hparams.n_attn_chunk = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
|
||||
hparams.n_swa = 1; // TODO @ngxson : this is added to trigger the SWA branch (we store the chunked attn mask in the SWA tensor), will need to clean this up later
|
||||
|
||||
switch (hparams.n_expert) {
|
||||
case 16: type = LLM_TYPE_17B_16E; break;
|
||||
case 128: type = LLM_TYPE_17B_128E; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
|
||||
if (type == LLM_TYPE_17B_128E) {
|
||||
hparams.use_kq_norm = false;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_DECI:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
@ -1690,6 +1711,56 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_LLAMA4:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Llama 4 requires n_moe_layer_step > 0");
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
bool is_moe_layer = (i + 1) % hparams.n_moe_layer_step == 0;
|
||||
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||
|
||||
if (is_moe_layer) {
|
||||
int n_ff_exp = hparams.n_ff_exp;
|
||||
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
||||
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0);
|
||||
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
|
||||
|
||||
// Shared expert
|
||||
const int64_t n_ff_shexp = n_ff_exp;
|
||||
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0);
|
||||
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd }, 0);
|
||||
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0);
|
||||
} else {
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_DECI:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
@ -4203,12 +4274,22 @@ struct llm_build_llama : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// temperature tuning
|
||||
ggml_tensor * inp_attn_scale = nullptr;
|
||||
if (arch == LLM_ARCH_LLAMA4) {
|
||||
inp_attn_scale = build_inp_attn_scale();
|
||||
}
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
bool use_rope = arch == LLM_ARCH_LLAMA4
|
||||
? (il + 1) % hparams.n_no_rope_layer_step != 0
|
||||
: true;
|
||||
|
||||
// norm
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
@ -4246,25 +4327,38 @@ struct llm_build_llama : public llm_graph_context {
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
if (use_rope) {
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
} else if (inp_attn_scale) {
|
||||
Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale);
|
||||
}
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
if (arch == LLM_ARCH_LLAMA4 && use_rope && hparams.use_kq_norm) {
|
||||
// Llama4TextL2Norm
|
||||
Qcur = ggml_rms_norm(ctx0, Qcur, 1e-6);
|
||||
Kcur = ggml_rms_norm(ctx0, Kcur, 1e-6);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
}
|
||||
|
||||
cur = build_attn(inp_attn, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, kq_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
@ -4282,7 +4376,7 @@ struct llm_build_llama : public llm_graph_context {
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network
|
||||
// feed-forward network (non-MoE)
|
||||
if (model.layers[il].ffn_gate_inp == nullptr) {
|
||||
|
||||
cur = build_norm(ffn_inp,
|
||||
@ -4297,6 +4391,38 @@ struct llm_build_llama : public llm_graph_context {
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
} else if (arch == LLM_ARCH_LLAMA4) {
|
||||
// llama4 MoE
|
||||
ggml_tensor * ffn_inp_normed = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
ggml_tensor * moe_out = build_moe_ffn(ffn_inp_normed,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, false,
|
||||
false, 0.0,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID,
|
||||
il);
|
||||
|
||||
// Shared experts
|
||||
ggml_tensor * shexp_out = build_ffn(ffn_inp_normed,
|
||||
model.layers[il].ffn_up_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_gate_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_down_shexp, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(shexp_out, "ffn_moe_shexp", il);
|
||||
|
||||
cur = ggml_add(ctx0, moe_out, shexp_out);
|
||||
cb(cur, "ffn_moe_out_merged", il);
|
||||
|
||||
} else {
|
||||
// MoE branch
|
||||
cur = build_norm(ffn_inp,
|
||||
@ -12091,6 +12217,7 @@ llm_graph_result_ptr llama_model::build_graph(
|
||||
|
||||
switch (arch) {
|
||||
case LLM_ARCH_LLAMA:
|
||||
case LLM_ARCH_LLAMA4:
|
||||
case LLM_ARCH_MINICPM:
|
||||
case LLM_ARCH_GRANITE:
|
||||
case LLM_ARCH_GRANITE_MOE:
|
||||
@ -12440,6 +12567,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
|
||||
// use what we call a normal RoPE, operating on pairs of consecutive head values
|
||||
case LLM_ARCH_LLAMA:
|
||||
case LLM_ARCH_LLAMA4:
|
||||
case LLM_ARCH_DECI:
|
||||
case LLM_ARCH_BAICHUAN:
|
||||
case LLM_ARCH_STARCODER:
|
||||
|
@ -86,6 +86,8 @@ enum llm_type {
|
||||
LLM_TYPE_57B_A14B,
|
||||
LLM_TYPE_27B,
|
||||
LLM_TYPE_290B,
|
||||
LLM_TYPE_17B_16E, // llama4 Scout
|
||||
LLM_TYPE_17B_128E, // llama4 Maverick
|
||||
};
|
||||
|
||||
struct llama_layer_posnet {
|
||||
|
@ -1616,7 +1616,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
tokenizer_pre == "megrez") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
|
||||
} else if (
|
||||
tokenizer_pre == "gpt-4o") {
|
||||
tokenizer_pre == "gpt-4o" ||
|
||||
tokenizer_pre == "llama4") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O;
|
||||
clean_spaces = false;
|
||||
} else if (
|
||||
|
Loading…
x
Reference in New Issue
Block a user