tool-call: support Command R7B (+ return tool_plan "thoughts" in API) (#11585)

* `tool-call`: support Command R7B (w/ tool_plan return)

* `tool-call`: cleaner preservation of tokens + warn when likely bad chat template override

* `tool-call`: test cleanup / handle lazy grammar triggers
This commit is contained in:
Olivier Chafik 2025-02-02 09:25:38 +00:00 committed by GitHub
parent 69804487e0
commit bfcce4d693
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 420 additions and 56 deletions

View File

@ -16,6 +16,7 @@ std::string common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
default:
throw std::runtime_error("Unknown chat format");
}
@ -317,6 +318,79 @@ static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input)
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
}
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
common_chat_params data;
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"];
schemas.push_back({
{"type", "object"},
{"properties", {
{"tool_call_id", {
{"type", "string"},
// Command-R's template expects an integer string.
{"pattern", "^[0-9]{1,10}$"},
}},
{"tool_name", {
{"type", "string"},
{"const", function["name"]},
}},
{"parameters", function["parameters"]},
}},
{"required", json::array({"tool_call_id", "tool_name", "parameters"})},
});
});
auto schema = json {
{"type", "array"},
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
{"minItems", 1},
};
if (!inputs.parallel_tool_calls) {
schema["maxItems"] = 1;
}
builder.add_rule("root", "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\"");
}, grammar_options);
data.grammar_triggers.push_back({"<|START_ACTION|>", /* .at_start = */ false});
data.preserved_tokens = {
"<|START_RESPONSE|>",
"<|END_RESPONSE|>",
"<|START_THINKING|>",
"<|END_THINKING|>",
"<|END_ACTION|>",
};
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = COMMON_CHAT_FORMAT_COMMAND_R7B;
return data;
}
static common_chat_msg common_chat_parse_command_r7b(const std::string & input) {
static std::regex response_regex("<\\|START_RESPONSE\\|>(.*?)<\\|END_RESPONSE\\|>");
static std::regex thought_action_regex("<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|><\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>");
std::smatch match;
common_chat_msg result;
result.role = "assistant";
if (std::regex_match(input, match, response_regex)) {
result.content = match[1].str();
} else if (std::regex_match(input, match, thought_action_regex)) {
result.tool_plan = match[1].str();
auto actions_str = match[2].str();
auto actions = json::parse(actions_str);
for (const auto & action : actions) {
result.tool_calls.push_back({
/* .name = */ action["tool_name"],
/* .arguments = */ action["parameters"].dump(),
/* .id = */ action["tool_call_id"],
});
}
} else {
LOG_ERR("Failed to parse command_r output");
result.content = input;
}
return result;
}
static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties");
@ -462,6 +536,10 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
"\"<tool▁call▁begin>function<tool▁sep>" + name + "\\n```json\\n\" " + args_rule + " \"```<tool▁call▁end>\""));
});
data.grammar_triggers.push_back({"<tool▁calls▁begin>", /* .at_start = */ false});
data.preserved_tokens = {
"<tool▁sep>",
"<tool▁call▁end>",
};
builder.add_rule("root", "\"<tool▁calls▁begin>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space");
}, grammar_options);
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
@ -704,8 +782,7 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
data.grammar_triggers.push_back({"<tool_call>", /* .at_start = */ false});
// Not really a trigger but need to print this special token to get a successful parse.
data.grammar_triggers.push_back({"</tool_call>", /* .at_start = */ false});
data.preserved_tokens = { "</tool_call>" };
}, grammar_options);
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
@ -822,6 +899,9 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
if (src.find("[TOOL_CALLS]") != std::string::npos) {
return common_chat_params_init_mistral_nemo(tmpl, inputs);
}
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos) {
return common_chat_params_init_command_r7b(tmpl, inputs);
}
return common_chat_params_init_generic(tmpl, inputs);
}
@ -855,6 +935,8 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format
return common_chat_parse_hermes_2_pro(input);
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
return common_chat_parse_firefunction_v2(input);
case COMMON_CHAT_FORMAT_COMMAND_R7B:
return common_chat_parse_command_r7b(input);
default:
throw std::runtime_error("Unsupported format: " + common_chat_format_name(format));
}

View File

@ -32,6 +32,7 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
@ -42,6 +43,7 @@ struct common_chat_params {
std::string grammar;
bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_triggers;
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;
};

View File

