mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-04-14 10:36:07 +00:00
kv-cache : serparate recurrent vs non-recurrent impl (wip)
ggml-ci
This commit is contained in:
parent
916c83bfe7
commit
19eb81e083
@ -178,24 +178,37 @@ llama_context::llama_context(
|
||||
// init the memory module
|
||||
// TODO: for now, always create a unified KV cache
|
||||
if (!hparams.vocab_only) {
|
||||
kv_self.reset(static_cast<llama_kv_cache_unified *>(model.create_memory()));
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
||||
|
||||
cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv_self->get_padding(cparams));
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
||||
|
||||
uint32_t kv_size = cparams.n_ctx;
|
||||
uint32_t kv_size = 0;
|
||||
ggml_type type_k = params.type_k;
|
||||
ggml_type type_v = params.type_v;
|
||||
|
||||
if (llama_model_is_recurrent(&model)) {
|
||||
if (!llama_model_is_recurrent(&model)) {
|
||||
//kv_self.reset(static_cast<llama_kv_cache_unified *>(model.create_memory()));
|
||||
auto * kv = static_cast<llama_kv_cache_unified *>(model.create_memory());
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
||||
|
||||
cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv->get_padding(cparams));
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
||||
|
||||
kv_size = cparams.n_ctx;
|
||||
type_k = params.type_k;
|
||||
type_v = params.type_v;
|
||||
|
||||
kv_self.reset(kv);
|
||||
} else {
|
||||
auto * kv = static_cast<llama_kv_cache_recurrent *>(model.create_memory());
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
||||
|
||||
// Mamba needs at least as many KV cells as there are sequences kept at any time
|
||||
kv_size = std::max((uint32_t) 1, params.n_seq_max);
|
||||
// it's probably best to keep as much precision as possible for the states
|
||||
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
|
||||
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
|
||||
|
||||
kv_self.reset(kv);
|
||||
}
|
||||
|
||||
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
|
||||
@ -304,7 +317,7 @@ llama_context::llama_context(
|
||||
int n_nodes_tg = -1;
|
||||
|
||||
// simulate full KV cache
|
||||
kv_self->n = kv_self->size;
|
||||
kv_self->set_full();
|
||||
|
||||
cross.v_embd.clear();
|
||||
|
||||
@ -553,7 +566,9 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
|
||||
|
||||
//GGML_ASSERT(kv_self->size == n_ctx);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_k_shift>(kv_self.get());
|
||||
const auto * kv = static_cast<const llama_kv_cache_unified *>(kv_self.get());
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_k_shift>(kv);
|
||||
|
||||
inp->k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_ctx);
|
||||
ggml_set_input(inp->k_shift);
|
||||
@ -569,16 +584,16 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
|
||||
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
|
||||
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
|
||||
|
||||
ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
|
||||
ggml_tensor * rope_factors = kv->cbs.get_rope_factors(n_ctx_per_seq(), il);
|
||||
|
||||
ggml_tensor * k =
|
||||
ggml_view_3d(ctx0, kv_self->k_l[il],
|
||||
n_embd_head_k, n_head_kv, kv_self->size,
|
||||
ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
|
||||
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
||||
ggml_view_3d(ctx0, kv->k_l[il],
|
||||
n_embd_head_k, n_head_kv, kv->size,
|
||||
ggml_row_size(kv->k_l[il]->type, n_embd_head_k),
|
||||
ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa),
|
||||
0);
|
||||
|
||||
ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv_self->k_l[il]->buffer);
|
||||
ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv->k_l[il]->buffer);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
@ -593,9 +608,11 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
|
||||
ggml_cgraph * gf) const {
|
||||
auto res = std::make_unique<llm_graph_result>();
|
||||
|
||||
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const auto & ids = kv_self->defrag_info.ids;
|
||||
const auto & ids = kv->defrag_info.ids;
|
||||
|
||||
#if 0
|
||||
// CPU defrag
|
||||
@ -685,40 +702,40 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
|
||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
|
||||
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self->k_l[il],
|
||||
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv->k_l[il],
|
||||
n_embd_k_gqa, nm,
|
||||
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
||||
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i));
|
||||
ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa),
|
||||
ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa*i));
|
||||
|
||||
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self->k_l[il],
|
||||
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv->k_l[il],
|
||||
n_embd_k_gqa, nm,
|
||||
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
||||
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id));
|
||||
ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa),
|
||||
ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa*id));
|
||||
|
||||
ggml_tensor * view_v_src;
|
||||
ggml_tensor * view_v_dst;
|
||||
|
||||
if (cparams.flash_attn) {
|
||||
// NOTE: the V cache is not transposed when using flash attention
|
||||
view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
|
||||
view_v_src = ggml_view_2d(ctx0, kv->v_l[il],
|
||||
n_embd_v_gqa, nm,
|
||||
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
|
||||
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i));
|
||||
ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa),
|
||||
ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa*i));
|
||||
|
||||
view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
|
||||
view_v_dst = ggml_view_2d(ctx0, kv->v_l[il],
|
||||
n_embd_v_gqa, nm,
|
||||
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
|
||||
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id));
|
||||
ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa),
|
||||
ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa*id));
|
||||
} else {
|
||||
view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
|
||||
view_v_src = ggml_view_2d(ctx0, kv->v_l[il],
|
||||
nm, n_embd_v_gqa,
|
||||
ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
|
||||
ggml_row_size(kv_self->v_l[il]->type, i));
|
||||
ggml_row_size(kv->v_l[il]->type, kv->size),
|
||||
ggml_row_size(kv->v_l[il]->type, i));
|
||||
|
||||
view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
|
||||
view_v_dst = ggml_view_2d(ctx0, kv->v_l[il],
|
||||
nm, n_embd_v_gqa,
|
||||
ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
|
||||
ggml_row_size(kv_self->v_l[il]->type, id));
|
||||
ggml_row_size(kv->v_l[il]->type, kv->size),
|
||||
ggml_row_size(kv->v_l[il]->type, id));
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
|
||||
@ -735,13 +752,11 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
|
||||
}
|
||||
|
||||
void llama_context::kv_self_update() {
|
||||
auto & kv = kv_self;
|
||||
|
||||
bool need_reserve = false;
|
||||
|
||||
if (kv->has_shift) {
|
||||
if (!kv->get_can_shift()) {
|
||||
GGML_ABORT("The current context does not support K-shift");
|
||||
if (kv_self->get_has_shift()) {
|
||||
if (!kv_self->get_can_shift()) {
|
||||
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
|
||||
}
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
|
||||
@ -764,6 +779,8 @@ void llama_context::kv_self_update() {
|
||||
}
|
||||
|
||||
{
|
||||
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
|
||||
|
||||
kv->has_shift = false;
|
||||
|
||||
for (uint32_t i = 0; i < kv->size; ++i) {
|
||||
@ -773,9 +790,11 @@ void llama_context::kv_self_update() {
|
||||
}
|
||||
|
||||
// defragment the KV cache if needed
|
||||
if (kv->do_defrag) {
|
||||
if (kv_self->get_do_defrag()) {
|
||||
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
|
||||
|
||||
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
|
||||
|
||||
if (kv->defrag_prepare(graph_max_nodes())) {
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
|
||||
@ -804,7 +823,7 @@ void llama_context::kv_self_update() {
|
||||
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
// simulate full KV cache
|
||||
kv_self->n = kv_self->size;
|
||||
kv_self->set_full();
|
||||
|
||||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
@ -1024,8 +1043,8 @@ int llama_context::encode(llama_batch & inp_batch) {
|
||||
}
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
|
||||
// TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
|
||||
|
||||
const llama_batch & batch = batch_allocr.batch;
|
||||
const int32_t n_tokens = batch.n_tokens;
|
||||
@ -1189,8 +1208,8 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
}
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
|
||||
// TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
|
||||
|
||||
const llama_batch & batch = batch_allocr.batch;
|
||||
|
||||
@ -1245,8 +1264,10 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
|
||||
const bool logits_all = n_outputs_all == n_tokens_all;
|
||||
|
||||
const bool is_recurrent = llama_model_is_recurrent(&model);
|
||||
|
||||
sbatch.from_batch(batch, n_embd,
|
||||
/* simple_split */ !kv_self->recurrent,
|
||||
/* simple_split */ !is_recurrent,
|
||||
/* logits_all */ logits_all);
|
||||
|
||||
// reserve output buffer
|
||||
@ -1265,7 +1286,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
|
||||
const auto & n_ubatch = cparams.n_ubatch;
|
||||
|
||||
if (kv_self->recurrent) {
|
||||
if (is_recurrent) {
|
||||
if (embd_pooled) {
|
||||
// Pooled embeddings cannot be split across ubatches (yet)
|
||||
ubatch = sbatch.split_seq(cparams.n_ubatch);
|
||||
@ -1303,17 +1324,19 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (!kv_self->recurrent) {
|
||||
if (!is_recurrent) {
|
||||
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
|
||||
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
// if we start defragmenting the cache, the benefit from this will be more important
|
||||
const uint32_t pad = kv_self->get_padding(cparams);
|
||||
kv_self->n = std::min(kv_self->size, std::max(pad, GGML_PAD(kv_self->cell_max(), pad)));
|
||||
const uint32_t pad = kv->get_padding(cparams);
|
||||
kv->n = std::min(kv->size, std::max(pad, GGML_PAD(kv->cell_max(), pad)));
|
||||
|
||||
//printf("kv.n = %5d, kv.used = %5d, kv.head = %5d\n", kv->n, kv->used, kv->head);
|
||||
}
|
||||
}
|
||||
|
||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head);
|
||||
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||
|
||||
@ -1453,10 +1476,12 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
//synchronize();
|
||||
|
||||
// decide if we need to defrag the kv cache
|
||||
if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
|
||||
if (!llama_model_is_recurrent(&model) && cparams.causal_attn && cparams.defrag_thold > 0.0f) {
|
||||
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
|
||||
|
||||
// - do not defrag small contexts (i.e. < 2048 tokens)
|
||||
// - count the padding towards the number of used tokens
|
||||
const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f;
|
||||
const float fragmentation = kv->n >= 2048 ? std::max(0.0f, 1.0f - float(kv->used + kv->get_padding(cparams))/float(kv->n)) : 0.0f;
|
||||
|
||||
// queue defragmentation for next llama_kv_cache_update
|
||||
if (fragmentation > cparams.defrag_thold) {
|
||||
|
@ -201,7 +201,7 @@ private:
|
||||
|
||||
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
|
||||
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_self;
|
||||
std::unique_ptr<llama_kv_cache> kv_self;
|
||||
|
||||
// TODO: remove
|
||||
bool logits_all = false;
|
||||
|
@ -258,7 +258,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
||||
|
||||
//////////////////////////////////////////////
|
||||
// TODO: this should not mutate the KV cache !
|
||||
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
|
||||
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_recurrent *>(kv_self)->cells[i];
|
||||
|
||||
// prevent out-of-bound sources
|
||||
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
|
||||
@ -291,7 +291,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
|
||||
|
||||
//////////////////////////////////////////////
|
||||
// TODO: this should not mutate the KV cache !
|
||||
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
|
||||
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_recurrent *>(kv_self)->cells[i];
|
||||
|
||||
data[i] = (float) (kv_cell.src >= 0);
|
||||
|
||||
@ -1021,7 +1021,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
|
||||
|
||||
@ -1038,7 +1038,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
|
||||
|
||||
@ -1332,8 +1332,6 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
|
||||
// store to KV cache
|
||||
{
|
||||
GGML_ASSERT(!kv_self->recurrent);
|
||||
|
||||
const auto kv_head = kv_self->head;
|
||||
|
||||
GGML_ASSERT(kv_self->size == n_ctx);
|
||||
@ -1482,7 +1480,7 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
|
||||
ggml_tensor * state_mask,
|
||||
int32_t n_state,
|
||||
int32_t n_seqs) const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
|
||||
const auto n_kv = kv_self->n;
|
||||
const auto kv_head = kv_self->head;
|
||||
@ -1514,7 +1512,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
||||
ggml_tensor * state_mask,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
|
||||
const auto token_shift_count = hparams.token_shift_count;
|
||||
|
||||
@ -1535,7 +1533,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
||||
ggml_tensor * token_shift,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
|
||||
const auto token_shift_count = hparams.token_shift_count;
|
||||
const auto n_embd = hparams.n_embd;
|
||||
|
@ -19,6 +19,7 @@ struct llama_cparams;
|
||||
|
||||
class llama_memory_i;
|
||||
class llama_kv_cache_unified;
|
||||
class llama_kv_cache_recurrent;
|
||||
|
||||
// certain models (typically multi-modal) can produce different types of graphs
|
||||
enum llm_graph_type {
|
||||
@ -171,26 +172,26 @@ public:
|
||||
|
||||
class llm_graph_input_s_copy : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
|
||||
llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
|
||||
virtual ~llm_graph_input_s_copy() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * s_copy; // I32 [kv_size]
|
||||
|
||||
const llama_kv_cache_unified * kv_self;
|
||||
const llama_kv_cache_recurrent * kv_self;
|
||||
};
|
||||
|
||||
class llm_graph_input_s_mask : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
|
||||
llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
|
||||
virtual ~llm_graph_input_s_mask() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * s_mask; // F32 [1, n_kv]
|
||||
|
||||
const llama_kv_cache_unified * kv_self;
|
||||
const llama_kv_cache_recurrent * kv_self;
|
||||
};
|
||||
|
||||
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -15,17 +15,44 @@ struct llama_hparams;
|
||||
struct llama_ubatch;
|
||||
|
||||
struct llama_kv_cache : public llama_memory_i {
|
||||
virtual ~llama_kv_cache() = default;
|
||||
|
||||
using llama_memory_i::llama_memory_i;
|
||||
|
||||
// TODO: become constructor
|
||||
virtual bool init(
|
||||
const llama_model & model, // TODO: do not reference the model
|
||||
const llama_cparams & cparams,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
uint32_t kv_size,
|
||||
bool offload) = 0;
|
||||
|
||||
virtual void restore() = 0; // call if batch processing fails - restores the cache state
|
||||
virtual void commit() = 0; // call after successful batch processing - clears any pending state
|
||||
|
||||
virtual int32_t get_n_tokens() const = 0;
|
||||
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
|
||||
|
||||
virtual bool get_has_shift() const = 0;
|
||||
virtual bool get_do_defrag() const = 0;
|
||||
|
||||
virtual llama_pos get_pos_max() const = 0;
|
||||
|
||||
virtual bool get_can_shift() const = 0;
|
||||
|
||||
bool get_can_edit() const override { return get_can_shift(); }
|
||||
|
||||
virtual bool find_slot(const llama_ubatch & batch) = 0;
|
||||
|
||||
// simulate full cache, used for allocating worst-case compute buffers
|
||||
virtual void set_full() = 0;
|
||||
|
||||
virtual size_t size_k_bytes() const = 0;
|
||||
virtual size_t size_v_bytes() const = 0;
|
||||
|
||||
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
|
||||
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
|
||||
};
|
||||
|
||||
struct llama_kv_cache_guard {
|
||||
@ -74,12 +101,6 @@ public:
|
||||
std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
|
||||
};
|
||||
|
||||
llama_kv_cache_unified(
|
||||
const llama_hparams & hparams,
|
||||
callbacks cbs);
|
||||
|
||||
virtual ~llama_kv_cache_unified() = default;
|
||||
|
||||
// TODO: become constructor
|
||||
bool init(
|
||||
const llama_model & model, // TODO: do not reference the model
|
||||
@ -87,21 +108,30 @@ public:
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
uint32_t kv_size,
|
||||
bool offload);
|
||||
bool offload) override;
|
||||
|
||||
llama_kv_cache_unified(
|
||||
const llama_hparams & hparams,
|
||||
callbacks cbs);
|
||||
|
||||
~llama_kv_cache_unified() = default;
|
||||
|
||||
int32_t get_n_tokens() const override;
|
||||
int32_t get_used_cells() const override;
|
||||
|
||||
bool get_has_shift() const override;
|
||||
bool get_do_defrag() const override;
|
||||
|
||||
size_t total_size() const;
|
||||
|
||||
// TODO: better data structures to reduce the cost of this operation
|
||||
llama_pos pos_max() const;
|
||||
llama_pos get_pos_max() const override;
|
||||
|
||||
void clear() override;
|
||||
void defrag() override;
|
||||
|
||||
virtual void restore() override;
|
||||
virtual void commit() override;
|
||||
void restore() override;
|
||||
void commit() override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
@ -117,7 +147,7 @@ public:
|
||||
// updates the cache head
|
||||
// Note: On success, it's important that cache.head points
|
||||
// to the first cell of the slot.
|
||||
bool find_slot(const llama_ubatch & batch);
|
||||
bool find_slot(const llama_ubatch & batch) override;
|
||||
|
||||
// TODO: maybe not needed
|
||||
uint32_t get_padding(const llama_cparams & cparams) const;
|
||||
@ -125,8 +155,10 @@ public:
|
||||
// find how many cells are currently in use
|
||||
uint32_t cell_max() const;
|
||||
|
||||
size_t size_k_bytes() const;
|
||||
size_t size_v_bytes() const;
|
||||
void set_full() override;
|
||||
|
||||
size_t size_k_bytes() const override;
|
||||
size_t size_v_bytes() const override;
|
||||
|
||||
// defrag
|
||||
|
||||
@ -151,8 +183,8 @@ public:
|
||||
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1);
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||
|
||||
// members
|
||||
|
||||
@ -163,9 +195,6 @@ public:
|
||||
bool has_shift = false;
|
||||
bool do_defrag = false;
|
||||
|
||||
// TODO: remove this and implement llama_kv_cache_recurrent instead
|
||||
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
|
||||
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
bool can_shift = false;
|
||||
|
||||
@ -198,11 +227,124 @@ private:
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||
};
|
||||
|
||||
// TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified
|
||||
//class llama_kv_cache_recurrent : public llama_kv_cache_unified {
|
||||
//public:
|
||||
// using llama_kv_cache_unified::llama_kv_cache_unified;
|
||||
//};
|
||||
class llama_kv_cache_recurrent : public llama_kv_cache {
|
||||
public:
|
||||
// can be used to query data from the model if needed
|
||||
struct callbacks {
|
||||
std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
|
||||
};
|
||||
|
||||
llama_kv_cache_recurrent(
|
||||
const llama_hparams & hparams,
|
||||
callbacks cbs);
|
||||
|
||||
~llama_kv_cache_recurrent() = default;
|
||||
|
||||
// TODO: become constructor
|
||||
bool init(
|
||||
const llama_model & model, // TODO: do not reference the model
|
||||
const llama_cparams & cparams,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
uint32_t kv_size,
|
||||
bool offload) override;
|
||||
|
||||
int32_t get_n_tokens() const override;
|
||||
int32_t get_used_cells() const override;
|
||||
|
||||
bool get_has_shift() const override;
|
||||
bool get_do_defrag() const override;
|
||||
|
||||
size_t total_size() const;
|
||||
|
||||
// TODO: better data structures to reduce the cost of this operation
|
||||
llama_pos get_pos_max() const override;
|
||||
|
||||
void clear() override;
|
||||
void defrag() override;
|
||||
|
||||
void restore() override;
|
||||
void commit() override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
// find an empty slot of size "n_tokens" in the cache
|
||||
// updates the cache head
|
||||
// Note: On success, it's important that cache.head points
|
||||
// to the first cell of the slot.
|
||||
bool find_slot(const llama_ubatch & batch) override;
|
||||
|
||||
// TODO: maybe not needed
|
||||
uint32_t get_padding(const llama_cparams & cparams) const;
|
||||
|
||||
// find how many cells are currently in use
|
||||
uint32_t cell_max() const;
|
||||
|
||||
void set_full() override;
|
||||
|
||||
size_t size_k_bytes() const override;
|
||||
size_t size_v_bytes() const override;
|
||||
|
||||
// commit/restore cache
|
||||
|
||||
struct slot_range {
|
||||
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
|
||||
uint32_t c1 = 0;
|
||||
};
|
||||
|
||||
// pending cell updates that are not yet committed
|
||||
struct {
|
||||
std::vector<slot_range> ranges;
|
||||
} pending;
|
||||
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||
|
||||
// members
|
||||
|
||||
const llama_hparams & hparams;
|
||||
|
||||
callbacks cbs;
|
||||
|
||||
// Note: The value of head isn't only used to optimize searching
|
||||
// for a free KV slot. llama_decode_impl also uses it, so it
|
||||
// cannot be freely changed after a slot has been allocated.
|
||||
uint32_t head = 0;
|
||||
uint32_t size = 0;
|
||||
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
||||
|
||||
// computed before each graph build
|
||||
uint32_t n = 0;
|
||||
|
||||
std::vector<llama_kv_cell> cells;
|
||||
|
||||
std::vector<ggml_tensor *> k_l; // per layer
|
||||
std::vector<ggml_tensor *> v_l;
|
||||
|
||||
private:
|
||||
ggml_type type_k = GGML_TYPE_F16;
|
||||
ggml_type type_v = GGML_TYPE_F16;
|
||||
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
||||
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
||||
|
||||
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||
};
|
||||
|
||||
|
||||
//
|
||||
// kv cache view
|
||||
|
@ -8128,7 +8128,7 @@ struct llm_build_mamba : public llm_graph_context {
|
||||
ggml_tensor * state_mask,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
|
||||
const auto kv_head = kv_self->head;
|
||||
|
||||
@ -10691,7 +10691,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
||||
ggml_tensor * state_mask,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_seqs = ubatch.n_seqs;
|
||||
@ -11087,7 +11087,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
||||
ggml_tensor *& first_layer_value,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_seqs = ubatch.n_seqs;
|
||||
@ -12057,7 +12057,7 @@ llama_memory_i * llama_model::create_memory() const {
|
||||
case LLM_ARCH_RWKV7:
|
||||
case LLM_ARCH_ARWKV7:
|
||||
{
|
||||
res = new llama_kv_cache_unified(hparams, {
|
||||
res = new llama_kv_cache_recurrent(hparams, {
|
||||
/*.get_rope_factors =*/ nullptr
|
||||
});
|
||||
} break;
|
||||
|
Loading…
x
Reference in New Issue
Block a user