llama : refactor kv cache guard (#12695)

* llama : refactor kv cache guard

ggml-ci

* cont : fix comment [no ci]

* llama : fix kv_cache restore logic

ggml-ci

* context : simplify kv cache updates

ggml-ci

* cont : better name [no ci]

* llama : fix llama_decode return code when could not find KV slot

ggml-ci

* context : change log err -> warn [no ci]

* kv-cache : add comment + warning
This commit is contained in:
Georgi Gerganov 2025-04-02 14:32:59 +03:00 committed by GitHub
parent 83a88bd6af
commit a10b36c91a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 107 additions and 127 deletions

View File

@ -106,6 +106,8 @@ int main(int argc, char ** argv) {
common_params params;
params.n_predict = 128;
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PARALLEL)) {
return 1;
}

View File

@ -1201,33 +1201,7 @@ int llama_context::decode(llama_batch & inp_batch) {
const int64_t n_tokens_all = batch.n_tokens;
const int64_t n_embd = hparams.n_embd;
// TODO: remove this stuff
class batch_guard {
public:
batch_guard(llama_kv_cache_unified & kv_self) : kv_slot_restorer(kv_self) {
}
~batch_guard() {
if (!is_done) {
kv_slot_restorer.restore();
}
}
void done() {
is_done = true;
}
void save(const llama_kv_cache_slot_info & slot_info) {
kv_slot_restorer.save(slot_info);
}
private:
bool is_done = false;
llama_kv_slot_restorer kv_slot_restorer;
};
batch_guard bg(*kv_self);
llama_kv_cache_guard kv_guard(kv_self.get());
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
@ -1280,6 +1254,9 @@ int llama_context::decode(llama_batch & inp_batch) {
return -2;
};
// handle any pending defrags/shifts
kv_self_update();
int64_t n_outputs_prev = 0;
while (sbatch.n_tokens > 0) {
@ -1319,22 +1296,12 @@ int llama_context::decode(llama_batch & inp_batch) {
// find KV slot
{
kv_self_update();
if (!kv_self->find_slot(ubatch)) {
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (kv_self->head > kv_self->used + 2*ubatch.n_tokens) {
kv_self->head = 0;
return 1;
}
const auto slot_info = kv_self->find_slot(ubatch);
if (!slot_info) {
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
return -3;
}
bg.save(slot_info);
if (!kv_self->recurrent) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
@ -1371,16 +1338,6 @@ int llama_context::decode(llama_batch & inp_batch) {
}
}
// update the kv ring buffer
{
kv_self->head += ubatch.n_tokens;
// Ensure kv cache head points to a valid index.
if (kv_self->head >= kv_self->size) {
kv_self->head = 0;
}
}
// plot the computation graph in dot format (for debugging purposes)
//if (n_past%100 == 0) {
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
@ -1467,7 +1424,7 @@ int llama_context::decode(llama_batch & inp_batch) {
}
// finalize the batch processing
bg.done();
kv_guard.commit();
// set output mappings
{

View File

@ -11,8 +11,6 @@
#include <map>
#include <stdexcept>
static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
}
@ -206,6 +204,8 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
return false;
}
}
return true;
}
for (uint32_t i = 0; i < size; ++i) {
@ -446,16 +446,66 @@ void llama_kv_cache_unified::defrag() {
}
}
void llama_kv_cache_unified::restore() {
if (pending.ranges.empty()) {
return;
}
// TODO: tmp - move to llama_kv_cache_recurrent
if (recurrent) {
seq_rm(-1, -1, -1);
return;
}
uint32_t new_head = size;
for (auto & range : pending.ranges) {
for (uint32_t i = range.c0; i < range.c1; ++i) {
cells[i].seq_id.clear();
// keep count of the number of used cells
if (cells[i].pos >= 0) {
used--;
}
cells[i].pos = -1;
cells[i].src = -1;
}
new_head = std::min(new_head, range.c0);
}
if (new_head != size && new_head < head) {
head = new_head;
}
}
void llama_kv_cache_unified::commit() {
if (pending.ranges.empty()) {
LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
__func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
return;
}
pending.ranges.clear();
}
bool llama_kv_cache_unified::get_can_shift() const {
return can_shift;
}
llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
bool llama_kv_cache_unified::find_slot(
const llama_ubatch & ubatch) {
const uint32_t n_tokens = ubatch.n_tokens;
const uint32_t n_seqs = ubatch.n_seqs;
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (head > used + 2*ubatch.n_tokens) {
head = 0;
}
if (recurrent) {
// For recurrent state architectures (like Mamba or RWKV),
// each cache cell can store the state for a whole sequence.
@ -477,7 +527,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
// too big seq_id
// TODO: would it be possible to resize the cache instead?
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
return llama_kv_cache_slot_info_failed;
return false;
}
if (j > 0) {
llama_kv_cell & seq = cells[seq_id];
@ -616,14 +666,14 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
[](const llama_kv_cell& cell){ return !cell.is_empty(); });
// sanity check
return llama_kv_cache_slot_info(n >= n_seqs);
return n >= n_seqs;
}
// otherwise, one cell per token.
if (n_tokens > size) {
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size);
return llama_kv_cache_slot_info_failed;
return false;
}
uint32_t n_tested = 0;
@ -651,7 +701,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
if (n_tested >= size) {
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
return llama_kv_cache_slot_info_failed;
return false;
}
}
@ -668,7 +718,9 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
used += n_tokens;
return llama_kv_cache_slot_info(head, head + n_tokens);
pending.ranges.push_back({head, head + n_tokens});
return true;
}
uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const {
@ -1033,6 +1085,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
return false;
}
commit();
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
// Assume that this is one contiguous block of cells

