mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-04-16 03:26:08 +00:00
graph : simplify attn input build for unified KV cache (#12381)
ggml-ci
This commit is contained in:
parent
081bee8c64
commit
c522ce4143
@ -1311,29 +1311,23 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
return cur;
|
||||
}
|
||||
|
||||
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(
|
||||
bool causal,
|
||||
bool swa) const {
|
||||
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
|
||||
|
||||
const auto n_kv = kv_self->n;
|
||||
|
||||
inp->self_kq_mask = causal
|
||||
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
|
||||
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
|
||||
if (swa) {
|
||||
if (hparams.n_swa_pattern > 1) {
|
||||
GGML_ASSERT(hparams.n_swa > 0);
|
||||
|
||||
inp->self_kq_mask_swa = causal
|
||||
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
|
||||
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
|
||||
|
@ -509,9 +509,7 @@ struct llm_graph_context {
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified(
|
||||
bool causal,
|
||||
bool swa) const;
|
||||
llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
|
||||
|
||||
ggml_tensor * build_attn(
|
||||
llm_graph_input_attn_kv_unified * inp,
|
||||
|
@ -784,9 +784,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
hparams.n_swa = 2047;
|
||||
} else if (hparams.n_layer == 32 && hparams.n_head_kv(0) == 32 && hparams.n_ctx_train == 131072) {
|
||||
// default value for Phi-3-mini-128k-instruct
|
||||
// note: this seems incorrect because the window is bigger than the train context?
|
||||
hparams.n_swa = 262144;
|
||||
} else if (hparams.n_layer == 40 && hparams.n_ctx_train == 131072) {
|
||||
// default value for Phi-3-medium-128k-instruct
|
||||
// note: this seems incorrect because the window is equal to the train context?
|
||||
hparams.n_swa = 131072;
|
||||
}
|
||||
bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
|
||||
@ -3710,6 +3712,7 @@ void llama_model::print_info() const {
|
||||
LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
|
||||
LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
|
||||
LLAMA_LOG_INFO("%s: n_swa_pattern = %u\n", __func__, hparams.n_swa_pattern);
|
||||
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
|
||||
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
|
||||
LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str());
|
||||
@ -3871,7 +3874,7 @@ struct llm_build_llama : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
@ -4034,7 +4037,7 @@ struct llm_build_deci : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
@ -4192,7 +4195,7 @@ struct llm_build_baichuan : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = model.type == LLM_TYPE_7B ? build_inp_pos() : nullptr;
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
@ -4310,7 +4313,7 @@ struct llm_build_xverse : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
@ -4418,7 +4421,7 @@ struct llm_build_falcon : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * attn_norm;
|
||||
@ -4543,7 +4546,7 @@ struct llm_build_grok : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
@ -4697,7 +4700,7 @@ struct llm_build_dbrx : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
@ -4821,7 +4824,7 @@ struct llm_build_starcoder : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
ggml_tensor * pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
|
||||
cb(pos, "pos_embd", -1);
|
||||
@ -4924,7 +4927,7 @@ struct llm_build_refact : public llm_graph_context {
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
@ -5187,7 +5190,7 @@ struct llm_build_bloom : public llm_graph_context {
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
inpL = build_norm(inpL,
|
||||
model.tok_norm,
|
||||
@ -5292,7 +5295,7 @@ struct llm_build_mpt : public llm_graph_context {
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
if (model.pos_embd) {
|
||||
// inp_pos - contains the positions
|
||||
@ -5436,7 +5439,7 @@ struct llm_build_stablelm : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
// norm
|
||||
@ -5587,7 +5590,7 @@ struct llm_build_qwen : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
@ -5703,7 +5706,7 @@ struct llm_build_qwen2 : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
@ -5818,7 +5821,7 @@ struct llm_build_qwen2vl : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
int sections[4];
|
||||
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
|
||||
@ -5938,7 +5941,7 @@ struct llm_build_qwen2moe : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
@ -6087,7 +6090,7 @@ struct llm_build_phi2 : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
attn_norm_output = build_norm(inpL,
|
||||
@ -6211,7 +6214,7 @@ struct llm_build_phi3 : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, true);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
auto * residual = inpL;
|
||||
@ -6357,7 +6360,7 @@ struct llm_build_plamo : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
|
||||
@ -6465,7 +6468,7 @@ struct llm_build_gpt2 : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
|
||||
cb(pos, "pos_embd", -1);
|
||||
@ -6573,7 +6576,7 @@ struct llm_build_codeshell : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
cur = build_norm(inpL,
|
||||
@ -6686,7 +6689,7 @@ struct llm_build_orion : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
@ -6807,7 +6810,7 @@ struct llm_build_internlm2 : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
@ -6937,7 +6940,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
@ -7141,7 +7144,7 @@ struct llm_build_gemma : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
// norm
|
||||
@ -7251,7 +7254,7 @@ struct llm_build_gemma2 : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, true);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
// norm
|
||||
@ -7386,7 +7389,7 @@ struct llm_build_gemma3 : public llm_graph_context {
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// TODO: is causal == true correct? might need some changes
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, true);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const bool is_swa = hparams.is_swa(il);
|
||||
@ -7515,7 +7518,7 @@ struct llm_build_starcoder2 : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
@ -7828,7 +7831,7 @@ struct llm_build_command_r : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
|
||||
@ -7978,7 +7981,7 @@ struct llm_build_cohere2 : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, true);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const bool is_swa = hparams.is_swa(il);
|
||||
@ -8110,7 +8113,7 @@ struct llm_build_olmo : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
@ -8232,7 +8235,7 @@ struct llm_build_olmo2 : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
@ -8358,7 +8361,7 @@ struct llm_build_olmoe : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
@ -8481,7 +8484,7 @@ struct llm_build_openelm : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const int64_t n_head = hparams.n_head(il);
|
||||
@ -8611,7 +8614,7 @@ struct llm_build_gptneox : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
cur = build_norm(inpL,
|
||||
@ -8757,7 +8760,7 @@ struct llm_build_arctic : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
@ -8889,7 +8892,7 @@ struct llm_build_deepseek : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||
|
||||
@ -9054,7 +9057,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
@ -9274,7 +9277,7 @@ struct llm_build_bitnet : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
@ -9532,7 +9535,7 @@ struct llm_build_t5_dec : public llm_graph_context {
|
||||
|
||||
const int64_t n_outputs_enc = embd_enc->ne[1];
|
||||
|
||||
auto * inp_attn_self = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn_self = build_attn_inp_kv_unified();
|
||||
auto * inp_attn_cross = build_attn_inp_cross();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
@ -9698,7 +9701,7 @@ struct llm_build_jais : public llm_graph_context {
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
cur = build_norm(inpL,
|
||||
@ -9794,7 +9797,7 @@ struct llm_build_chatglm : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
@ -9926,7 +9929,7 @@ struct llm_build_nemotron : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
@ -10049,7 +10052,7 @@ struct llm_build_exaone : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
@ -10565,7 +10568,7 @@ struct llm_build_chameleon : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified(true, false);
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
Loading…
x
Reference in New Issue
Block a user