mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-04-16 03:26:08 +00:00
examples : fix add_special conditions (#11311)
This commit is contained in:
parent
90d987b105
commit
9f7add1cde
@ -729,10 +729,12 @@ static int apply_chat_template(LlamaData & llama_data, const bool append) {
|
||||
|
||||
// Function to tokenize the prompt
|
||||
static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
|
||||
std::vector<llama_token> & prompt_tokens) {
|
||||
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
|
||||
std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) {
|
||||
const bool is_first = llama_get_kv_cache_used_cells(llama_data.context.get()) == 0;
|
||||
|
||||
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
|
||||
prompt_tokens.resize(n_prompt_tokens);
|
||||
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true,
|
||||
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first,
|
||||
true) < 0) {
|
||||
printe("failed to tokenize the prompt\n");
|
||||
return -1;
|
||||
@ -778,7 +780,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
|
||||
const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get());
|
||||
|
||||
std::vector<llama_token> tokens;
|
||||
if (tokenize_prompt(vocab, prompt, tokens) < 0) {
|
||||
if (tokenize_prompt(vocab, prompt, tokens, llama_data) < 0) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
@ -95,13 +95,15 @@ int main(int argc, char ** argv) {
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
|
||||
|
||||
// helper function to evaluate a prompt and generate a response
|
||||
auto generate = [&](const std::string & prompt, bool is_first) {
|
||||
auto generate = [&](const std::string & prompt) {
|
||||
std::string response;
|
||||
|
||||
const bool is_first = llama_get_kv_cache_used_cells(ctx) == 0;
|
||||
|
||||
// tokenize the prompt
|
||||
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
|
||||
std::vector<llama_token> prompt_tokens(n_prompt_tokens);
|
||||
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_get_kv_cache_used_cells(ctx) == 0, true) < 0) {
|
||||
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first, true) < 0) {
|
||||
GGML_ABORT("failed to tokenize the prompt\n");
|
||||
}
|
||||
|
||||
@ -180,7 +182,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// generate a response
|
||||
printf("\033[33m");
|
||||
std::string response = generate(prompt, prev_len == 0);
|
||||
std::string response = generate(prompt);
|
||||
printf("\n\033[0m");
|
||||
|
||||
// add the response to the messages
|
||||
|
Loading…
x
Reference in New Issue
Block a user