diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index e0e6da631..80698518e 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -106,6 +106,8 @@ int main(int argc, char ** argv) { common_params params; + params.n_predict = 128; + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PARALLEL)) { return 1; } diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 3479a8cca..7d067afbe 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1201,33 +1201,7 @@ int llama_context::decode(llama_batch & inp_batch) { const int64_t n_tokens_all = batch.n_tokens; const int64_t n_embd = hparams.n_embd; - // TODO: remove this stuff - class batch_guard { - public: - batch_guard(llama_kv_cache_unified & kv_self) : kv_slot_restorer(kv_self) { - } - - ~batch_guard() { - if (!is_done) { - kv_slot_restorer.restore(); - } - } - - void done() { - is_done = true; - } - - void save(const llama_kv_cache_slot_info & slot_info) { - kv_slot_restorer.save(slot_info); - } - - private: - bool is_done = false; - - llama_kv_slot_restorer kv_slot_restorer; - }; - - batch_guard bg(*kv_self); + llama_kv_cache_guard kv_guard(kv_self.get()); GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT @@ -1280,6 +1254,9 @@ int llama_context::decode(llama_batch & inp_batch) { return -2; }; + // handle any pending defrags/shifts + kv_self_update(); + int64_t n_outputs_prev = 0; while (sbatch.n_tokens > 0) { @@ -1319,22 +1296,12 @@ int llama_context::decode(llama_batch & inp_batch) { // find KV slot { - kv_self_update(); + if (!kv_self->find_slot(ubatch)) { + LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens); - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (kv_self->head > kv_self->used + 2*ubatch.n_tokens) { - kv_self->head = 0; + return 1; } - const auto slot_info = kv_self->find_slot(ubatch); - if (!slot_info) { - LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__); - return -3; - } - - bg.save(slot_info); - if (!kv_self->recurrent) { // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears @@ -1371,16 +1338,6 @@ int llama_context::decode(llama_batch & inp_batch) { } } - // update the kv ring buffer - { - kv_self->head += ubatch.n_tokens; - - // Ensure kv cache head points to a valid index. - if (kv_self->head >= kv_self->size) { - kv_self->head = 0; - } - } - // plot the computation graph in dot format (for debugging purposes) //if (n_past%100 == 0) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); @@ -1467,7 +1424,7 @@ int llama_context::decode(llama_batch & inp_batch) { } // finalize the batch processing - bg.done(); + kv_guard.commit(); // set output mappings { diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 14c8933b4..7ba546c10 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -11,8 +11,6 @@ #include #include -static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false}; - llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) { } @@ -206,6 +204,8 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos return false; } } + + return true; } for (uint32_t i = 0; i < size; ++i) { @@ -446,16 +446,66 @@ void llama_kv_cache_unified::defrag() { } } +void llama_kv_cache_unified::restore() { + if (pending.ranges.empty()) { + return; + } + + // TODO: tmp - move to llama_kv_cache_recurrent + if (recurrent) { + seq_rm(-1, -1, -1); + return; + } + + uint32_t new_head = size; + + for (auto & range : pending.ranges) { + for (uint32_t i = range.c0; i < range.c1; ++i) { + cells[i].seq_id.clear(); + + // keep count of the number of used cells + if (cells[i].pos >= 0) { + used--; + } + + cells[i].pos = -1; + cells[i].src = -1; + } + + new_head = std::min(new_head, range.c0); + } + + if (new_head != size && new_head < head) { + head = new_head; + } +} + +void llama_kv_cache_unified::commit() { + if (pending.ranges.empty()) { + LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n", + __func__, "https://github.com/ggml-org/llama.cpp/pull/12695"); + return; + } + + pending.ranges.clear(); +} + bool llama_kv_cache_unified::get_can_shift() const { return can_shift; } -llama_kv_cache_slot_info llama_kv_cache_unified::find_slot( +bool llama_kv_cache_unified::find_slot( const llama_ubatch & ubatch) { const uint32_t n_tokens = ubatch.n_tokens; const uint32_t n_seqs = ubatch.n_seqs; const uint32_t n_seq_tokens = ubatch.n_seq_tokens; + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (head > used + 2*ubatch.n_tokens) { + head = 0; + } + if (recurrent) { // For recurrent state architectures (like Mamba or RWKV), // each cache cell can store the state for a whole sequence. @@ -477,7 +527,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot( // too big seq_id // TODO: would it be possible to resize the cache instead? LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size); - return llama_kv_cache_slot_info_failed; + return false; } if (j > 0) { llama_kv_cell & seq = cells[seq_id]; @@ -616,14 +666,14 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot( [](const llama_kv_cell& cell){ return !cell.is_empty(); }); // sanity check - return llama_kv_cache_slot_info(n >= n_seqs); + return n >= n_seqs; } // otherwise, one cell per token. if (n_tokens > size) { LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size); - return llama_kv_cache_slot_info_failed; + return false; } uint32_t n_tested = 0; @@ -651,7 +701,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot( if (n_tested >= size) { //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); - return llama_kv_cache_slot_info_failed; + return false; } } @@ -668,7 +718,9 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot( used += n_tokens; - return llama_kv_cache_slot_info(head, head + n_tokens); + pending.ranges.push_back({head, head + n_tokens}); + + return true; } uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const { @@ -1033,6 +1085,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return false; } + commit(); // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) // Assume that this is one contiguous block of cells diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 0a7ff8a4e..ff0ba3540 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -17,6 +17,9 @@ struct llama_ubatch; struct llama_kv_cache : public llama_memory_i { using llama_memory_i::llama_memory_i; + virtual void restore() = 0; // call if batch processing fails - restores the cache state + virtual void commit() = 0; // call after successful batch processing - clears any pending state + virtual int32_t get_n_tokens() const = 0; virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache @@ -25,9 +28,24 @@ struct llama_kv_cache : public llama_memory_i { bool get_can_edit() const override { return get_can_shift(); } }; +struct llama_kv_cache_guard { + llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {} + + ~llama_kv_cache_guard() { + kv->restore(); + } + + void commit() { + kv->commit(); + } + +private: + llama_kv_cache * kv; +}; + struct llama_kv_cell { llama_pos pos = -1; - llama_pos delta = 0; + llama_pos delta = 0; int32_t src = -1; // used by recurrent state models to copy states int32_t tail = -1; @@ -46,17 +64,6 @@ struct llama_kv_cell { } }; -// a structure holds information about the slot found in llama_kv_cache_find_slot -struct llama_kv_cache_slot_info { - std::pair boundaries; // slot boundaries [begin, end) - bool found = false; // the slot was found - - explicit llama_kv_cache_slot_info(bool found_) : found{found_} {} - llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {} - - operator bool() const { return found; } -}; - // ring-buffer of cached KV data // TODO: pimpl // TODO: add notion of max sequences @@ -93,6 +100,9 @@ public: void clear() override; void defrag() override; + virtual void restore() override; + virtual void commit() override; + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; void seq_keep(llama_seq_id seq_id) override; @@ -105,10 +115,9 @@ public: // find an empty slot of size "n_tokens" in the cache // updates the cache head - // returns a structure holding information about the slot found // Note: On success, it's important that cache.head points // to the first cell of the slot. - llama_kv_cache_slot_info find_slot(const llama_ubatch & batch); + bool find_slot(const llama_ubatch & batch); // TODO: maybe not needed uint32_t get_padding(const llama_cparams & cparams) const; @@ -128,7 +137,19 @@ public: // return true if cells have been moved bool defrag_prepare(int32_t n_max_nodes); - // state save/load + // commit/restore cache + + struct slot_range { + uint32_t c0 = 0; // note: these are cell indices, not sequence positions + uint32_t c1 = 0; + }; + + // pending cell updates that are not yet committed + struct { + std::vector ranges; + } pending; + + // state write/load void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const; void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1); @@ -183,59 +204,6 @@ private: // using llama_kv_cache_unified::llama_kv_cache_unified; //}; -// -// kv cache restore -// - -// saves the kv_cache state for future recovery. -// used to rollback llama_kv_cache_find_slot changes. -struct llama_kv_slot_restorer { - struct llama_kv_cache_state { - uint32_t head = 0; - uint32_t n = 0; - } old_state; - - // for non-recurrent models only - // list of slots to restore - std::vector> slot_boundaries; - - bool do_restore = false; - - llama_kv_cache_unified & cache; - - explicit llama_kv_slot_restorer(llama_kv_cache_unified & cache) : cache(cache) { - old_state.head = cache.head; - old_state.n = cache.n; - } - - // saves a slot information for future restoration - void save(const llama_kv_cache_slot_info & slot) { - if (slot) { - do_restore = true; - if (slot.boundaries.first != slot.boundaries.second) { - slot_boundaries.push_back(slot.boundaries); - } - } - } - - // must be explicitly called to restore the kv_cache state - // and rollback changes from all llama_kv_cache_find_slot calls - void restore() { - if (do_restore) { - cache.head = old_state.head; - cache.n = old_state.n; - - if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased - cache.seq_rm(-1, -1, -1); - } else { - for (auto & slot : slot_boundaries) { - cache.seq_rm(-1, slot.first, slot.second); - } - } - } - } -}; - // TODO: maybe become part of the public llama_kv_cache in the future int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv);