From 19eb81e0831d28896af636a73a88dd48b0ae9829 Mon Sep 17 00:00:00 2001
From: Georgi Gerganov <ggerganov@gmail.com>
Date: Mon, 7 Apr 2025 16:50:58 +0300
Subject: [PATCH] kv-cache : serparate recurrent vs non-recurrent impl (wip)

ggml-ci
---
 src/llama-context.cpp  |  139 +++--
 src/llama-context.h    |    2 +-
 src/llama-graph.cpp    |   16 +-
 src/llama-graph.h      |    9 +-
 src/llama-kv-cache.cpp | 1269 +++++++++++++++++++++++++++++++---------
 src/llama-kv-cache.h   |  188 +++++-
 src/llama-model.cpp    |    8 +-
 7 files changed, 1245 insertions(+), 386 deletions(-)

diff --git a/src/llama-context.cpp b/src/llama-context.cpp
index 4735e98ea..e80ff22fd 100644
--- a/src/llama-context.cpp
+++ b/src/llama-context.cpp
@@ -178,24 +178,37 @@ llama_context::llama_context(
     // init the memory module
     // TODO: for now, always create a unified KV cache
     if (!hparams.vocab_only) {
-        kv_self.reset(static_cast<llama_kv_cache_unified *>(model.create_memory()));
-
-        LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
-
-        cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv_self->get_padding(cparams));
-
-        LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
-
-        uint32_t kv_size = cparams.n_ctx;
+        uint32_t kv_size = 0;
         ggml_type type_k = params.type_k;
         ggml_type type_v = params.type_v;
 
-        if (llama_model_is_recurrent(&model)) {
+        if (!llama_model_is_recurrent(&model)) {
+            //kv_self.reset(static_cast<llama_kv_cache_unified *>(model.create_memory()));
+            auto * kv = static_cast<llama_kv_cache_unified *>(model.create_memory());
+
+            LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
+
+            cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv->get_padding(cparams));
+
+            LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
+
+            kv_size = cparams.n_ctx;
+            type_k = params.type_k;
+            type_v = params.type_v;
+
+            kv_self.reset(kv);
+        } else {
+            auto * kv = static_cast<llama_kv_cache_recurrent *>(model.create_memory());
+
+            LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
+
             // Mamba needs at least as many KV cells as there are sequences kept at any time
             kv_size = std::max((uint32_t) 1, params.n_seq_max);
             // it's probably best to keep as much precision as possible for the states
             type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
             type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
+
+            kv_self.reset(kv);
         }
 
         GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
@@ -304,7 +317,7 @@ llama_context::llama_context(
         int n_nodes_tg  = -1;
 
         // simulate full KV cache
-        kv_self->n = kv_self->size;
+        kv_self->set_full();
 
         cross.v_embd.clear();
 
@@ -553,7 +566,9 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
 
     //GGML_ASSERT(kv_self->size == n_ctx);
 
-    auto inp = std::make_unique<llm_graph_input_k_shift>(kv_self.get());
+    const auto * kv = static_cast<const llama_kv_cache_unified *>(kv_self.get());
+
+    auto inp = std::make_unique<llm_graph_input_k_shift>(kv);
 
     inp->k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_ctx);
     ggml_set_input(inp->k_shift);
@@ -569,16 +584,16 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
         const float freq_base_l  = is_swa ? hparams.rope_freq_base_train_swa  : cparams.rope_freq_base;
         const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
 
-        ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
+        ggml_tensor * rope_factors = kv->cbs.get_rope_factors(n_ctx_per_seq(), il);
 
         ggml_tensor * k =
-            ggml_view_3d(ctx0, kv_self->k_l[il],
-                n_embd_head_k, n_head_kv, kv_self->size,
-                ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
-                ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
+            ggml_view_3d(ctx0, kv->k_l[il],
+                n_embd_head_k, n_head_kv, kv->size,
+                ggml_row_size(kv->k_l[il]->type, n_embd_head_k),
+                ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa),
                 0);
 
-        ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv_self->k_l[il]->buffer);
+        ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv->k_l[il]->buffer);
 
         ggml_build_forward_expand(gf, cur);
     }
@@ -593,9 +608,11 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
         ggml_cgraph * gf) const {
     auto res = std::make_unique<llm_graph_result>();
 
+    auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
+
     const auto & hparams = model.hparams;
 
-    const auto & ids = kv_self->defrag_info.ids;
+    const auto & ids = kv->defrag_info.ids;
 
 #if 0
     // CPU defrag
@@ -685,40 +702,40 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
             const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
             const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
-            ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self->k_l[il],
+            ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv->k_l[il],
                     n_embd_k_gqa, nm,
-                    ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
-                    ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i));
+                    ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa),
+                    ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa*i));
 
-            ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self->k_l[il],
+            ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv->k_l[il],
                     n_embd_k_gqa, nm,
-                    ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
-                    ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id));
+                    ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa),
+                    ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa*id));
 
             ggml_tensor * view_v_src;
             ggml_tensor * view_v_dst;
 
             if (cparams.flash_attn) {
                 // NOTE: the V cache is not transposed when using flash attention
-                view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
+                view_v_src = ggml_view_2d(ctx0, kv->v_l[il],
                         n_embd_v_gqa, nm,
-                        ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
-                        ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i));
+                        ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa),
+                        ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa*i));
 
-                view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
+                view_v_dst = ggml_view_2d(ctx0, kv->v_l[il],
                         n_embd_v_gqa, nm,
-                        ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
-                        ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id));
+                        ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa),
+                        ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa*id));
             } else {
-                view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
+                view_v_src = ggml_view_2d(ctx0, kv->v_l[il],
                         nm, n_embd_v_gqa,
-                        ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
-                        ggml_row_size(kv_self->v_l[il]->type, i));
+                        ggml_row_size(kv->v_l[il]->type, kv->size),
+                        ggml_row_size(kv->v_l[il]->type, i));
 
-                view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
+                view_v_dst = ggml_view_2d(ctx0, kv->v_l[il],
                         nm, n_embd_v_gqa,
-                        ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
-                        ggml_row_size(kv_self->v_l[il]->type, id));
+                        ggml_row_size(kv->v_l[il]->type, kv->size),
+                        ggml_row_size(kv->v_l[il]->type, id));
             }
 
             ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
@@ -735,13 +752,11 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
 }
 
 void llama_context::kv_self_update() {
-    auto & kv = kv_self;
-
     bool need_reserve = false;
 
-    if (kv->has_shift) {
-        if (!kv->get_can_shift()) {
-            GGML_ABORT("The current context does not support K-shift");
+    if (kv_self->get_has_shift()) {
+        if (!kv_self->get_can_shift()) {
+            GGML_ABORT("The current KV cache / model configuration does not support K-shift");
         }
 
         LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
@@ -764,6 +779,8 @@ void llama_context::kv_self_update() {
         }
 
         {
+            auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
+
             kv->has_shift = false;
 
             for (uint32_t i = 0; i < kv->size; ++i) {
@@ -773,9 +790,11 @@ void llama_context::kv_self_update() {
     }
 
     // defragment the KV cache if needed
-    if (kv->do_defrag) {
+    if (kv_self->get_do_defrag()) {
         LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
 
+        auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
+
         if (kv->defrag_prepare(graph_max_nodes())) {
             ggml_backend_sched_reset(sched.get());
 
@@ -804,7 +823,7 @@ void llama_context::kv_self_update() {
         uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
 
         // simulate full KV cache
-        kv_self->n = kv_self->size;
+        kv_self->set_full();
 
         llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
         llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
@@ -1024,8 +1043,8 @@ int llama_context::encode(llama_batch & inp_batch) {
     }
 
     // temporary allocate memory for the input batch if needed
-    // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
-    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
+    // TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
+    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
 
     const llama_batch & batch = batch_allocr.batch;
     const int32_t n_tokens = batch.n_tokens;
@@ -1189,8 +1208,8 @@ int llama_context::decode(llama_batch & inp_batch) {
     }
 
     // temporary allocate memory for the input batch if needed
-    // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
-    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
+    // TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
+    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
 
     const llama_batch & batch = batch_allocr.batch;
 
@@ -1245,8 +1264,10 @@ int llama_context::decode(llama_batch & inp_batch) {
 
     const bool logits_all = n_outputs_all == n_tokens_all;
 
+    const bool is_recurrent = llama_model_is_recurrent(&model);
+
     sbatch.from_batch(batch, n_embd,
-            /* simple_split */ !kv_self->recurrent,
+            /* simple_split */ !is_recurrent,
             /* logits_all   */ logits_all);
 
     // reserve output buffer
@@ -1265,7 +1286,7 @@ int llama_context::decode(llama_batch & inp_batch) {
 
         const auto & n_ubatch = cparams.n_ubatch;
 
-        if (kv_self->recurrent) {
+        if (is_recurrent) {
             if (embd_pooled) {
                 // Pooled embeddings cannot be split across ubatches (yet)
                 ubatch = sbatch.split_seq(cparams.n_ubatch);
@@ -1303,17 +1324,19 @@ int llama_context::decode(llama_batch & inp_batch) {
                 return 1;
             }
 
-            if (!kv_self->recurrent) {
+            if (!is_recurrent) {
+                auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
+
                 // a heuristic, to avoid attending the full cache if it is not yet utilized
                 // after enough generations, the benefit from this heuristic disappears
                 // if we start defragmenting the cache, the benefit from this will be more important
-                const uint32_t pad = kv_self->get_padding(cparams);
-                kv_self->n = std::min(kv_self->size, std::max(pad, GGML_PAD(kv_self->cell_max(), pad)));
+                const uint32_t pad = kv->get_padding(cparams);
+                kv->n = std::min(kv->size, std::max(pad, GGML_PAD(kv->cell_max(), pad)));
+
+                //printf("kv.n = %5d, kv.used = %5d, kv.head = %5d\n", kv->n, kv->used, kv->head);
             }
         }
 
-        //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head);
-
         ggml_backend_sched_reset(sched.get());
         ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
 
@@ -1453,10 +1476,12 @@ int llama_context::decode(llama_batch & inp_batch) {
     //synchronize();
 
     // decide if we need to defrag the kv cache
-    if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
+    if (!llama_model_is_recurrent(&model) && cparams.causal_attn && cparams.defrag_thold > 0.0f) {
+        auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
+
         // - do not defrag small contexts (i.e. < 2048 tokens)
         // - count the padding towards the number of used tokens
-        const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f;
+        const float fragmentation = kv->n >= 2048 ? std::max(0.0f, 1.0f - float(kv->used + kv->get_padding(cparams))/float(kv->n)) : 0.0f;
 
         // queue defragmentation for next llama_kv_cache_update
         if (fragmentation > cparams.defrag_thold) {
diff --git a/src/llama-context.h b/src/llama-context.h
index 04facb544..8216c9f7e 100644
--- a/src/llama-context.h
+++ b/src/llama-context.h
@@ -201,7 +201,7 @@ private:
 
     llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
 
-    std::unique_ptr<llama_kv_cache_unified> kv_self;
+    std::unique_ptr<llama_kv_cache> kv_self;
 
     // TODO: remove
     bool logits_all = false;
diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
index cec203df4..1797a8bd7 100644
--- a/src/llama-graph.cpp
+++ b/src/llama-graph.cpp
@@ -258,7 +258,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
 
             //////////////////////////////////////////////
             // TODO: this should not mutate the KV cache !
-            llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
+            llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_recurrent *>(kv_self)->cells[i];
 
             // prevent out-of-bound sources
             if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
@@ -291,7 +291,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
 
             //////////////////////////////////////////////
             // TODO: this should not mutate the KV cache !
-            llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
+            llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_recurrent *>(kv_self)->cells[i];
 
             data[i] = (float) (kv_cell.src >= 0);
 
@@ -1021,7 +1021,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
 }
 
 ggml_tensor * llm_graph_context::build_inp_s_copy() const {
-    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+    const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
 
     auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
 
@@ -1038,7 +1038,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
 }
 
 ggml_tensor * llm_graph_context::build_inp_s_mask() const {
-    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+    const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
 
     auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
 
@@ -1332,8 +1332,6 @@ ggml_tensor * llm_graph_context::build_attn(
 
     // store to KV cache
     {
-        GGML_ASSERT(!kv_self->recurrent);
-
         const auto kv_head = kv_self->head;
 
         GGML_ASSERT(kv_self->size == n_ctx);
@@ -1482,7 +1480,7 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
          ggml_tensor * state_mask,
              int32_t   n_state,
              int32_t   n_seqs) const {
-    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+    const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
 
     const auto n_kv    = kv_self->n;
     const auto kv_head = kv_self->head;
@@ -1514,7 +1512,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
          ggml_tensor * state_mask,
   const llama_ubatch & ubatch,
                  int   il) const {
-    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+    const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
 
     const auto token_shift_count = hparams.token_shift_count;
 
@@ -1535,7 +1533,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
          ggml_tensor * token_shift,
   const llama_ubatch & ubatch,
                  int   il) const {
-    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+    const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
 
     const auto token_shift_count = hparams.token_shift_count;
     const auto n_embd = hparams.n_embd;
diff --git a/src/llama-graph.h b/src/llama-graph.h
index bdf19ed01..7745ae658 100644
--- a/src/llama-graph.h
+++ b/src/llama-graph.h
@@ -19,6 +19,7 @@ struct llama_cparams;
 
 class llama_memory_i;
 class llama_kv_cache_unified;
+class llama_kv_cache_recurrent;
 
 // certain models (typically multi-modal) can produce different types of graphs
 enum llm_graph_type {
@@ -171,26 +172,26 @@ public:
 
 class llm_graph_input_s_copy : public llm_graph_input_i {
 public:
-    llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
+    llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
     virtual ~llm_graph_input_s_copy() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
 
     ggml_tensor * s_copy; // I32 [kv_size]
 
-    const llama_kv_cache_unified * kv_self;
+    const llama_kv_cache_recurrent * kv_self;
 };
 
 class llm_graph_input_s_mask : public llm_graph_input_i {
 public:
-    llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
+    llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
     virtual ~llm_graph_input_s_mask() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
 
     ggml_tensor * s_mask; // F32 [1, n_kv]
 
-    const llama_kv_cache_unified * kv_self;
+    const llama_kv_cache_recurrent * kv_self;
 };
 
 class llm_graph_input_cross_embd : public llm_graph_input_i {
diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp
index dbf5f1187..b9248b939 100644
--- a/src/llama-kv-cache.cpp
+++ b/src/llama-kv-cache.cpp
@@ -11,6 +11,10 @@
 #include <map>
 #include <stdexcept>
 
+//
+// llama_kv_cache_unified
+//
+
 llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
 }
 
@@ -25,9 +29,10 @@ bool llama_kv_cache_unified::init(
 
     has_shift = false;
 
-    recurrent = llama_model_is_recurrent(&model);
-    v_trans   = !recurrent && !cparams.flash_attn;
-    can_shift = !recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
+    GGML_ASSERT(!llama_model_is_recurrent(&model));
+
+    v_trans   = !cparams.flash_attn;
+    can_shift = model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
 
     LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n",
             __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift);
@@ -135,6 +140,14 @@ int32_t llama_kv_cache_unified::get_used_cells() const {
     return used;
 }
 
+bool llama_kv_cache_unified::get_has_shift() const {
+    return has_shift;
+}
+
+bool llama_kv_cache_unified::get_do_defrag() const {
+    return do_defrag;
+}
+
 size_t llama_kv_cache_unified::total_size() const {
     size_t size = 0;
     for (const auto & buf : bufs) {
@@ -144,7 +157,7 @@ size_t llama_kv_cache_unified::total_size() const {
     return size;
 }
 
-llama_pos llama_kv_cache_unified::pos_max() const {
+llama_pos llama_kv_cache_unified::get_pos_max() const {
     llama_pos pos_max = -1;
     for (const auto & cell : cells) {
         pos_max = std::max(pos_max, cell.pos);
@@ -179,35 +192,6 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
         p1 = std::numeric_limits<llama_pos>::max();
     }
 
-    // models like Mamba or RWKV can't have a state partially erased
-    if (recurrent) {
-        if (seq_id >= (int64_t) size) {
-            // could be fatal
-            return false;
-        }
-        if (0 <= seq_id) {
-            int32_t & tail_id = cells[seq_id].tail;
-            if (tail_id >= 0) {
-                const llama_kv_cell & cell = cells[tail_id];
-                // partial intersection is invalid
-                if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
-                    return false;
-                }
-                // invalidate tails which will be cleared
-                if (p0 <= cell.pos && cell.pos < p1) {
-                    tail_id = -1;
-                }
-            }
-        } else {
-            // seq_id is negative, then the range should include everything or nothing
-            if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
-                return false;
-            }
-        }
-
-        return true;
-    }
-
     for (uint32_t i = 0; i < size; ++i) {
         if (cells[i].pos >= p0 && cells[i].pos < p1) {
             if (seq_id < 0) {
@@ -254,34 +238,6 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
         p1 = std::numeric_limits<llama_pos>::max();
     }
 
-    if (recurrent) {
-        if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
-            llama_kv_cell & tail_src = cells[seq_id_src];
-            llama_kv_cell & tail_dst = cells[seq_id_dst];
-            if (tail_dst.tail >= 0) {
-                // clear destination seq_id if it wasn't empty
-                llama_kv_cell & cell_dst = cells[tail_dst.tail];
-
-                cell_dst.seq_id.erase(seq_id_dst);
-                tail_dst.tail = -1;
-                if (cell_dst.seq_id.empty()) {
-                    cell_dst.pos = -1;
-                    cell_dst.delta = -1;
-                    cell_dst.src = -1;
-                    used -= 1;
-                }
-            }
-            if (tail_src.tail >= 0) {
-                llama_kv_cell & cell_src = cells[tail_src.tail];
-
-                cell_src.seq_id.insert(seq_id_dst);
-                tail_dst.tail = tail_src.tail;
-            }
-        }
-
-        return;
-    }
-
     // otherwise, this is the KV of a Transformer-like model
     head = 0;
 
@@ -296,9 +252,10 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
     uint32_t new_head = size;
 
     for (uint32_t i = 0; i < size; ++i) {
-        if (recurrent && (llama_seq_id) i != seq_id) {
-            cells[i].tail = -1;
-        }
+        // TODO: remove tail
+        //if (recurrent && (llama_seq_id) i != seq_id) {
+        //    cells[i].tail = -1;
+        //}
 
         if (!cells[i].has_seq_id(seq_id)) {
             if (cells[i].pos >= 0) {
@@ -344,20 +301,6 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
         return;
     }
 
-    if (recurrent) {
-        // for Mamba-like or RWKV models, only the pos needs to be shifted
-        if (0 <= seq_id && seq_id < (int64_t) size) {
-            const int32_t tail_id = cells[seq_id].tail;
-            if (tail_id >= 0) {
-                llama_kv_cell & cell = cells[tail_id];
-                if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
-                    cell.pos += delta;
-                }
-            }
-        }
-        return;
-    }
-
     for (uint32_t i = 0; i < size; ++i) {
         if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
             has_shift = true;
@@ -400,21 +343,6 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
         return;
     }
 
-    if (recurrent) {
-        // for Mamba-like or RWKV models, only the pos needs to be changed
-        if (0 <= seq_id && seq_id < (int64_t) size) {
-            const int32_t tail_id = cells[seq_id].tail;
-            if (tail_id >= 0) {
-                llama_kv_cell & cell = cells[tail_id];
-                if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
-                    cell.pos /= d;
-                }
-            }
-        }
-
-        return;
-    }
-
     for (uint32_t i = 0; i < size; ++i) {
         if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
             has_shift = true;
@@ -441,9 +369,7 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
 }
 
 void llama_kv_cache_unified::defrag() {
-    if (!recurrent) {
-        do_defrag = true;
-    }
+    do_defrag = true;
 }
 
 void llama_kv_cache_unified::restore() {
@@ -451,12 +377,6 @@ void llama_kv_cache_unified::restore() {
         return;
     }
 
-    // TODO: tmp - move to llama_kv_cache_recurrent
-    if (recurrent) {
-        seq_rm(-1, -1, -1);
-        return;
-    }
-
     uint32_t new_head = size;
 
     for (auto & range : pending.ranges) {
@@ -481,11 +401,6 @@ void llama_kv_cache_unified::restore() {
 }
 
 void llama_kv_cache_unified::commit() {
-    // TODO: tmp - move to llama_kv_cache_recurrent
-    if (recurrent) {
-        return;
-    }
-
     if (pending.ranges.empty()) {
         LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
                 __func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
@@ -511,169 +426,6 @@ bool llama_kv_cache_unified::find_slot(
         head = 0;
     }
 
-    if (recurrent) {
-        // For recurrent state architectures (like Mamba or RWKV),
-        // each cache cell can store the state for a whole sequence.
-        // A slot should be always be contiguous.
-
-        // can only process batches with an equal number of new tokens in each sequence
-        GGML_ASSERT(ubatch.equal_seqs);
-
-        int32_t min = size - 1;
-        int32_t max = 0;
-
-        // everything should fit if all seq_ids are smaller than the max
-        for (uint32_t s = 0; s < n_seqs; ++s) {
-            const uint32_t n_seq_id = ubatch.n_seq_id[s];
-            for (uint32_t j = 0; j < n_seq_id; ++j) {
-                const llama_seq_id seq_id = ubatch.seq_id[s][j];
-
-                if (seq_id < 0 || (uint32_t) seq_id >= size) {
-                    // too big seq_id
-                    // TODO: would it be possible to resize the cache instead?
-                    LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
-                    return false;
-                }
-                if (j > 0) {
-                    llama_kv_cell & seq = cells[seq_id];
-                    if (seq.tail >= 0) {
-                        llama_kv_cell & cell = cells[seq.tail];
-                        // clear cells from seq_ids that become shared
-                        // (should not normally happen, but let's handle it anyway)
-                        cell.seq_id.erase(seq_id);
-                        seq.tail = -1;
-                        if (cell.seq_id.empty()) {
-                            cell.pos = -1;
-                            cell.src = -1;
-                            used -= 1;
-                        }
-                    }
-                }
-            }
-        }
-
-#ifndef NDEBUG
-        {
-            std::vector<int32_t> tails_verif;
-            tails_verif.assign(size, -1);
-            for (uint32_t i = 0; i < size; ++i) {
-                llama_kv_cell & cell = cells[i];
-                for (llama_seq_id seq_id : cell.seq_id) {
-                    if (tails_verif[seq_id] != -1) {
-                        LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
-                    }
-                    tails_verif[seq_id] = i;
-                }
-            }
-            for (uint32_t i = 0; i < size; ++i) {
-                if (tails_verif[i] != cells[i].tail) {
-                    LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]);
-                }
-            }
-        }
-#endif
-
-        // find next empty cell
-        uint32_t next_empty_cell = head;
-
-        for (uint32_t i = 0; i < size; ++i) {
-            if (next_empty_cell >= size) { next_empty_cell -= size; }
-            llama_kv_cell & cell = cells[next_empty_cell];
-            if (cell.is_empty()) { break; }
-            next_empty_cell += 1;
-        }
-
-        // find usable cell range
-        for (uint32_t s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = ubatch.seq_id[s][0];
-            llama_kv_cell & seq_meta = cells[seq_id];
-            bool has_cell = false;
-            if (seq_meta.tail >= 0) {
-                llama_kv_cell & cell = cells[seq_meta.tail];
-                GGML_ASSERT(cell.has_seq_id(seq_id));
-                // does this seq_id "own" the cell?
-                if (cell.seq_id.size() == 1) { has_cell = true; }
-            }
-            if (!has_cell) {
-                llama_kv_cell & empty_cell = cells[next_empty_cell];
-                GGML_ASSERT(empty_cell.is_empty());
-                // copy old tail into the empty cell
-                if (seq_meta.tail >= 0) {
-                    llama_kv_cell & orig_cell = cells[seq_meta.tail];
-                    empty_cell.pos = orig_cell.pos;
-                    empty_cell.src = orig_cell.src;
-                    orig_cell.seq_id.erase(seq_id);
-                    empty_cell.seq_id.insert(seq_id); // will be overwritten
-                }
-                seq_meta.tail = next_empty_cell;
-                // find next empty cell
-                if (s + 1 < n_seqs) {
-                    next_empty_cell += 1;
-                    for (uint32_t i = 0; i < size; ++i) {
-                        if (next_empty_cell >= size) { next_empty_cell -= size; }
-                        llama_kv_cell & cell = cells[next_empty_cell];
-                        if (cell.is_empty()) { break; }
-                        next_empty_cell += 1;
-                    }
-                }
-            }
-            if (min > seq_meta.tail) { min = seq_meta.tail; }
-            if (max < seq_meta.tail) { max = seq_meta.tail; }
-        }
-
-        // gather and re-order
-        for (uint32_t s = 0; s < n_seqs; ++s) {
-            int32_t dst_id = s + min;
-            int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
-            if (dst_id != src_id) {
-                llama_kv_cell & dst_cell = cells[dst_id];
-                llama_kv_cell & src_cell = cells[src_id];
-
-                std::swap(dst_cell.pos, src_cell.pos);
-                std::swap(dst_cell.src, src_cell.src);
-                std::swap(dst_cell.seq_id, src_cell.seq_id);
-
-                // swap tails (assuming they NEVER overlap)
-                for (const llama_seq_id seq_id : src_cell.seq_id) {
-                    cells[seq_id].tail = src_id;
-                }
-                for (const llama_seq_id seq_id : dst_cell.seq_id) {
-                    cells[seq_id].tail = dst_id;
-                }
-            }
-        }
-
-        // update the pos of the used seqs
-        for (uint32_t s = 0; s < n_seqs; ++s) {
-            const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
-            int32_t cell_id = s + min;
-            llama_kv_cell & cell = cells[cell_id];
-
-            if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
-                // What should happen when the pos backtracks or skips a value?
-                // Clearing the state mid-batch would require special-casing which isn't done.
-                LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
-                    __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
-            }
-            cell.pos = last_pos;
-            cell.seq_id.clear();
-            for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
-                const llama_seq_id seq_id = ubatch.seq_id[s][j];
-                cell.seq_id.insert(seq_id);
-                cells[seq_id].tail = cell_id;
-            }
-        }
-
-        // allow getting the range of used cells, from head to head + n
-        head = min;
-        n    = max - min + 1;
-        used = std::count_if(cells.begin(), cells.end(),
-            [](const llama_kv_cell& cell){ return !cell.is_empty(); });
-
-        // sanity check
-        return n >= n_seqs;
-    }
-
     // otherwise, one cell per token.
 
     if (n_tokens > size) {
@@ -745,6 +497,10 @@ uint32_t llama_kv_cache_unified::cell_max() const {
     return 0;
 }
 
+void llama_kv_cache_unified::set_full() {
+    n = size;
+}
+
 size_t llama_kv_cache_unified::size_k_bytes() const {
     size_t size_k_bytes = 0;
 
@@ -1133,15 +889,6 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
                 }
 
                 cell.seq_id.insert(seq_id);
-
-                if (recurrent) {
-                    int32_t & tail = cells[seq_id].tail;
-                    if (tail != -1) {
-                        LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
-                        return false;
-                    }
-                    tail = i;
-                }
             }
         }
 
@@ -1149,14 +896,6 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
         used = cell_count;
     }
 
-    if (recurrent) {
-        for (uint32_t i = 0; i < cell_count; ++i) {
-            uint32_t cell_id = head + i;
-            // make sure the recurrent states will keep their restored state
-            cells[cell_id].src = cell_id;
-        }
-    }
-
     return true;
 }
 
@@ -1174,7 +913,961 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
         LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
         return false;
     }
-    if (v_trans != (bool) v_trans) {
+    if (this->v_trans != (bool) v_trans) {
+        LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
+        return false;
+    }
+
+    // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+
+        // Read type of key
+        int32_t k_type_i_ref;
+        io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
+        const int32_t k_type_i = (int32_t) k_l[il]->type;
+        if (k_type_i != k_type_i_ref) {
+            LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
+            return false;
+        }
+
+        // Read row size of key
+        uint64_t k_size_row_ref;
+        io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
+        const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
+        if (k_size_row != k_size_row_ref) {
+            LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
+            return false;
+        }
+
+        if (cell_count) {
+            // Read and set the keys for the whole cell range
+            ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
+        }
+    }
+
+    if (!this->v_trans) {
+        for (uint32_t il = 0; il < n_layer; ++il) {
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
+            // Read type of value
+            int32_t v_type_i_ref;
+            io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
+            const int32_t v_type_i = (int32_t)v_l[il]->type;
+            if (v_type_i != v_type_i_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
+                return false;
+            }
+
+            // Read row size of value
+            uint64_t v_size_row_ref;
+            io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
+            const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
+            if (v_size_row != v_size_row_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
+                return false;
+            }
+
+            if (cell_count) {
+                // Read and set the values for the whole cell range
+                ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
+            }
+        }
+    } else {
+        // For each layer, read the values for each cell (transposed)
+        for (uint32_t il = 0; il < n_layer; ++il) {
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
+            // Read type of value
+            int32_t v_type_i_ref;
+            io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
+            const int32_t v_type_i = (int32_t)v_l[il]->type;
+            if (v_type_i != v_type_i_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
+                return false;
+            }
+
+            // Read element size of value
+            uint32_t v_size_el_ref;
+            io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
+            const size_t v_size_el = ggml_type_size(v_l[il]->type);
+            if (v_size_el != v_size_el_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
+                return false;
+            }
+
+            // Read GQA embedding size
+            uint32_t n_embd_v_gqa_ref;
+            io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
+            if (n_embd_v_gqa != n_embd_v_gqa_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
+                return false;
+            }
+
+            if (cell_count) {
+                // For each row in the transposed matrix, read the values for the whole cell range
+                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                    const size_t dst_offset = (head + j * size) * v_size_el;
+                    ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
+                }
+            }
+        }
+    }
+
+    return true;
+}
+
+//
+// llama_kv_cache_recurrent
+//
+
+llama_kv_cache_recurrent::llama_kv_cache_recurrent(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
+}
+
+bool llama_kv_cache_recurrent::init(
+        const llama_model & model,
+      const llama_cparams & cparams,
+                ggml_type   type_k,
+                ggml_type   type_v,
+                 uint32_t   kv_size,
+                     bool   offload) {
+    GGML_UNUSED(cparams);
+
+    const int32_t n_layer = hparams.n_layer;
+
+    GGML_ASSERT(llama_model_is_recurrent(&model));
+
+    LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d\n",
+            __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
+
+    head = 0;
+    size = kv_size;
+    used = 0;
+
+    this->type_k = type_k;
+    this->type_v = type_v;
+
+    cells.clear();
+    cells.resize(kv_size);
+
+    // create a context for each buffer type
+    std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
+    auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
+        auto it = ctx_map.find(buft);
+        if (it == ctx_map.end()) {
+            ggml_init_params params = {
+                /*.mem_size   =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
+                /*.mem_buffer =*/ NULL,
+                /*.no_alloc   =*/ true,
+            };
+
+            ggml_context * ctx = ggml_init(params);
+            if (!ctx) {
+                return nullptr;
+            }
+
+            ctx_map[buft] = ctx;
+            ctxs.emplace_back(ctx);
+
+            return ctx;
+        }
+
+        return it->second;
+    };
+
+    k_l.reserve(n_layer);
+    v_l.reserve(n_layer);
+
+    for (int i = 0; i < n_layer; i++) {
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
+        const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
+
+        const char * dev_name = "CPU";
+
+        ggml_backend_buffer_type_t buft;
+        if (offload) {
+            auto * dev = model.dev_layer(i);
+            buft = ggml_backend_dev_buffer_type(dev);
+
+            dev_name = ggml_backend_dev_name(dev);
+        } else {
+            buft = ggml_backend_cpu_buffer_type();
+        }
+
+        LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k_gqa = %d, n_embd_v_gqa = %d, dev = %s\n", __func__,
+                i, n_embd_k_gqa, n_embd_v_gqa, dev_name);
+
+        ggml_context * ctx = ctx_for_buft(buft);
+        if (!ctx) {
+            LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
+            return false;
+        }
+
+        ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
+        ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
+        ggml_format_name(k, "cache_k_l%d", i);
+        ggml_format_name(v, "cache_v_l%d", i);
+        k_l.push_back(k);
+        v_l.push_back(v);
+    }
+
+    // allocate tensors and initialize the buffers to avoid NaNs in the padding
+    for (auto it : ctx_map) {
+        auto * buft = it.first;
+        auto * ctx  = it.second;
+
+        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
+        if (!buf) {
+            LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__);
+            return false;
+        }
+        ggml_backend_buffer_clear(buf, 0);
+        LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
+        bufs.emplace_back(buf);
+    }
+
+    return true;
+}
+
+int32_t llama_kv_cache_recurrent::get_n_tokens() const {
+    int32_t result = 0;
+
+    for (uint32_t i = 0; i < size; i++) {
+        result += cells[i].seq_id.size();
+    }
+
+    return result;
+}
+
+int32_t llama_kv_cache_recurrent::get_used_cells() const {
+    return used;
+}
+
+bool llama_kv_cache_recurrent::get_has_shift() const {
+    return false;
+}
+
+bool llama_kv_cache_recurrent::get_do_defrag() const {
+    return false;
+}
+
+size_t llama_kv_cache_recurrent::total_size() const {
+    size_t size = 0;
+    for (const auto & buf : bufs) {
+        size += ggml_backend_buffer_get_size(buf.get());
+    }
+
+    return size;
+}
+
+llama_pos llama_kv_cache_recurrent::get_pos_max() const {
+    llama_pos pos_max = -1;
+    for (const auto & cell : cells) {
+        pos_max = std::max(pos_max, cell.pos);
+    }
+
+    return pos_max;
+}
+
+void llama_kv_cache_recurrent::clear() {
+    for (int32_t i = 0; i < (int32_t) size; ++i) {
+        cells[i].pos = -1;
+        cells[i].seq_id.clear();
+        cells[i].src = -1;
+        cells[i].tail = -1;
+    }
+    head = 0;
+    used = 0;
+
+    for (auto & buf : bufs) {
+        ggml_backend_buffer_clear(buf.get(), 0);
+    }
+}
+
+bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
+    if (p0 < 0) {
+        p0 = 0;
+    }
+
+    if (p1 < 0) {
+        p1 = std::numeric_limits<llama_pos>::max();
+    }
+
+    // models like Mamba or RWKV can't have a state partially erased
+    if (seq_id >= (int64_t) size) {
+        // could be fatal
+        return false;
+    }
+    if (0 <= seq_id) {
+        int32_t & tail_id = cells[seq_id].tail;
+        if (tail_id >= 0) {
+            const llama_kv_cell & cell = cells[tail_id];
+            // partial intersection is invalid
+            if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
+                return false;
+            }
+            // invalidate tails which will be cleared
+            if (p0 <= cell.pos && cell.pos < p1) {
+                tail_id = -1;
+            }
+        }
+    } else {
+        // seq_id is negative, then the range should include everything or nothing
+        if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
+            return false;
+        }
+    }
+
+    return true;
+}
+
+void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
+    if (seq_id_src == seq_id_dst) {
+        return;
+    }
+
+    if (p0 < 0) {
+        p0 = 0;
+    }
+
+    if (p1 < 0) {
+        p1 = std::numeric_limits<llama_pos>::max();
+    }
+
+    if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
+        llama_kv_cell & tail_src = cells[seq_id_src];
+        llama_kv_cell & tail_dst = cells[seq_id_dst];
+        if (tail_dst.tail >= 0) {
+            // clear destination seq_id if it wasn't empty
+            llama_kv_cell & cell_dst = cells[tail_dst.tail];
+
+            cell_dst.seq_id.erase(seq_id_dst);
+            tail_dst.tail = -1;
+            if (cell_dst.seq_id.empty()) {
+                cell_dst.pos = -1;
+                cell_dst.delta = -1;
+                cell_dst.src = -1;
+                used -= 1;
+            }
+        }
+        if (tail_src.tail >= 0) {
+            llama_kv_cell & cell_src = cells[tail_src.tail];
+
+            cell_src.seq_id.insert(seq_id_dst);
+            tail_dst.tail = tail_src.tail;
+        }
+    }
+}
+
+void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
+    uint32_t new_head = size;
+
+    for (uint32_t i = 0; i < size; ++i) {
+        if ((llama_seq_id) i != seq_id) {
+            cells[i].tail = -1;
+        }
+
+        if (!cells[i].has_seq_id(seq_id)) {
+            if (cells[i].pos >= 0) {
+                used--;
+            }
+
+            cells[i].pos = -1;
+            cells[i].src = -1;
+            cells[i].seq_id.clear();
+
+            if (new_head == size){
+                new_head = i;
+            }
+        } else {
+            cells[i].seq_id.clear();
+            cells[i].seq_id.insert(seq_id);
+        }
+    }
+
+    // If we freed up a slot, set head to it so searching can start there.
+    if (new_head != size && new_head < head) {
+        head = new_head;
+    }
+}
+
+void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
+    if (delta == 0) {
+        return;
+    }
+
+    if (p0 < 0) {
+        p0 = 0;
+    }
+
+    if (p1 < 0) {
+        p1 = std::numeric_limits<llama_pos>::max();
+    }
+
+    // If there is no range then return early to avoid looping over the
+    if (p0 == p1) {
+        return;
+    }
+
+    // for Mamba-like or RWKV models, only the pos needs to be shifted
+    if (0 <= seq_id && seq_id < (int64_t) size) {
+        const int32_t tail_id = cells[seq_id].tail;
+        if (tail_id >= 0) {
+            llama_kv_cell & cell = cells[tail_id];
+            if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
+                cell.pos += delta;
+            }
+        }
+    }
+}
+
+void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
+    if (d == 1) {
+        return;
+    }
+
+    if (p0 < 0) {
+        p0 = 0;
+    }
+
+    if (p1 < 0) {
+        p1 = std::numeric_limits<llama_pos>::max();
+    }
+
+    // If there is no range then return early to avoid looping over the cache.
+    if (p0 == p1) {
+        return;
+    }
+
+    // for Mamba-like or RWKV models, only the pos needs to be changed
+    if (0 <= seq_id && seq_id < (int64_t) size) {
+        const int32_t tail_id = cells[seq_id].tail;
+        if (tail_id >= 0) {
+            llama_kv_cell & cell = cells[tail_id];
+            if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
+                cell.pos /= d;
+            }
+        }
+    }
+}
+
+llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
+    llama_pos result = 0;
+
+    for (uint32_t i = 0; i < size; ++i) {
+        if (cells[i].has_seq_id(seq_id)) {
+            result = std::max(result, cells[i].pos);
+        }
+    }
+
+    return result;
+}
+
+void llama_kv_cache_recurrent::defrag() {
+    LLAMA_LOG_ERROR("%s: not supported\n", __func__);
+}
+
+void llama_kv_cache_recurrent::restore() {
+    if (pending.ranges.empty()) {
+        return;
+    }
+
+    seq_rm(-1, -1, -1);
+}
+
+void llama_kv_cache_recurrent::commit() {
+    pending.ranges.clear();
+}
+
+bool llama_kv_cache_recurrent::get_can_shift() const {
+    return false;
+}
+
+bool llama_kv_cache_recurrent::find_slot(
+       const llama_ubatch & ubatch) {
+    const uint32_t n_tokens = ubatch.n_tokens;
+    const uint32_t n_seqs   = ubatch.n_seqs;
+
+    const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
+
+    // if we have enough unused cells before the current head ->
+    //   better to start searching from the beginning of the cache, hoping to fill it
+    if (head > used + 2*n_tokens) {
+        head = 0;
+    }
+
+    // For recurrent state architectures (like Mamba or RWKV),
+    // each cache cell can store the state for a whole sequence.
+    // A slot should be always be contiguous.
+
+    // can only process batches with an equal number of new tokens in each sequence
+    GGML_ASSERT(ubatch.equal_seqs);
+
+    int32_t min = size - 1;
+    int32_t max = 0;
+
+    // everything should fit if all seq_ids are smaller than the max
+    for (uint32_t s = 0; s < n_seqs; ++s) {
+        const uint32_t n_seq_id = ubatch.n_seq_id[s];
+        for (uint32_t j = 0; j < n_seq_id; ++j) {
+            const llama_seq_id seq_id = ubatch.seq_id[s][j];
+
+            if (seq_id < 0 || (uint32_t) seq_id >= size) {
+                // too big seq_id
+                // TODO: would it be possible to resize the cache instead?
+                LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
+                return false;
+            }
+            if (j > 0) {
+                llama_kv_cell & seq = cells[seq_id];
+                if (seq.tail >= 0) {
+                    llama_kv_cell & cell = cells[seq.tail];
+                    // clear cells from seq_ids that become shared
+                    // (should not normally happen, but let's handle it anyway)
+                    cell.seq_id.erase(seq_id);
+                    seq.tail = -1;
+                    if (cell.seq_id.empty()) {
+                        cell.pos = -1;
+                        cell.src = -1;
+                        used -= 1;
+                    }
+                }
+            }
+        }
+    }
+
+#ifndef NDEBUG
+    {
+        std::vector<int32_t> tails_verif;
+        tails_verif.assign(size, -1);
+        for (uint32_t i = 0; i < size; ++i) {
+            llama_kv_cell & cell = cells[i];
+            for (llama_seq_id seq_id : cell.seq_id) {
+                if (tails_verif[seq_id] != -1) {
+                    LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
+                }
+                tails_verif[seq_id] = i;
+            }
+        }
+        for (uint32_t i = 0; i < size; ++i) {
+            if (tails_verif[i] != cells[i].tail) {
+                LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]);
+            }
+        }
+    }
+#endif
+
+    // find next empty cell
+    uint32_t next_empty_cell = head;
+
+    for (uint32_t i = 0; i < size; ++i) {
+        if (next_empty_cell >= size) { next_empty_cell -= size; }
+        llama_kv_cell & cell = cells[next_empty_cell];
+        if (cell.is_empty()) { break; }
+        next_empty_cell += 1;
+    }
+
+    // find usable cell range
+    for (uint32_t s = 0; s < n_seqs; ++s) {
+        const llama_seq_id seq_id = ubatch.seq_id[s][0];
+        llama_kv_cell & seq_meta = cells[seq_id];
+        bool has_cell = false;
+        if (seq_meta.tail >= 0) {
+            llama_kv_cell & cell = cells[seq_meta.tail];
+            GGML_ASSERT(cell.has_seq_id(seq_id));
+            // does this seq_id "own" the cell?
+            if (cell.seq_id.size() == 1) { has_cell = true; }
+        }
+        if (!has_cell) {
+            llama_kv_cell & empty_cell = cells[next_empty_cell];
+            GGML_ASSERT(empty_cell.is_empty());
+            // copy old tail into the empty cell
+            if (seq_meta.tail >= 0) {
+                llama_kv_cell & orig_cell = cells[seq_meta.tail];
+                empty_cell.pos = orig_cell.pos;
+                empty_cell.src = orig_cell.src;
+                orig_cell.seq_id.erase(seq_id);
+                empty_cell.seq_id.insert(seq_id); // will be overwritten
+            }
+            seq_meta.tail = next_empty_cell;
+            // find next empty cell
+            if (s + 1 < n_seqs) {
+                next_empty_cell += 1;
+                for (uint32_t i = 0; i < size; ++i) {
+                    if (next_empty_cell >= size) { next_empty_cell -= size; }
+                    llama_kv_cell & cell = cells[next_empty_cell];
+                    if (cell.is_empty()) { break; }
+                    next_empty_cell += 1;
+                }
+            }
+        }
+        if (min > seq_meta.tail) { min = seq_meta.tail; }
+        if (max < seq_meta.tail) { max = seq_meta.tail; }
+    }
+
+    // gather and re-order
+    for (uint32_t s = 0; s < n_seqs; ++s) {
+        int32_t dst_id = s + min;
+        int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
+        if (dst_id != src_id) {
+            llama_kv_cell & dst_cell = cells[dst_id];
+            llama_kv_cell & src_cell = cells[src_id];
+
+            std::swap(dst_cell.pos, src_cell.pos);
+            std::swap(dst_cell.src, src_cell.src);
+            std::swap(dst_cell.seq_id, src_cell.seq_id);
+
+            // swap tails (assuming they NEVER overlap)
+            for (const llama_seq_id seq_id : src_cell.seq_id) {
+                cells[seq_id].tail = src_id;
+            }
+            for (const llama_seq_id seq_id : dst_cell.seq_id) {
+                cells[seq_id].tail = dst_id;
+            }
+        }
+    }
+
+    // update the pos of the used seqs
+    for (uint32_t s = 0; s < n_seqs; ++s) {
+        const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
+        int32_t cell_id = s + min;
+        llama_kv_cell & cell = cells[cell_id];
+
+        if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
+            // What should happen when the pos backtracks or skips a value?
+            // Clearing the state mid-batch would require special-casing which isn't done.
+            LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
+                __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
+        }
+        cell.pos = last_pos;
+        cell.seq_id.clear();
+        for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
+            const llama_seq_id seq_id = ubatch.seq_id[s][j];
+            cell.seq_id.insert(seq_id);
+            cells[seq_id].tail = cell_id;
+        }
+    }
+
+    // allow getting the range of used cells, from head to head + n
+    head = min;
+    n    = max - min + 1;
+    used = std::count_if(cells.begin(), cells.end(),
+        [](const llama_kv_cell& cell){ return !cell.is_empty(); });
+
+    // sanity check
+    return n >= n_seqs;
+}
+
+uint32_t llama_kv_cache_recurrent::get_padding(const llama_cparams & cparams) const {
+    // the FA kernels require padding to avoid extra runtime boundary checks
+    return cparams.flash_attn ? 256u : 32u;
+}
+
+uint32_t llama_kv_cache_recurrent::cell_max() const {
+    for (uint32_t i = size; i > 0; --i) {
+        const llama_kv_cell & cell = cells[i - 1];
+
+        if (cell.pos >= 0 && !cell.is_empty()) {
+            return i;
+        }
+    }
+
+    return 0;
+}
+
+void llama_kv_cache_recurrent::set_full() {
+    n = size;
+}
+
+size_t llama_kv_cache_recurrent::size_k_bytes() const {
+    size_t size_k_bytes = 0;
+
+    for (const auto & k : k_l) {
+        size_k_bytes += ggml_nbytes(k);
+    }
+
+    return size_k_bytes;
+}
+
+size_t llama_kv_cache_recurrent::size_v_bytes() const {
+    size_t size_v_bytes = 0;
+
+    for (const auto & v : v_l) {
+        size_v_bytes += ggml_nbytes(v);
+    }
+
+    return size_v_bytes;
+}
+
+void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
+    std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
+    uint32_t cell_count = 0;
+
+    // Count the number of cells with the specified seq_id
+    // Find all the ranges of cells with this seq id (or all, when -1)
+    uint32_t cell_range_begin = size;
+    for (uint32_t i = 0; i < size; ++i) {
+        const auto & cell = cells[i];
+        if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
+            ++cell_count;
+            if (cell_range_begin == size) {
+                cell_range_begin = i;
+            }
+        } else {
+            if (cell_range_begin != size) {
+                cell_ranges.emplace_back(cell_range_begin, i);
+                cell_range_begin = size;
+            }
+        }
+    }
+    if (cell_range_begin != size) {
+        cell_ranges.emplace_back(cell_range_begin, size);
+    }
+
+    // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
+    uint32_t cell_count_check = 0;
+    for (const auto & range : cell_ranges) {
+        cell_count_check += range.second - range.first;
+    }
+    GGML_ASSERT(cell_count == cell_count_check);
+
+    io.write(&cell_count, sizeof(cell_count));
+
+    state_write_meta(io, cell_ranges, seq_id);
+    state_write_data(io, cell_ranges);
+}
+
+void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
+    uint32_t cell_count;
+    io.read_to(&cell_count, sizeof(cell_count));
+
+    bool res = true;
+    res = res && state_read_meta(io, cell_count, seq_id);
+    res = res && state_read_data(io, cell_count);
+
+    if (!res) {
+        if (seq_id == -1) {
+            clear();
+        } else {
+            seq_rm(seq_id, -1, -1);
+        }
+        throw std::runtime_error("failed to restore kv cache");
+    }
+}
+
+void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
+    for (const auto & range : cell_ranges) {
+        for (uint32_t i = range.first; i < range.second; ++i) {
+            const auto & cell = cells[i];
+            const llama_pos pos      = cell.pos;
+            const uint32_t  n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
+
+            io.write(&pos,      sizeof(pos));
+            io.write(&n_seq_id, sizeof(n_seq_id));
+
+            if (n_seq_id) {
+                for (auto seq_id : cell.seq_id) {
+                    io.write(&seq_id, sizeof(seq_id));
+                }
+            }
+        }
+    }
+}
+
+void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
+    const uint32_t v_trans = 0;
+    const uint32_t n_layer = hparams.n_layer;
+
+    io.write(&v_trans, sizeof(v_trans));
+    io.write(&n_layer, sizeof(n_layer));
+
+    std::vector<uint8_t> tmp_buf;
+
+    // Iterate and write all the keys first, each row is a cell
+    // Get whole range at a time
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+
+        // Write key type
+        const int32_t k_type_i = (int32_t)k_l[il]->type;
+        io.write(&k_type_i, sizeof(k_type_i));
+
+        // Write row size of key
+        const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
+        io.write(&k_size_row, sizeof(k_size_row));
+
+        // Read each range of cells of k_size length each into tmp_buf and write out
+        for (const auto & range : cell_ranges) {
+            const size_t range_size = range.second - range.first;
+            const size_t buf_size = range_size * k_size_row;
+            io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
+        }
+    }
+
+    if (!v_trans) {
+        for (uint32_t il = 0; il < n_layer; ++il) {
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
+            // Write value type
+            const int32_t v_type_i = (int32_t)v_l[il]->type;
+            io.write(&v_type_i, sizeof(v_type_i));
+
+            // Write row size of value
+            const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
+            io.write(&v_size_row, sizeof(v_size_row));
+
+            // Read each range of cells of v_size length each into tmp_buf and write out
+            for (const auto & range : cell_ranges) {
+                const size_t range_size = range.second - range.first;
+                const size_t buf_size = range_size * v_size_row;
+                io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
+            }
+        }
+    } else {
+        // When v is transposed, we also need the element size and get the element ranges from each row
+        const uint32_t kv_size = size;
+        for (uint32_t il = 0; il < n_layer; ++il) {
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
+            // Write value type
+            const int32_t v_type_i = (int32_t)v_l[il]->type;
+            io.write(&v_type_i, sizeof(v_type_i));
+
+            // Write element size
+            const uint32_t v_size_el = ggml_type_size(v_l[il]->type);
+            io.write(&v_size_el, sizeof(v_size_el));
+
+            // Write GQA embedding size
+            io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
+
+            // For each row, we get the element values of each cell
+            for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                // Read each range of cells of v_size_el length each into tmp_buf and write out
+                for (const auto & range : cell_ranges) {
+                    const size_t range_size = range.second - range.first;
+                    const size_t src_offset = (range.first + j * kv_size) * v_size_el;
+                    const size_t buf_size = range_size * v_size_el;
+                    io.write_tensor(v_l[il], src_offset, buf_size);
+                }
+            }
+        }
+    }
+}
+
+bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
+    if (dest_seq_id != -1) {
+        // single sequence
+
+        seq_rm(dest_seq_id, -1, -1);
+
+        llama_sbatch sbatch;
+        llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
+
+        batch.n_tokens = cell_count;
+        batch.n_seq_tokens = cell_count;
+        batch.n_seqs = 1;
+
+        for (uint32_t i = 0; i < cell_count; ++i) {
+            llama_pos pos;
+            uint32_t n_seq_id;
+
+            io.read_to(&pos,      sizeof(pos));
+            io.read_to(&n_seq_id, sizeof(n_seq_id));
+
+            if (n_seq_id != 0) {
+                LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
+                return false;
+            }
+
+            batch.pos[i] = pos;
+        }
+        batch.n_seq_id[0] = 1;
+        batch.seq_id[0] = &dest_seq_id;
+        if (!find_slot(batch)) {
+            LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
+            return false;
+        }
+        commit();
+
+        // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
+        // Assume that this is one contiguous block of cells
+        GGML_ASSERT(head + cell_count <= size);
+        GGML_ASSERT(cells[head].pos == batch.pos[0]);
+        GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
+        GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
+        GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
+    } else {
+        // whole KV cache restore
+
+        if (cell_count > size) {
+            LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
+            return false;
+        }
+
+        clear();
+
+        for (uint32_t i = 0; i < cell_count; ++i) {
+            llama_kv_cell & cell = cells[i];
+
+            llama_pos pos;
+            uint32_t  n_seq_id;
+
+            io.read_to(&pos,      sizeof(pos));
+            io.read_to(&n_seq_id, sizeof(n_seq_id));
+
+            cell.pos = pos;
+
+            for (uint32_t j = 0; j < n_seq_id; ++j) {
+                llama_seq_id seq_id;
+                io.read_to(&seq_id, sizeof(seq_id));
+
+                // TODO: llama_kv_cache_recurrent should have a notion of max sequences
+                //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
+                if (seq_id < 0) {
+                    //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
+                    LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
+                    return false;
+                }
+
+                cell.seq_id.insert(seq_id);
+
+                int32_t & tail = cells[seq_id].tail;
+                if (tail != -1) {
+                    LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
+                    return false;
+                }
+                tail = i;
+            }
+        }
+
+        head = 0;
+        used = cell_count;
+    }
+
+    for (uint32_t i = 0; i < cell_count; ++i) {
+        uint32_t cell_id = head + i;
+        // make sure the recurrent states will keep their restored state
+        cells[cell_id].src = cell_id;
+    }
+
+    return true;
+}
+
+bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
+    uint32_t v_trans;
+    uint32_t n_layer;
+    io.read_to(&v_trans, sizeof(v_trans));
+    io.read_to(&n_layer, sizeof(n_layer));
+
+    if (n_layer != hparams.n_layer) {
+        LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
+        return false;
+    }
+    if (cell_count > size) {
+        LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
+        return false;
+    }
+    if (false != (bool) v_trans) {
         LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
         return false;
     }
diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h
index 56c74035a..a2d88c9cc 100644
--- a/src/llama-kv-cache.h
+++ b/src/llama-kv-cache.h
@@ -15,17 +15,44 @@ struct llama_hparams;
 struct llama_ubatch;
 
 struct llama_kv_cache : public llama_memory_i {
+    virtual ~llama_kv_cache() = default;
+
     using llama_memory_i::llama_memory_i;
 
+    // TODO: become constructor
+    virtual bool init(
+            const llama_model & model,   // TODO: do not reference the model
+          const llama_cparams & cparams,
+                    ggml_type   type_k,
+                    ggml_type   type_v,
+                     uint32_t   kv_size,
+                         bool   offload) = 0;
+
     virtual void restore() = 0; // call if batch processing fails - restores the cache state
     virtual void commit() = 0;  // call after successful batch processing - clears any pending state
 
     virtual int32_t get_n_tokens()   const = 0;
     virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
 
+    virtual bool get_has_shift() const = 0;
+    virtual bool get_do_defrag() const = 0;
+
+    virtual llama_pos get_pos_max() const = 0;
+
     virtual bool get_can_shift() const = 0;
 
     bool get_can_edit() const override { return get_can_shift(); }
+
+    virtual bool find_slot(const llama_ubatch & batch) = 0;
+
+    // simulate full cache, used for allocating worst-case compute buffers
+    virtual void set_full() = 0;
+
+    virtual size_t size_k_bytes() const = 0;
+    virtual size_t size_v_bytes() const = 0;
+
+    virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
+    virtual void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1) = 0;
 };
 
 struct llama_kv_cache_guard {
@@ -74,12 +101,6 @@ public:
         std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
     };
 