View File

@ -17,6 +17,9 @@ struct llama_ubatch;
struct llama_kv_cache : public llama_memory_i {
using llama_memory_i::llama_memory_i;
virtual void restore() = 0; // call if batch processing fails - restores the cache state
virtual void commit() = 0; // call after successful batch processing - clears any pending state
virtual int32_t get_n_tokens() const = 0;
virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
@ -25,9 +28,24 @@ struct llama_kv_cache : public llama_memory_i {
bool get_can_edit() const override { return get_can_shift(); }
};
struct llama_kv_cache_guard {
llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
~llama_kv_cache_guard() {
kv->restore();
}
void commit() {
kv->commit();
}
private:
llama_kv_cache * kv;
};
struct llama_kv_cell {
llama_pos pos = -1;
llama_pos delta = 0;
llama_pos delta = 0;
int32_t src = -1; // used by recurrent state models to copy states
int32_t tail = -1;
@ -46,17 +64,6 @@ struct llama_kv_cell {
}
};
// a structure holds information about the slot found in llama_kv_cache_find_slot
struct llama_kv_cache_slot_info {
std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
bool found = false; // the slot was found
explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
operator bool() const { return found; }
};
// ring-buffer of cached KV data
// TODO: pimpl
// TODO: add notion of max sequences
@ -93,6 +100,9 @@ public:
void clear() override;
void defrag() override;
virtual void restore() override;
virtual void commit() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
@ -105,10 +115,9 @@ public:
// find an empty slot of size "n_tokens" in the cache
// updates the cache head
// returns a structure holding information about the slot found
// Note: On success, it's important that cache.head points
// to the first cell of the slot.
llama_kv_cache_slot_info find_slot(const llama_ubatch & batch);
bool find_slot(const llama_ubatch & batch);
// TODO: maybe not needed
uint32_t get_padding(const llama_cparams & cparams) const;
@ -128,7 +137,19 @@ public:
// return true if cells have been moved
bool defrag_prepare(int32_t n_max_nodes);
// state save/load
// commit/restore cache
struct slot_range {
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
uint32_t c1 = 0;
};
// pending cell updates that are not yet committed
struct {
std::vector<slot_range> ranges;
} pending;
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1);
@ -183,59 +204,6 @@ private:
// using llama_kv_cache_unified::llama_kv_cache_unified;
//};
//
// kv cache restore
//
// saves the kv_cache state for future recovery.
// used to rollback llama_kv_cache_find_slot changes.
struct llama_kv_slot_restorer {
struct llama_kv_cache_state {
uint32_t head = 0;
uint32_t n = 0;
} old_state;
// for non-recurrent models only
// list of slots to restore
std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries;
bool do_restore = false;
llama_kv_cache_unified & cache;
explicit llama_kv_slot_restorer(llama_kv_cache_unified & cache) : cache(cache) {
old_state.head = cache.head;
old_state.n = cache.n;
}
// saves a slot information for future restoration
void save(const llama_kv_cache_slot_info & slot) {
if (slot) {
do_restore = true;
if (slot.boundaries.first != slot.boundaries.second) {
slot_boundaries.push_back(slot.boundaries);
}
}
}
// must be explicitly called to restore the kv_cache state
// and rollback changes from all llama_kv_cache_find_slot calls
void restore() {
if (do_restore) {
cache.head = old_state.head;
cache.n = old_state.n;
if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
cache.seq_rm(-1, -1, -1);
} else {
for (auto & slot : slot_boundaries) {
cache.seq_rm(-1, slot.first, slot.second);
}
}
}
}
};
// TODO: maybe become part of the public llama_kv_cache in the future
int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv);