diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index c9ac2957f..62284fdd4 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1806,10 +1806,6 @@ class Llama4Model(LlamaModel): self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size_moe"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): - name = name.replace("language_model.", "") - name = name.replace("feed_forward.", "mlp.") # a bit hacky for now - name = name.replace(".router.weight", ".gate.weight") # a bit hacky for now - # split the gate_up into gate and up if "gate_up_proj" in name: name_up = name.replace("gate_up_proj", "up_proj.weight") diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 50bef12e3..a9e681f8e 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -30,6 +30,7 @@ class TensorNameMap: "rwkv.embeddings", # rwkv6 "model.embeddings", # rwkv7 "model.word_embeddings", # bailingmoe + "language_model.model.embed_tokens", # llama4 ), # Token type embeddings @@ -67,6 +68,7 @@ class TensorNameMap: "output_layer", # chatglm "head", # rwkv "head.out", # wavtokenizer + "language_model.lm_head", # llama4 ), # Output norm @@ -89,6 +91,7 @@ class TensorNameMap: "rwkv.ln_out", # rwkv6 "model.ln_out", # rwkv7 "backbone.final_layer_norm", # wavtokenizer + "language_model.model.norm", # llama4 ), # Rope frequencies @@ -130,6 +133,7 @@ class TensorNameMap: "transformer.layers.{bid}.attn_norm", # openelm "rwkv.blocks.{bid}.ln1", # rwkv6 "model.layers.{bid}.ln1", # rwkv7 + "language_model.model.layers.{bid}.input_layernorm", # llama4 ), # Attention norm 2 @@ -169,6 +173,7 @@ class TensorNameMap: "model.layers.{bid}.attention.wq", # internlm2 "transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok "transformer.h.{bid}.attn.attention.q_proj", # exaone + "language_model.model.layers.{bid}.self_attn.q_proj", # llama4 ), # Attention key @@ -183,6 +188,7 @@ class TensorNameMap: "model.layers.{bid}.attention.wk", # internlm2 "transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok "transformer.h.{bid}.attn.attention.k_proj", # exaone + "language_model.model.layers.{bid}.self_attn.k_proj", # llama4 ), # Attention value @@ -196,6 +202,7 @@ class TensorNameMap: "model.layers.{bid}.attention.wv", # internlm2 "transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok "transformer.h.{bid}.attn.attention.v_proj", # exaone + "language_model.model.layers.{bid}.self_attn.v_proj", # llama4 ), # Attention output @@ -222,6 +229,7 @@ class TensorNameMap: "encoder.layers.{bid}.self_attention.dense", # chatglm "transformer.layers.{bid}.attn.out_proj", # openelm "transformer.h.{bid}.attn.attention.out_proj", # exaone + "language_model.model.layers.{bid}.self_attn.o_proj", # llama4 ), # Attention output norm @@ -259,6 +267,7 @@ class TensorNameMap: "transformer.decoder_layer.{bid}.rms_norm_2", # Grok "encoder.layers.{bid}.post_attention_layernorm", # chatglm "transformer.layers.{bid}.ffn_norm", # openelm + "language_model.model.layers.{bid}.post_attention_layernorm", # llama4 ), # Post feed-forward norm @@ -278,6 +287,7 @@ class TensorNameMap: "transformer.decoder_layer.{bid}.router", # Grok "transformer.blocks.{bid}.ffn.router.layer", # dbrx "model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe + "language_model.model.layers.{bid}.feed_forward.router", # llama4 ), MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( @@ -315,6 +325,7 @@ class TensorNameMap: "model.layers.{bid}.residual_mlp.w3", # arctic "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm "transformer.h.{bid}.mlp.c_fc_1", # exaone + "language_model.model.layers.{bid}.feed_forward.up_proj", # llama4 ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -323,11 +334,13 @@ class TensorNameMap: "transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx "model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) "model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged) + "language_model.model.layers.{bid}.feed_forward.experts.up_proj", # llama4 ), MODEL_TENSOR.FFN_UP_SHEXP: ( "model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe "model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2 + "language_model.model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4 ), # AWQ-activation gate @@ -348,6 +361,7 @@ class TensorNameMap: "transformer.h.{bid}.mlp.linear_1", # refact "model.layers.{bid}.residual_mlp.w1", # arctic "transformer.h.{bid}.mlp.c_fc_0", # exaone + "language_model.model.layers.{bid}.feed_forward.gate_proj", # llama4 ), MODEL_TENSOR.FFN_GATE_EXP: ( @@ -356,11 +370,13 @@ class TensorNameMap: "transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx "model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged) "model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged) + "language_model.model.layers.{bid}.feed_forward.experts.gate_proj", # llama4 ), MODEL_TENSOR.FFN_GATE_SHEXP: ( "model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe "model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2 + "language_model.model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4 ), # Feed-forward down @@ -389,6 +405,7 @@ class TensorNameMap: "encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2 "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm "model.layers.h.{bid}.mlp.c_proj", # exaone + "language_model.model.layers.{bid}.feed_forward.down_proj", # llama4 ), MODEL_TENSOR.FFN_DOWN_EXP: ( @@ -398,11 +415,13 @@ class TensorNameMap: "model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged) "model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe "model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged) + "language_model.model.layers.{bid}.feed_forward.experts.down_proj", # llama4 ), MODEL_TENSOR.FFN_DOWN_SHEXP: ( "model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2 + "language_model.model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4 ), MODEL_TENSOR.ATTN_Q_NORM: (