-    llama_kv_cache_unified(
-            const llama_hparams & hparams,
-            callbacks             cbs);
-
-    virtual ~llama_kv_cache_unified() = default;
-
     // TODO: become constructor
     bool init(
             const llama_model & model,   // TODO: do not reference the model
@@ -87,21 +108,30 @@ public:
                     ggml_type   type_k,
                     ggml_type   type_v,
                      uint32_t   kv_size,
-                         bool   offload);
+                         bool   offload) override;
+
+    llama_kv_cache_unified(
+            const llama_hparams & hparams,
+            callbacks             cbs);
+
+    ~llama_kv_cache_unified() = default;
 
     int32_t get_n_tokens()   const override;
     int32_t get_used_cells() const override;
 
+    bool get_has_shift() const override;
+    bool get_do_defrag() const override;
+
     size_t total_size() const;
 
     // TODO: better data structures to reduce the cost of this operation
-    llama_pos pos_max() const;
+    llama_pos get_pos_max() const override;
 
     void clear() override;
     void defrag() override;
 
-    virtual void restore() override;
-    virtual void commit() override;
+    void restore() override;
+    void commit() override;
 
     bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
     void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
@@ -117,7 +147,7 @@ public:
     // updates the cache head
     // Note: On success, it's important that cache.head points
     // to the first cell of the slot.
-    bool find_slot(const llama_ubatch & batch);
+    bool find_slot(const llama_ubatch & batch) override;
 
     // TODO: maybe not needed
     uint32_t get_padding(const llama_cparams & cparams) const;
@@ -125,8 +155,10 @@ public:
     // find how many cells are currently in use
     uint32_t cell_max() const;
 
-    size_t size_k_bytes() const;
-    size_t size_v_bytes() const;
+    void set_full() override;
+
+    size_t size_k_bytes() const override;
+    size_t size_v_bytes() const override;
 
     // defrag
 
@@ -151,8 +183,8 @@ public:
 
     // state write/load
 
-    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
-    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1);
+    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
+    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1) override;
 
     // members
 
@@ -163,9 +195,6 @@ public:
     bool has_shift = false;
     bool do_defrag = false;
 
-    // TODO: remove this and implement llama_kv_cache_recurrent instead
-    bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
-
     bool v_trans   = true;  // the value tensor is transposed
     bool can_shift = false;
 
@@ -198,11 +227,124 @@ private:
     bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
 };
 
-// TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified
-//class llama_kv_cache_recurrent : public llama_kv_cache_unified {
-//public:
-//    using llama_kv_cache_unified::llama_kv_cache_unified;
-//};
+class llama_kv_cache_recurrent : public llama_kv_cache {
+public:
+    // can be used to query data from the model if needed
+    struct callbacks {
+        std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
+    };
+
+    llama_kv_cache_recurrent(
+            const llama_hparams & hparams,
+            callbacks             cbs);
+
+    ~llama_kv_cache_recurrent() = default;
+
+    // TODO: become constructor
+    bool init(
+            const llama_model & model,   // TODO: do not reference the model
+          const llama_cparams & cparams,
+                    ggml_type   type_k,
+                    ggml_type   type_v,
+                     uint32_t   kv_size,
+                         bool   offload) override;
+
+    int32_t get_n_tokens()   const override;
+    int32_t get_used_cells() const override;
+
+    bool get_has_shift() const override;
+    bool get_do_defrag() const override;
+
+    size_t total_size() const;
+
+    // TODO: better data structures to reduce the cost of this operation
+    llama_pos get_pos_max() const override;
+
+    void clear() override;
+    void defrag() override;
+
+    void restore() override;
+    void commit() override;
+
+    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
+    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
+    void seq_keep(llama_seq_id seq_id) override;
+    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) override;
+    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
+
+    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
+
+    bool get_can_shift() const override;
+
+    // find an empty slot of size "n_tokens" in the cache
+    // updates the cache head
+    // Note: On success, it's important that cache.head points
+    // to the first cell of the slot.
+    bool find_slot(const llama_ubatch & batch) override;
+
+    // TODO: maybe not needed
+    uint32_t get_padding(const llama_cparams & cparams) const;
+
+    // find how many cells are currently in use
+    uint32_t cell_max() const;
+
+    void set_full() override;
+
+    size_t size_k_bytes() const override;
+    size_t size_v_bytes() const override;
+
+    // commit/restore cache
+
+    struct slot_range {
+        uint32_t c0 = 0; // note: these are cell indices, not sequence positions
+        uint32_t c1 = 0;
+    };
+
+    // pending cell updates that are not yet committed
+    struct {
+        std::vector<slot_range> ranges;
+    } pending;
+
+    // state write/load
+
+    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
+    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1) override;
+
+    // members
+
+    const llama_hparams & hparams;
+
+    callbacks cbs;
+
+    // Note: The value of head isn't only used to optimize searching
+    // for a free KV slot. llama_decode_impl also uses it, so it
+    // cannot be freely changed after a slot has been allocated.
+    uint32_t head = 0;
+    uint32_t size = 0;
+    uint32_t used = 0; // used cells (i.e. at least one seq_id)
+
+    // computed before each graph build
+    uint32_t n = 0;
+
+    std::vector<llama_kv_cell> cells;
+
+    std::vector<ggml_tensor *> k_l; // per layer
+    std::vector<ggml_tensor *> v_l;
+
+private:
+    ggml_type type_k = GGML_TYPE_F16;
+    ggml_type type_v = GGML_TYPE_F16;
+
+    std::vector<ggml_context_ptr>        ctxs;
+    std::vector<ggml_backend_buffer_ptr> bufs;
+
+    void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
+    void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
+
+    bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
+    bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
+};
+
 
 //
 // kv cache view
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index ca6e3ab2c..124cc6797 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -8128,7 +8128,7 @@ struct llm_build_mamba : public llm_graph_context {
              ggml_tensor * state_mask,
       const llama_ubatch & ubatch,
                      int   il) const {
-        const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+        const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
 
         const auto kv_head = kv_self->head;
 
@@ -10691,7 +10691,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
             ggml_tensor * state_mask,
             const llama_ubatch & ubatch,
             int   il) const {
-        const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+        const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
 
         const auto n_tokens = ubatch.n_tokens;
         const auto n_seqs = ubatch.n_seqs;
@@ -11087,7 +11087,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
             ggml_tensor *& first_layer_value,
             const llama_ubatch & ubatch,
             int   il) const {
-        const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+        const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
 
         const auto n_tokens = ubatch.n_tokens;
         const auto n_seqs = ubatch.n_seqs;
@@ -12057,7 +12057,7 @@ llama_memory_i * llama_model::create_memory() const {
         case LLM_ARCH_RWKV7:
         case LLM_ARCH_ARWKV7:
             {
-                res = new llama_kv_cache_unified(hparams, {
+                res = new llama_kv_cache_recurrent(hparams, {
                     /*.get_rope_factors =*/ nullptr
                 });
             } break;