diff --git a/README.md b/README.md index e56042f1c..a129d27d5 100644 --- a/README.md +++ b/README.md @@ -97,6 +97,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo - [x] [Flan T5](https://huggingface.co/models?search=flan-t5) - [x] [Open Elm models](https://huggingface.co/collections/apple/openelm-instruct-models-6619ad295d7ae9f868b759ca) - [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b) + [GLMEdge-1.5b](https://huggingface.co/THUDM/glm-edge-1.5b-chat) + [GLMEdge-4b](https://huggingface.co/THUDM/glm-edge-4b-chat) +- [x] [GLM-4-0414](https://huggingface.co/collections/THUDM/glm-4-0414-67f3cbcb34dd9d252707cb2e) - [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966) - [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct) - [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 4aff03a34..2bf97475f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -735,6 +735,9 @@ class Model: if chkhsh == "d353350c764d8c3b39c763113960e4fb4919bea5fbf208a0e3b22e8469dc7406": # ref: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct res = "llama4" + if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2": + # ref: https://huggingface.co/THUDM/glm-4-9b-hf + res = "glm4" if res is None: logger.warning("\n") @@ -4897,6 +4900,22 @@ class JaisModel(Model): self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias) +@Model.register("Glm4ForCausalLM") +class Glm4Model(Model): + model_arch = gguf.MODEL_ARCH.GLM4 + + def set_vocab(self): + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: + if self.hparams["rope_scaling"].get("type") == "yarn": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"]) + + @Model.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration") class ChatGLMModel(Model): model_arch = gguf.MODEL_ARCH.CHATGLM @@ -5588,7 +5607,6 @@ def main() -> None: with torch.inference_mode(): output_type = ftype_map[args.outtype] model_architecture = hparams["architectures"][0] - try: model_class = Model.from_model_architecture(model_architecture) except NotImplementedError: diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index ce6104da4..160c9fe0e 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -114,6 +114,7 @@ models = [ {"name": "trillion", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/trillionlabs/Trillion-7B-preview", }, {"name": "bailingmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-lite", }, {"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", }, + {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", }, ] diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 0410654dd..162070e6e 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -280,6 +280,7 @@ class MODEL_ARCH(IntEnum): DEEPSEEK = auto() DEEPSEEK2 = auto() CHATGLM = auto() + GLM4 = auto() BITNET = auto() T5 = auto() T5ENCODER = auto() @@ -487,6 +488,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.DEEPSEEK: "deepseek", MODEL_ARCH.DEEPSEEK2: "deepseek2", MODEL_ARCH.CHATGLM: "chatglm", + MODEL_ARCH.GLM4: "glm4", MODEL_ARCH.BITNET: "bitnet", MODEL_ARCH.T5: "t5", MODEL_ARCH.T5ENCODER: "t5encoder", @@ -1561,6 +1563,23 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.GLM4 : [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.FFN_POST_NORM, + ], MODEL_ARCH.BITNET: [ MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index a9e681f8e..35154e9b5 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -13,7 +13,7 @@ class TensorNameMap: "transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone "transformer.word_embeddings", # falcon "word_embeddings", # bloom - "model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 + "model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 "tok_embeddings", # llama-pth "embeddings.word_embeddings", # bert nomic-bert "language_model.embedding.word_embeddings", # persimmon @@ -241,7 +241,8 @@ class TensorNameMap: ), MODEL_TENSOR.ATTN_POST_NORM: ( - "model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 + "model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 # ge + "model.layers.{bid}.post_self_attn_layernorm", # glm-4-0414 ), # Rotary embeddings @@ -278,6 +279,7 @@ class TensorNameMap: # Post feed-forward norm MODEL_TENSOR.FFN_POST_NORM: ( "model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2 + "model.layers.{bid}.post_mlp_layernorm", # glm-4-0414 ), MODEL_TENSOR.FFN_GATE_INP: ( @@ -316,7 +318,7 @@ class TensorNameMap: "h.{bid}.mlp.c_fc", # gpt2 "transformer.h.{bid}.mlp.fc1", # phi2 "model.layers.{bid}.mlp.fc1", # phi2 - "model.layers.{bid}.mlp.gate_up_proj", # phi3 + "model.layers.{bid}.mlp.gate_up_proj", # phi3 glm-4-0414 "model.layers.layers.{bid}.mlp.up_proj", # plamo "model.layers.{bid}.feed_forward.w3", # internlm2 "encoder.layers.{bid}.mlp.fc11", # nomic-bert diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 264f8c5b9..a6fddc7fd 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -54,6 +54,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_DEEPSEEK, "deepseek" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_CHATGLM, "chatglm" }, + { LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, @@ -1152,6 +1153,25 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, }, }, + { + LLM_ARCH_GLM4, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, { LLM_ARCH_BITNET, { diff --git a/src/llama-arch.h b/src/llama-arch.h index 201935281..2c2099b3c 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -58,6 +58,7 @@ enum llm_arch { LLM_ARCH_DEEPSEEK, LLM_ARCH_DEEPSEEK2, LLM_ARCH_CHATGLM, + LLM_ARCH_GLM4, LLM_ARCH_BITNET, LLM_ARCH_T5, LLM_ARCH_T5ENCODER, @@ -256,6 +257,8 @@ enum llm_tensor { LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_LAYER_OUT_NORM, + LLM_TENSOR_POST_ATTN_NORM, + LLM_TENSOR_POST_MLP_NORM, LLM_TENSOR_SSM_IN, LLM_TENSOR_SSM_CONV1D, LLM_TENSOR_SSM_X, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index ff847701e..b74dd72cf 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1205,6 +1205,15 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GLM4: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 40: type = LLM_TYPE_9B; break; + case 61: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_BITNET: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -3476,6 +3485,45 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); } } break; + case LLM_ARCH_GLM4: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + + if (layer.wqkv == nullptr) { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + } + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); + + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } + } break; case LLM_ARCH_NEMOTRON: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -10854,6 +10902,157 @@ struct llm_build_chatglm : public llm_graph_context { } }; +struct llm_build_glm4 : public llm_graph_context { + llm_build_glm4(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // Pre-attention norm + cur = build_norm(inpL, + model.layers[il].attn_norm, + NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = nullptr; + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; + + if (model.layers[il].wqkv == nullptr) { + Qcur = build_lora_mm(model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + Kcur = build_lora_mm(model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + Vcur = build_lora_mm(model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + } + } else { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + if (model.layers[il].bqkv) { + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + } + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Post-attention norm (new!) + cur = build_norm(cur, + model.layers[il].attn_post_norm, + NULL, + LLM_NORM_RMS, il); + cb(cur, "post_attn_norm", il); + + // Add the input (residual connection after post-attention norm) + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + // Pre-MLP norm + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // MLP + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + // Post-MLP norm + cur = build_norm(cur, + model.layers[il].ffn_post_norm, + NULL, + LLM_NORM_RMS, il); + cb(cur, "post_mlp_norm", il); + } + + // Add residual connection after post-MLP norm + inpL = ggml_add(ctx0, cur, ffn_inp); + cb(inpL, "l_out", il); + } + + // Final norm + cur = build_norm(inpL, + model.output_norm, + NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // Output projection + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + struct llm_build_nemotron : public llm_graph_context { llm_build_nemotron(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -12735,6 +12934,10 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_GLM4: + { + llm = std::make_unique(*this, params, gf); + } break; case LLM_ARCH_BITNET: { llm = std::make_unique(*this, params, gf); @@ -12932,6 +13135,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_DEEPSEEK2: case LLM_ARCH_PLM: case LLM_ARCH_CHATGLM: + case LLM_ARCH_GLM4: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: case LLM_ARCH_CHAMELEON: diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 0feabd95a..464ff01e0 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1572,6 +1572,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { pre_type = LLAMA_VOCAB_PRE_TYPE_PORO; clean_spaces = false; } else if ( + tokenizer_pre == "glm4" || tokenizer_pre == "chatglm-bpe") { pre_type = LLAMA_VOCAB_PRE_TYPE_CHATGLM4; special_bos_id = LLAMA_TOKEN_NULL;