kv-cache : serparate recurrent vs non-recurrent impl (wip)

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-04-07 16:50:58 +03:00
parent 916c83bfe7
commit 19eb81e083
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
7 changed files with 1245 additions and 386 deletions

View File

@ -178,24 +178,37 @@ llama_context::llama_context(
// init the memory module
// TODO: for now, always create a unified KV cache
if (!hparams.vocab_only) {
kv_self.reset(static_cast<llama_kv_cache_unified *>(model.create_memory()));
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv_self->get_padding(cparams));
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
uint32_t kv_size = cparams.n_ctx;
uint32_t kv_size = 0;
ggml_type type_k = params.type_k;
ggml_type type_v = params.type_v;
if (llama_model_is_recurrent(&model)) {
if (!llama_model_is_recurrent(&model)) {
//kv_self.reset(static_cast<llama_kv_cache_unified *>(model.create_memory()));
auto * kv = static_cast<llama_kv_cache_unified *>(model.create_memory());
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv->get_padding(cparams));
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
kv_size = cparams.n_ctx;
type_k = params.type_k;
type_v = params.type_v;
kv_self.reset(kv);
} else {
auto * kv = static_cast<llama_kv_cache_recurrent *>(model.create_memory());
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
// Mamba needs at least as many KV cells as there are sequences kept at any time
kv_size = std::max((uint32_t) 1, params.n_seq_max);
// it's probably best to keep as much precision as possible for the states
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
kv_self.reset(kv);
}
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
@ -304,7 +317,7 @@ llama_context::llama_context(
int n_nodes_tg = -1;
// simulate full KV cache
kv_self->n = kv_self->size;
kv_self->set_full();
cross.v_embd.clear();
@ -553,7 +566,9 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
//GGML_ASSERT(kv_self->size == n_ctx);
auto inp = std::make_unique<llm_graph_input_k_shift>(kv_self.get());
const auto * kv = static_cast<const llama_kv_cache_unified *>(kv_self.get());
auto inp = std::make_unique<llm_graph_input_k_shift>(kv);
inp->k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_ctx);
ggml_set_input(inp->k_shift);
@ -569,16 +584,16 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
ggml_tensor * rope_factors = kv->cbs.get_rope_factors(n_ctx_per_seq(), il);
ggml_tensor * k =
ggml_view_3d(ctx0, kv_self->k_l[il],
n_embd_head_k, n_head_kv, kv_self->size,
ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
ggml_view_3d(ctx0, kv->k_l[il],
n_embd_head_k, n_head_kv, kv->size,
ggml_row_size(kv->k_l[il]->type, n_embd_head_k),
ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa),
0);
ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv_self->k_l[il]->buffer);
ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv->k_l[il]->buffer);
ggml_build_forward_expand(gf, cur);
}
@ -593,9 +608,11 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
ggml_cgraph * gf) const {
auto res = std::make_unique<llm_graph_result>();
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
const auto & hparams = model.hparams;
const auto & ids = kv_self->defrag_info.ids;
const auto & ids = kv->defrag_info.ids;
#if 0
// CPU defrag
@ -685,40 +702,40 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self->k_l[il],
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv->k_l[il],
n_embd_k_gqa, nm,
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i));
ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa*i));
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self->k_l[il],
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv->k_l[il],
n_embd_k_gqa, nm,
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id));
ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa*id));
ggml_tensor * view_v_src;
ggml_tensor * view_v_dst;
if (cparams.flash_attn) {
// NOTE: the V cache is not transposed when using flash attention
view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
view_v_src = ggml_view_2d(ctx0, kv->v_l[il],
n_embd_v_gqa, nm,
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i));
ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa*i));
view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
view_v_dst = ggml_view_2d(ctx0, kv->v_l[il],
n_embd_v_gqa, nm,
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id));
ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa*id));
} else {
view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
view_v_src = ggml_view_2d(ctx0, kv->v_l[il],
nm, n_embd_v_gqa,
ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
ggml_row_size(kv_self->v_l[il]->type, i));
ggml_row_size(kv->v_l[il]->type, kv->size),
ggml_row_size(kv->v_l[il]->type, i));
view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
view_v_dst = ggml_view_2d(ctx0, kv->v_l[il],
nm, n_embd_v_gqa,
ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
ggml_row_size(kv_self->v_l[il]->type, id));
ggml_row_size(kv->v_l[il]->type, kv->size),
ggml_row_size(kv->v_l[il]->type, id));
}
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
@ -735,13 +752,11 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
}
void llama_context::kv_self_update() {
auto & kv = kv_self;
bool need_reserve = false;
if (kv->has_shift) {
if (!kv->get_can_shift()) {
GGML_ABORT("The current context does not support K-shift");
if (kv_self->get_has_shift()) {
if (!kv_self->get_can_shift()) {
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
}
LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
@ -764,6 +779,8 @@ void llama_context::kv_self_update() {
}
{
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
kv->has_shift = false;
for (uint32_t i = 0; i < kv->size; ++i) {
@ -773,9 +790,11 @@ void llama_context::kv_self_update() {
}
// defragment the KV cache if needed
if (kv->do_defrag) {
if (kv_self->get_do_defrag()) {
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
if (kv->defrag_prepare(graph_max_nodes())) {
ggml_backend_sched_reset(sched.get());
@ -804,7 +823,7 @@ void llama_context::kv_self_update() {
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
// simulate full KV cache
kv_self->n = kv_self->size;
kv_self->set_full();
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
@ -1024,8 +1043,8 @@ int llama_context::encode(llama_batch & inp_batch) {
}
// temporary allocate memory for the input batch if needed
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
// TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
const llama_batch & batch = batch_allocr.batch;
const int32_t n_tokens = batch.n_tokens;
@ -1189,8 +1208,8 @@ int llama_context::decode(llama_batch & inp_batch) {
}
// temporary allocate memory for the input batch if needed
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
// TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
const llama_batch & batch = batch_allocr.batch;
@ -1245,8 +1264,10 @@ int llama_context::decode(llama_batch & inp_batch) {
const bool logits_all = n_outputs_all == n_tokens_all;
const bool is_recurrent = llama_model_is_recurrent(&model);
sbatch.from_batch(batch, n_embd,
/* simple_split */ !kv_self->recurrent,
/* simple_split */ !is_recurrent,
/* logits_all */ logits_all);
// reserve output buffer
@ -1265,7 +1286,7 @@ int llama_context::decode(llama_batch & inp_batch) {
const auto & n_ubatch = cparams.n_ubatch;
if (kv_self->recurrent) {
if (is_recurrent) {
if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet)
ubatch = sbatch.split_seq(cparams.n_ubatch);
@ -1303,17 +1324,19 @@ int llama_context::decode(llama_batch & inp_batch) {
return 1;
}
if (!kv_self->recurrent) {
if (!is_recurrent) {
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
const uint32_t pad = kv_self->get_padding(cparams);
kv_self->n = std::min(kv_self->size, std::max(pad, GGML_PAD(kv_self->cell_max(), pad)));
const uint32_t pad = kv->get_padding(cparams);
kv->n = std::min(kv->size, std::max(pad, GGML_PAD(kv->cell_max(), pad)));
//printf("kv.n = %5d, kv.used = %5d, kv.head = %5d\n", kv->n, kv->used, kv->head);
}
}
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head);
ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
@ -1453,10 +1476,12 @@ int llama_context::decode(llama_batch & inp_batch) {
//synchronize();
// decide if we need to defrag the kv cache
if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
if (!llama_model_is_recurrent(&model) && cparams.causal_attn && cparams.defrag_thold > 0.0f) {
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
// - do not defrag small contexts (i.e. < 2048 tokens)
// - count the padding towards the number of used tokens
const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f;
const float fragmentation = kv->n >= 2048 ? std::max(0.0f, 1.0f - float(kv->used + kv->get_padding(cparams))/float(kv->n)) : 0.0f;
// queue defragmentation for next llama_kv_cache_update
if (fragmentation > cparams.defrag_thold) {

View File

@ -201,7 +201,7 @@ private:
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
std::unique_ptr<llama_kv_cache_unified> kv_self;
std::unique_ptr<llama_kv_cache> kv_self;
// TODO: remove
bool logits_all = false;

View File

@ -258,7 +258,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
//////////////////////////////////////////////
// TODO: this should not mutate the KV cache !
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_recurrent *>(kv_self)->cells[i];
// prevent out-of-bound sources
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
@ -291,7 +291,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
//////////////////////////////////////////////
// TODO: this should not mutate the KV cache !
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_recurrent *>(kv_self)->cells[i];
data[i] = (float) (kv_cell.src >= 0);
@ -1021,7 +1021,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
}
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
@ -1038,7 +1038,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
}
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
@ -1332,8 +1332,6 @@ ggml_tensor * llm_graph_context::build_attn(
// store to KV cache
{
GGML_ASSERT(!kv_self->recurrent);
const auto kv_head = kv_self->head;
GGML_ASSERT(kv_self->size == n_ctx);
@ -1482,7 +1480,7 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
ggml_tensor * state_mask,
int32_t n_state,
int32_t n_seqs) const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const auto n_kv = kv_self->n;
const auto kv_head = kv_self->head;
@ -1514,7 +1512,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const auto token_shift_count = hparams.token_shift_count;
@ -1535,7 +1533,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
ggml_tensor * token_shift,
const llama_ubatch & ubatch,
int il) const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const auto token_shift_count = hparams.token_shift_count;
const auto n_embd = hparams.n_embd;

View File

@ -19,6 +19,7 @@ struct llama_cparams;
class llama_memory_i;
class llama_kv_cache_unified;
class llama_kv_cache_recurrent;
// certain models (typically multi-modal) can produce different types of graphs
enum llm_graph_type {
@ -171,26 +172,26 @@ public:
class llm_graph_input_s_copy : public llm_graph_input_i {
public:
llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
virtual ~llm_graph_input_s_copy() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_copy; // I32 [kv_size]
const llama_kv_cache_unified * kv_self;
const llama_kv_cache_recurrent * kv_self;
};
class llm_graph_input_s_mask : public llm_graph_input_i {
public:
llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
virtual ~llm_graph_input_s_mask() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_mask; // F32 [1, n_kv]
const llama_kv_cache_unified * kv_self;
const llama_kv_cache_recurrent * kv_self;
};
class llm_graph_input_cross_embd : public llm_graph_input_i {

File diff suppressed because it is too large Load Diff

View File

@ -15,17 +15,44 @@ struct llama_hparams;
struct llama_ubatch;
struct llama_kv_cache : public llama_memory_i {
virtual ~llama_kv_cache() = default;
using llama_memory_i::llama_memory_i;
// TODO: become constructor
virtual bool init(
const llama_model & model, // TODO: do not reference the model
const llama_cparams & cparams,
ggml_type type_k,
ggml_type type_v,
uint32_t kv_size,
bool offload) = 0;
virtual void restore() = 0; // call if batch processing fails - restores the cache state
virtual void commit() = 0; // call after successful batch processing - clears any pending state
virtual int32_t get_n_tokens() const = 0;
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
virtual bool get_has_shift() const = 0;
virtual bool get_do_defrag() const = 0;
virtual llama_pos get_pos_max() const = 0;
virtual bool get_can_shift() const = 0;
bool get_can_edit() const override { return get_can_shift(); }
virtual bool find_slot(const llama_ubatch & batch) = 0;
// simulate full cache, used for allocating worst-case compute buffers
virtual void set_full() = 0;
virtual size_t size_k_bytes() const = 0;
virtual size_t size_v_bytes() const = 0;
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
};
struct llama_kv_cache_guard {
@ -74,12 +101,6 @@ public:
std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
};
llama_kv_cache_unified(
const llama_hparams & hparams,
callbacks cbs);
virtual ~llama_kv_cache_unified() = default;
// TODO: become constructor
bool init(
const llama_model & model, // TODO: do not reference the model
@ -87,21 +108,30 @@ public:
ggml_type type_k,
ggml_type type_v,
uint32_t kv_size,
bool offload);
bool offload) override;
llama_kv_cache_unified(
const llama_hparams & hparams,
callbacks cbs);
~llama_kv_cache_unified() = default;
int32_t get_n_tokens() const override;
int32_t get_used_cells() const override;
bool get_has_shift() const override;
bool get_do_defrag() const override;
size_t total_size() const;
// TODO: better data structures to reduce the cost of this operation
llama_pos pos_max() const;
llama_pos get_pos_max() const override;
void clear() override;
void defrag() override;
virtual void restore() override;
virtual void commit() override;
void restore() override;
void commit() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
@ -117,7 +147,7 @@ public:
// updates the cache head
// Note: On success, it's important that cache.head points
// to the first cell of the slot.
bool find_slot(const llama_ubatch & batch);
bool find_slot(const llama_ubatch & batch) override;
// TODO: maybe not needed
uint32_t get_padding(const llama_cparams & cparams) const;
@ -125,8 +155,10 @@ public:
// find how many cells are currently in use
uint32_t cell_max() const;
size_t size_k_bytes() const;
size_t size_v_bytes() const;
void set_full() override;
size_t size_k_bytes() const override;
size_t size_v_bytes() const override;
// defrag
@ -151,8 +183,8 @@ public:
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1);
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
// members
@ -163,9 +195,6 @@ public:
bool has_shift = false;
bool do_defrag = false;
// TODO: remove this and implement llama_kv_cache_recurrent instead
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
bool v_trans = true; // the value tensor is transposed
bool can_shift = false;
@ -198,11 +227,124 @@ private:
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
};
// TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified
//class llama_kv_cache_recurrent : public llama_kv_cache_unified {
//public:
// using llama_kv_cache_unified::llama_kv_cache_unified;
//};
class llama_kv_cache_recurrent : public llama_kv_cache {
public:
// can be used to query data from the model if needed
struct callbacks {
std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
};
llama_kv_cache_recurrent(
const llama_hparams & hparams,
callbacks cbs);
~llama_kv_cache_recurrent() = default;
// TODO: become constructor
bool init(
const llama_model & model, // TODO: do not reference the model
const llama_cparams & cparams,
ggml_type type_k,
ggml_type type_v,
uint32_t kv_size,
bool offload) override;
int32_t get_n_tokens() const override;
int32_t get_used_cells() const override;
bool get_has_shift() const override;
bool get_do_defrag() const override;
size_t total_size() const;
// TODO: better data structures to reduce the cost of this operation
llama_pos get_pos_max() const override;
void clear() override;
void defrag() override;
void restore() override;
void commit() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
bool get_can_shift() const override;
// find an empty slot of size "n_tokens" in the cache
// updates the cache head
// Note: On success, it's important that cache.head points
// to the first cell of the slot.
bool find_slot(const llama_ubatch & batch) override;
// TODO: maybe not needed
uint32_t get_padding(const llama_cparams & cparams) const;
// find how many cells are currently in use
uint32_t cell_max() const;
void set_full() override;
size_t size_k_bytes() const override;
size_t size_v_bytes() const override;
// commit/restore cache
struct slot_range {
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
uint32_t c1 = 0;
};
// pending cell updates that are not yet committed
struct {
std::vector<slot_range> ranges;
} pending;
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
// members
const llama_hparams & hparams;
callbacks cbs;
// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_impl also uses it, so it
// cannot be freely changed after a slot has been allocated.
uint32_t head = 0;
uint32_t size = 0;
uint32_t used = 0; // used cells (i.e. at least one seq_id)
// computed before each graph build
uint32_t n = 0;
std::vector<llama_kv_cell> cells;
std::vector<ggml_tensor *> k_l; // per layer
std::vector<ggml_tensor *> v_l;
private:
ggml_type type_k = GGML_TYPE_F16;
ggml_type type_v = GGML_TYPE_F16;
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
};
//
// kv cache view

View File

@ -8128,7 +8128,7 @@ struct llm_build_mamba : public llm_graph_context {
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const auto kv_head = kv_self->head;
@ -10691,7 +10691,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const auto n_tokens = ubatch.n_tokens;
const auto n_seqs = ubatch.n_seqs;
@ -11087,7 +11087,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
ggml_tensor *& first_layer_value,
const llama_ubatch & ubatch,
int il) const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const auto n_tokens = ubatch.n_tokens;
const auto n_seqs = ubatch.n_seqs;
@ -12057,7 +12057,7 @@ llama_memory_i * llama_model::create_memory() const {
case LLM_ARCH_RWKV7:
case LLM_ARCH_ARWKV7:
{
res = new llama_kv_cache_unified(hparams, {
res = new llama_kv_cache_recurrent(hparams, {
/*.get_rope_factors =*/ nullptr
});
} break;