@ -4,6 +4,7 @@
#include "llama-cpp.h"
#include <set>
#include <string>
#include <vector>
#include <sstream>
@ -163,6 +164,7 @@ struct common_params_sampling {
bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_trigger_words; // optional trigger words to trigger lazy grammar
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens.
std::set<llama_token> preserved_tokens;
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
@ -621,6 +623,7 @@ struct common_chat_msg {
std::string role;
std::string content;
std::vector<common_tool_call> tool_calls;
std::string tool_plan = "";
};
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid

View File

@ -1128,6 +1128,7 @@ curl http://localhost:8080/v1/chat/completions \
- Hermes 2/3, Qwen 2.5
- Mistral Nemo
- Firefunction v2
- Command R7B
- DeepSeek R1 (WIP / seems reluctant to call any tools?)
<details>
@ -1202,21 +1203,28 @@ curl http://localhost:8080/v1/chat/completions \
```shell
# Native support:
llama-server --jinja -fa -hf bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M
llama-server --jinja -fa -hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M
llama-server --jinja -fa -hf bartowski/Llama-3.2-3B-Instruct-GGUF:Q6_K
llama-server --jinja -fa -hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q6_K_L
llama-server --jinja -fa -hf bartowski/functionary-small-v3.2-GGUF:Q4_K_M
llama-server --jinja -fa -hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M \
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B )
llama-server --jinja -fa -hf bartowski/Llama-3.3-70B-Instruct-GGUF:Q4_K_M
# Native support requires the right template for these GGUFs:
llama-server --jinja -fa -hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M \
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use )
llama-server --jinja -fa -hf bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M \
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use )
llama-server --jinja -fa -hf bartowski/firefunction-v2-GGUF -hff firefunction-v2-IQ1_M.gguf \
--chat-template-file <( python scripts/get_chat_template.py fireworks-ai/firellama-3-firefunction-v2 )
--chat-template-file <( python scripts/get_chat_template.py fireworks-ai/llama-3-firefunction-v2 tool_use )
llama-server --jinja -fa -hf bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L \
--chat-template-file <( python scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 tool_use )
# Generic format support
llama-server --jinja -fa -hf bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M
llama-server --jinja -fa -hf bartowski/gemma-2-2b-it-GGUF:Q4_K_M
llama-server --jinja -fa -hf bartowski/phi-4-GGUF:Q4_0
llama-server --jinja -fa -hf bartowski/gemma-2-2b-it-GGUF:Q8_0
llama-server --jinja -fa -hf bartowski/c4ai-command-r-v01-GGUF:Q2_K
```
- Test in CLI:

View File

@ -131,6 +131,11 @@ struct slot_params {
lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
}
std::vector<std::string> grammar_trigger_words;
for (const auto & trigger : sampling.grammar_trigger_words) {
grammar_trigger_words.push_back(trigger.word);
}
return json {
{"n_predict", n_predict}, // Server configured n_predict
{"seed", sampling.seed},
@ -165,8 +170,9 @@ struct slot_params {
{"n_probs", sampling.n_probs},
{"min_keep", sampling.min_keep},
{"grammar", sampling.grammar},
// {"grammar_trigger_words", sampling.grammar_trigger_words},
{"grammar_trigger_words", grammar_trigger_words},
{"grammar_trigger_tokens", sampling.grammar_trigger_tokens},
{"preserved_tokens", sampling.preserved_tokens},
{"samplers", samplers},
{"speculative.n_max", speculative.n_max},
{"speculative.n_min", speculative.n_min},
@ -363,12 +369,26 @@ struct server_task {
if (ids.size() == 1) {
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
params.sampling.preserved_tokens.insert(ids[0]);
continue;
}
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
params.sampling.grammar_trigger_words.push_back(trigger);
}
}
const auto preserved_tokens = data.find("preserved_tokens");
if (preserved_tokens != data.end()) {
for (const auto & t : *preserved_tokens) {
auto ids = common_tokenize(vocab, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
LOG_DBG("Preserved token: %d\n", ids[0]);
params.sampling.preserved_tokens.insert(ids[0]);
} else {
// This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
LOG_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", t.get<std::string>().c_str());
}
}
}
if (params.sampling.grammar_lazy) {
GGML_ASSERT(params.sampling.grammar_trigger_tokens.size() > 0 || params.sampling.grammar_trigger_words.size() > 0);
}
@ -695,19 +715,19 @@ struct server_task_result_cmpl_final : server_task_result {
json to_json_oaicompat_chat() {
std::string finish_reason = "length";
common_chat_msg message;
common_chat_msg msg;
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
LOG_DBG("Parsing chat message: %s\n", content.c_str());
message = common_chat_parse(content, oaicompat_chat_format);
finish_reason = message.tool_calls.empty() ? "stop" : "tool_calls";
msg = common_chat_parse(content, oaicompat_chat_format);
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
} else {
message.content = content;
msg.content = content;
}
json tool_calls;
if (!message.tool_calls.empty()) {
if (!msg.tool_calls.empty()) {
tool_calls = json::array();
for (const auto & tc : message.tool_calls) {
for (const auto & tc : msg.tool_calls) {
tool_calls.push_back({
{"type", "function"},
{"function", {
@ -719,14 +739,19 @@ struct server_task_result_cmpl_final : server_task_result {
}
}
json message {
{"content", msg.content},
{"tool_calls", tool_calls},
{"role", "assistant"},
};
if (!msg.tool_plan.empty()) {
message["tool_plan"] = msg.tool_plan;
}
json choice {
{"finish_reason", finish_reason},
{"index", 0},
{"message", json {
{"content", message.content},
{"tool_calls", tool_calls},
{"role", "assistant"},
}},
{"message", message},
};
if (!stream && probs_output.size() > 0) {
@ -2833,8 +2858,7 @@ struct server_context {
server_slot * slot_batched = nullptr;
auto accept_special_token = [&](server_slot & slot, llama_token token) {
const auto & trigger_tokens = slot.params.sampling.grammar_trigger_tokens;
return params_base.special || std::find(trigger_tokens.begin(), trigger_tokens.end(), token) != trigger_tokens.end();
return params_base.special || slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end();
};
// frist, add sampled tokens from any ongoing sequences

View File

@ -662,6 +662,7 @@ static json oaicompat_completion_params_parse(
});
}
llama_params["grammar_triggers"] = grammar_triggers;
llama_params["preserved_tokens"] = chat_params.preserved_tokens;
for (const auto & stop : chat_params.additional_stops) {
llama_params["stop"].push_back(stop);
}

View File

@ -0,0 +1,156 @@
{{ bos_token }}{%- macro document_turn(documents) -%}
{# format documents into chat turn #}
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|>I will look through the document to address the users needs.<|END_THINKING|><|START_ACTION|>[
{"tool_call_id": "0", "tool_name": "direct-injected-document", "parameters": {}}
]<|END_ACTION|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[
{
"tool_call_id": "0",
"results": {
{% for doc in documents %}
"{{ loop.index0 }}": {{doc|tojson}}{% if not loop.last %},
{% endif %}
{% endfor %}
},
"is_error": null
}
]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>{%- endmacro %}
{%- macro tool_call_id_to_int(messages, tool_call_id) %}
{%- set counter = namespace(value=0) %}
{%- set tool_call_id_seen = namespace(value=false) %}
{%- for msg in messages %}
{%- if msg.tool_calls %}
{%- for tool_call in msg.tool_calls %}
{%- if tool_call.id == tool_call_id and not tool_call_id_seen.value -%}
{{ counter.value }}
{%- set tool_call_id_seen.value = true %}
{%- endif %}
{%- set counter.value = counter.value + 1 %}
{%- endfor %}
{%- endif %}
{%- endfor %}
{%- endmacro %}
{%- macro format_tool_message(messages, tool_msg) -%}
{# format tool message #}
{
"tool_call_id": "{{ tool_call_id_to_int(messages, tool_msg.tool_call_id) }}",
"results": {
"0": {{ tool_msg.content|tojson }}
},
"is_error": null
}
{%- endmacro -%}
{%- if messages and messages[0]['role']|lower == 'system' %}{%- set developer_preamble = messages[0]['content'] %}{% endif %}
{%- set tool_idx = namespace(value=0) %}
{%- set tool_ids_seen = namespace(value=[]) %}
{%- set sent_documents = namespace(value=false) %}
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble
You are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.
Your information cutoff date is June 2024.
You have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.
{% if tools or documents %}
You have been trained to have advanced reasoning and tool-use capabilities and you should make best use of these skills to serve user's requests.
## Tool Use
Think about how you can make best use of the provided tools to help with the task and come up with a high level plan that you will execute first.
0. Start by writing <|START_THINKING|> followed by a detailed step by step plan of how you will solve the problem. For each step explain your thinking fully and give details of required tool calls (if needed). Unless specified otherwise, you write your plan in natural language. When you finish, close it out with <|END_THINKING|>.
You can optionally choose to skip this step when the user request is so straightforward to address that only a trivial plan would be needed.
NOTE: You MUST skip this step when you are directly responding to the user's request without using any tools.
Then carry out your plan by repeatedly executing the following steps.
1. Action: write <|START_ACTION|> followed by a list of JSON-formatted tool calls, with each one containing "tool_name" and "parameters" fields.
When there are multiple tool calls which are completely independent of each other (i.e. they can be executed in parallel), you should list them out all together in one step. When you finish, close it out with <|END_ACTION|>.
2. Observation: you will then receive results of those tool calls in JSON format in the very next turn, wrapped around by <|START_TOOL_RESULT|> and <|END_TOOL_RESULT|>. Carefully observe those results and think about what to do next. Note that these results will be provided to you in a separate turn. NEVER hallucinate results.
Every tool call produces a list of results (when a tool call produces no result or a single result, it'll still get wrapped inside a list). Each result is clearly linked to its originating tool call via its "tool_call_id".
3. Reflection: start the next turn by writing <|START_THINKING|> followed by what you've figured out so far, any changes you need to make to your plan, and what you will do next. When you finish, close it out with <|END_THINKING|>.
You can optionally choose to skip this step when everything is going according to plan and no special pieces of information or reasoning chains need to be recorded.
NOTE: You MUST skip this step when you are done with tool-use actions and are ready to respond to the user.
You can repeat the above 3 steps multiple times (could be 0 times too if no suitable tool calls are available or needed), until you decide it's time to finally respond to the user.
4. Response: then break out of the loop and write <|START_RESPONSE|> followed by a piece of text which serves as a response to the user's last request. Use all previous tool calls and results to help you when formulating your response. When you finish, close it out with <|END_RESPONSE|>.
{% if enable_citations %}
## Grounding
Importantly, note that "Reflection" and "Response" above can be grounded.
Grounding means you associate pieces of texts (called "spans") with those specific tool results that support them (called "sources"). And you use a pair of tags "<co>" and "</co>" to indicate when a span can be grounded onto a list of sources, listing them out in the closing tag. Sources from the same tool call are grouped together and listed as "{tool_call_id}:[{list of result indices}]", before they are joined together by ",". E.g., "<co>span</co: 0:[1,2],1:[0]>" means that "span" is supported by result 1 and 2 from "tool_call_id=0" as well as result 0 from "tool_call_id=1".
{% endif %}
## Available Tools
Here is the list of tools that you have available to you.
You can ONLY use the tools listed here. When a tool is not listed below, it is NOT available and you should NEVER attempt to use it.
Each tool is represented as a JSON object with fields like "name", "description", "parameters" (per JSON Schema), and optionally, "responses" (per JSON Schema).
```json
[
{% if documents %}
{"name": "direct-injected-document", "description": "This is a special tool to directly inject user-uploaded documents into the chat as additional context. DO NOT use this tool by yourself!", "parameters": {"type": "object", "properties": {}, "required": []}, "responses": {"200": {"description": "Successfully returned a list of chunked text snippets from the directly uploaded documents.", "content": {"application/json": {"schema": {"type": "array", "items": {"type": "object", "required": ["url", "snippet"], "properties": {"url": {"type": "string", "description": "The url of the uploaded document."}, "snippet": {"type": "string", "description": "The text snippet for the returned document chunk."}}}}}}}}}{%- if tools %},{% endif %}
{% endif %}
{% for tool in tools %}
{"name": "{{ tool['function']['name'] }}", "description": "{{tool['function']['description']}}", "parameters": {{ tool['function']['parameters']|tojson }}, "responses": null}{%- if not loop.last %},{% endif %}
{% endfor %}
]
```
{% endif %}
# Default Preamble
The following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.
- Your name is Command.
- You are a large language model built by Cohere.
- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.
- If the input is ambiguous, ask clarifying follow-up questions.
- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).
- Use LaTeX to generate mathematical notation for complex equations.
- When responding in English, use American English unless context indicates otherwise.
- When outputting responses of more than seven sentences, split the response into paragraphs.
- Prefer the active voice.
- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.
- Use gender-neutral pronouns for unspecified persons.
- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.
- Use the third person when asked to write a summary.
- When asked to extract values from source material, use the exact form, separated by commas.
- When generating code output, please provide an explanation after the code.
- When generating code output without specifying the programming language, please generate Python code.
- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.
{%- if developer_preamble %}
# Developer Preamble
The following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.
{{ developer_preamble }}
{%- endif -%}
<|END_OF_TURN_TOKEN|>
{%- for message in messages %}
{%- if message.role|lower == 'system' and not (loop.first and developer_preamble)%}
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>
{%- elif message.role|lower == 'user' %}
<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %}
{%- elif message.role|lower == 'assistant' or message.role|lower == 'chatbot' %}
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[
{% for tc in message.tool_calls %}
{"tool_call_id": "{{ tool_idx.value }}", "tool_name": "{{ tc['function']['name'] }}", "parameters": {{ tc['function']['arguments']|tojson }}}{% if not loop.last %},{% endif %}
{% set tool_idx.value = tool_idx.value + 1 %}
{% endfor %}
]<|END_ACTION|><|END_OF_TURN_TOKEN|>{% else %}<|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>{% endif %}
{% elif message.role|lower == 'tool' and message.tool_call_id not in tool_ids_seen.value %}
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[
{{ format_tool_message(messages, message) }}
{%- for msg in messages[loop.index0 + 1:] %}
{%- if msg.role|lower == 'tool' %},
{{ format_tool_message(messages, msg) }}
{%- set tool_ids_seen.value = tool_ids_seen.value + [msg.tool_call_id] %}
{%- else %}
{%- break %}
{%- endif %}
{%- endfor %}
]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>
{%- endif %}
{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>

View File

@ -22,9 +22,13 @@ static common_chat_msg msg_from_json(const json & message) {
"assistant",
"",
{},
/* .tool_plan = */ "",
};
if (message.contains("content") && !message.at("content").is_null()) {
ret.content = message.at("content").get<std::string>();
ret.content = message.at("content");
}
if (message.contains("tool_plan")) {
ret.tool_plan = message.at("tool_plan");
}
auto has_tool_calls = message.contains("tool_calls");
if (has_tool_calls) {
@ -171,8 +175,7 @@ const json llama_3_1_tools = { special_function_tool, code_interpreter_too
struct delta_data {
std::string delta;
std::string grammar;
common_chat_format format;
common_chat_params params;
};
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
@ -214,7 +217,7 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
break;
}
}
return { delta, params_full.grammar, params_full.format };
return { delta, params_full };
}
/*
@ -224,7 +227,7 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
*/
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
const json & test_message, const json & tools = {}, const std::string & expected_delta = "",
bool skip_grammar_test = false, bool skip_parser_test = false) {
bool expect_grammar_triggered = true) {
common_chat_msg expected_msg = msg_from_json(test_message);
auto user_message = json{
@ -238,45 +241,110 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
assert_equals(expected_delta, data.delta);
}
if (!skip_parser_test) {
const auto msg = common_chat_parse(data.delta, data.format);
if (expect_grammar_triggered) {
const auto msg = common_chat_parse(data.delta, data.params.format);
assert_msg_equals(expected_msg, msg);
}
if (!expected_msg.tool_calls.empty()) {
GGML_ASSERT(!data.grammar.empty());
GGML_ASSERT(!data.params.grammar.empty());
}
if (!data.grammar.empty()) {
auto grammar = build_grammar(data.grammar);
if (!data.params.grammar.empty()) {
auto grammar = build_grammar(data.params.grammar);
if (!grammar) {
throw std::runtime_error("Failed to build grammar");
}
// TODO: exercice lazy grammars + triggers here, instead of skipping the test
if (!skip_grammar_test) {
if (!match_string(data.delta, grammar.get())) {
throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
"\n\nGrammar: " + data.grammar);
auto earliest_trigger_pos = std::string::npos;
auto constrained = data.delta;
for (const auto & trigger : data.params.grammar_triggers) {
auto pos = constrained.find(trigger.word);
if (pos == std::string::npos) {
continue;
}
if (pos > 0 && trigger.at_start) {
fprintf(stderr, "Trigger %s not at start of message, skipping:\n\n%s\n\n", trigger.word.c_str(), constrained.c_str());
continue;
}
if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
earliest_trigger_pos = pos;
}
}
auto grammar_triggered = false;
if (earliest_trigger_pos != std::string::npos) {
constrained = constrained.substr(earliest_trigger_pos);
grammar_triggered = true;
}
if (data.params.grammar_lazy) {
assert_equals(expect_grammar_triggered, grammar_triggered);
}
if (grammar_triggered && !match_string(constrained, grammar.get())) {
throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
"\n\nGrammar: " + data.params.grammar);
}
}
}
}
static void test_template_output_parsers() {
auto text_message = json{
json text_message {
{ "role", "assistant" },
{ "content", "Hello, world!" },
};
auto tool_call_message = json{
json tool_calls = json::array({{
{ "type", "function" },
{ "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } },
}});
json tool_call_message {
{ "role", "assistant"},
{ "content", {}},
{ "tool_calls", {
{
{ "type", "function" },
{ "function", {
{ "name", "special_function" },
{ "arguments", "{\"arg1\": 1}" },
}},
},
}},
};
json tool_call_message_with_id {
{ "role", "assistant"},
{ "content", {}},
{ "tool_calls", {
{
{ "type", "function" },
{ "function", {
{ "name", "special_function" },
{ "arguments", "{\"arg1\": 1}" },
}},
{"id", "123456789"},
},
}},
{ "role", "assistant" },
{ "content", {} },
{ "tool_calls", json{ {
{ "type", "function" },
{ "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } },
} } }
{ "tool_calls", tool_calls }
};
json tool_call_plan_message_with_idx {
{ "role", "assistant"},
{ "content", {}},
{ "tool_plan", "I'm not so sure"},
{ "tool_calls", {
{
{ "type", "function" },
{ "function", {
{ "name", "special_function" },
{ "arguments", "{\"arg1\": 1}" },
}},
// Index of the tool call in the tool_calls array
{"id", "0"},
},
}},
{ "role", "assistant" },
{ "content", {} },
{ "tool_calls", tool_calls }
};
auto tool_call_message_with_id = json::parse(tool_call_message.dump());
tool_call_message_with_id["tool_calls"][0]["id"] = "123456789";
auto python_tool_call_message = json{
{ "role", "assistant" },
@ -322,6 +390,27 @@ static void test_template_output_parsers() {
inputs_tools_builtin.tools = json::array();
inputs_tools_builtin.tools.push_back(python_tool);
{
// Not supported yet
const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja"), "<s>", "</s>");
assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format);
}
{
const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens{ "<|END_OF_TURN_TOKEN|>" };
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, tool_call_plan_message_with_idx, tools,
"<|START_THINKING|>I'm not so sure<|END_THINKING|>"
"<|START_ACTION|>[\n"
" {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
"]<|END_ACTION|>");
test_template(tmpl, end_tokens, text_message, tools,
"<|START_RESPONSE|>Hello, world!<|END_RESPONSE|>",
/* expect_grammar_triggered= */ false);
}
{
const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens{ "<end_of_turn>" };
@ -362,11 +451,10 @@ static void test_template_output_parsers() {
assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
test_template(
tmpl, end_tokens, tool_call_message_with_id, tools,
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]",
/* skip_grammar_test= */ true);
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
}
{
const common_chat_template tmpl(
@ -388,7 +476,7 @@ static void test_template_output_parsers() {
inputs_tools)
.format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, tool_call_message, tools,
"<tool_call>\n"
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
@ -413,7 +501,7 @@ static void test_template_output_parsers() {
inputs_tools_builtin)
.format);
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
"<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
test_template(tmpl, end_tokens, python_tool_call_message, tools,
@ -428,7 +516,7 @@ static void test_template_output_parsers() {
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, tool_call_message, tools,
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
}
@ -440,7 +528,7 @@ static void test_template_output_parsers() {
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, tool_call_message, tools,
"<function=special_function>{\"arg1\": 1}</function>");
}
@ -455,7 +543,7 @@ static void test_template_output_parsers() {
test_template(tmpl, end_tokens, text_message, {},
"all\n"
"Hello, world!",
/* skip_grammar_test= */ true);
/* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, tool_call_message, tools,
"special_function\n"
"{\"arg1\": 1}");
@ -467,7 +555,7 @@ static void test_template_output_parsers() {
assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, tool_call_message, tools,
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
}
@ -478,7 +566,7 @@ static void test_template_output_parsers() {
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, tool_call_message, tools,
"<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>special_function\n"
"```json\n"