From 4e39a3c332f84b890e91353ec448502cb8373a6f Mon Sep 17 00:00:00 2001 From: Olivier Chafik <ochafik@users.noreply.github.com> Date: Mon, 10 Mar 2025 10:59:03 +0000 Subject: [PATCH] `server`: extract <think> tags from qwq outputs (#12297) * extract <think> tags from qwq outputs * const for all static regexes in chat.cpp --- common/chat.cpp | 282 +++++++++++++++++++++++--------------------- common/chat.h | 1 + tests/test-chat.cpp | 13 ++ 3 files changed, 162 insertions(+), 134 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index 1b3f286af..62ca26ad7 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -445,6 +445,7 @@ std::string common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2"; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1"; case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro"; + case COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING: return "Hermes 2 Pro (extract reasoning)"; case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B"; case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING: return "Command R7B (extract reasoning)"; default: @@ -878,9 +879,9 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_ return data; } static common_chat_msg common_chat_parse_command_r7b(const std::string & input, bool extract_reasoning) { - static std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S]*?)<\\|END_THINKING\\|>)([\\s\\S]*)"); - static std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S]*?)<\\|END_ACTION\\|>"); - static std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S]*?)<\\|END_RESPONSE\\|>"); + static const std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S]*?)<\\|END_THINKING\\|>)([\\s\\S]*)"); + static const std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S]*?)<\\|END_ACTION\\|>"); + static const std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S]*?)<\\|END_RESPONSE\\|>"); std::smatch match; @@ -1012,10 +1013,10 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com } static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) { // TODO: tighten & simplify the parser, don't accept leading text context. - static std::regex function_regex( + static const std::regex function_regex( "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: "); - static std::regex close_regex("\\}\\s*"); - static std::regex builtin_call_regex("<\\|python_tag\\|>\\s*([^.(]+)\\s*\\.\\s*call\\s*\\(\\s*([\\w]+)\\s*=\\s*([\\s\\S]*?)\\)"); + static const std::regex close_regex("\\}\\s*"); + static const std::regex builtin_call_regex("<\\|python_tag\\|>\\s*([^.(]+)\\s*\\.\\s*call\\s*\\(\\s*([\\w]+)\\s*=\\s*([\\s\\S]*?)\\)"); if (with_builtin_tools) { std::smatch match; @@ -1105,34 +1106,42 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_ data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING : COMMON_CHAT_FORMAT_DEEPSEEK_R1; return data; } -static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, bool extract_reasoning) { - static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n"); - static std::regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>"); - static std::regex reasoning_content_regex("((?:<think>)?([\\s\\S\\r\\n]*?)</think>)?([\\s\\S\\r\\n]*)"); - static std::regex tool_calls_regex("[\\s\\r\\n]*(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>)([\\s\\S\\r\\n]*?)<|tool▁calls▁end|>"); - common_chat_msg msg; - msg.role = "assistant"; +static common_chat_msg handle_think_tag_prelude(const std::string & input, bool extract_reasoning, const std::function<common_chat_msg(const std::string &)> & rest_parser) { std::smatch match; + static const std::regex reasoning_content_regex("((?:<think>)?([\\s\\S\\r\\n]*?)</think>)?([\\s\\S\\r\\n]*)"); if (std::regex_match(input, match, reasoning_content_regex)) { - std::string rest; + auto rest = match[3].str(); + auto msg = rest_parser(rest); + auto reasoning_content = string_strip(match[2].str()); if (extract_reasoning) { - msg.reasoning_content = string_strip(match[2].str()); - } else { - msg.content = match[1].str(); + msg.reasoning_content = reasoning_content; + } else if (!reasoning_content.empty()) { + std::ostringstream content; + content << "<think>" << reasoning_content << "</think>" << msg.content; + msg.content = content.str(); } - rest = match[3].str(); + return msg; + } + return rest_parser(input); +} +static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, bool extract_reasoning) { + return handle_think_tag_prelude(input, extract_reasoning, [](const std::string & input) { + static const std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n"); + static const std::regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>"); + static const std::regex tool_calls_regex("[\\s\\r\\n]*(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>)([\\s\\S\\r\\n]*?)<|tool▁calls▁end|>"); - if (std::regex_search(rest, match, tool_calls_regex)) { + common_chat_msg msg; + msg.role = "assistant"; + std::smatch match; + if (std::regex_search(input, match, tool_calls_regex)) { auto tool_calls = match[1].str(); auto msg2 = parse_json_tool_calls(tool_calls, std::nullopt, function_regex, close_regex); msg.tool_calls = std::move(msg2.tool_calls); } else { - msg.content += std::string(rest.begin() + rest.find_first_not_of(" \r\n"), rest.end()); + msg.content = input; } - } else { - msg.content = input; - } - return msg; + return msg; + }); } static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { @@ -1237,8 +1246,8 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ } static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & input) { - static std::regex function_regex(R"((?:>>>)?(?:assistant<|end_header_id|>\n)?(\w+)\n)"); - static std::regex close_regex(R"($|(?=>>>))"); + static const std::regex function_regex(R"((?:>>>)?(?:assistant<|end_header_id|>\n)?(\w+)\n)"); + static const std::regex close_regex(R"($|(?=>>>))"); std::string content; auto it = input.begin(); @@ -1327,7 +1336,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con } static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) { // This version of Functionary still supports the llama 3.1 tool call format for the python tool. - static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)"); + static const std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)"); std::smatch match; if (std::regex_search(input, match, python_tag_regex)) { auto code = match[1].str(); @@ -1341,8 +1350,8 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s }); return msg; } - static std::regex function_regex(R"(<function=(\w+)>)"); - static std::regex close_regex(R"(</function>)"); + static const std::regex function_regex(R"(<function=(\w+)>)"); + static const std::regex close_regex(R"(</function>)"); // TODO: tighten & simplify. return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex); } @@ -1409,6 +1418,8 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat "(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?\\s*\\{\\s*\"", //name\"\\s*:\\s*\"" + escaped_name + "\"", }); data.preserved_tokens = { + "<think>", + "</think>", "<tool_call>", "</tool_call>", "<function", @@ -1429,122 +1440,123 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat }); data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); - data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO; + data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING : COMMON_CHAT_FORMAT_HERMES_2_PRO; return data; } -static common_chat_msg common_chat_parse_hermes_2_pro(const std::string& input) { - const static std::regex open_regex( - "(?:" - "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start) - "(<tool_call>" // match 2 (open_tag) - "|<function_call>" - "|<tool>" - "|<tools>" - "|<response>" - "|<json>" - "|<xml>" - "|<JSON>" - ")?" - "(\\s*\\{\\s*\"name\"\\s*:[\\s\\S]*)" // match 3 (named tool call + rest) - ")" - "|" - "(?:<function=([^>]+)>" // match 4 (function name) - "|<function name=\"([^\"]+)\">)" // match 5 (function name again) - "([\\s\\S]*)" // match 6 (function arguments + rest)})" - ); +static common_chat_msg common_chat_parse_hermes_2_pro(const std::string& input, bool extract_reasoning) { + return handle_think_tag_prelude(input, extract_reasoning, [](const std::string & input) { + static const std::regex open_regex( + "(?:" + "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start) + "(<tool_call>" // match 2 (open_tag) + "|<function_call>" + "|<tool>" + "|<tools>" + "|<response>" + "|<json>" + "|<xml>" + "|<JSON>" + ")?" + "(\\s*\\{\\s*\"name\"\\s*:[\\s\\S]*)" // match 3 (named tool call + rest) + ")" + "|" + "(?:<function=([^>]+)>" // match 4 (function name) + "|<function name=\"([^\"]+)\">)" // match 5 (function name again) + "([\\s\\S]*)" // match 6 (function arguments + rest)})" + ); - try { + try { + common_chat_msg msg; + msg.role = "assistant"; - common_chat_msg msg; - msg.role = "assistant"; + std::string::const_iterator it = input.begin(); + const std::string::const_iterator end = input.end(); + std::smatch match; - std::string::const_iterator it = input.begin(); - const std::string::const_iterator end = input.end(); - std::smatch match; + while (it != end) { + if (std::regex_search(it, end, match, open_regex)) { + // Add content before the match + msg.content += std::string(it, match[0].first); - while (it != end) { - if (std::regex_search(it, end, match, open_regex)) { - // Add content before the match - msg.content += std::string(it, match[0].first); + auto block_start = match[1].str(); + std::string block_end = block_start.empty() ? "" : "```"; - auto block_start = match[1].str(); - std::string block_end = block_start.empty() ? "" : "```"; + auto open_tag = match[2].str(); + std::string close_tag; - auto open_tag = match[2].str(); - std::string close_tag; + if (match[3].matched) { + close_tag = open_tag.empty() ? "" : "</" + open_tag.substr(1); + auto json_it = match[3].first; + json tool_call; + if (parse_json(json_it, end, tool_call) && tool_call.contains("name") && tool_call.contains("arguments")) { - if (match[3].matched) { - close_tag = open_tag.empty() ? "" : "</" + open_tag.substr(1); - auto json_it = match[3].first; - json tool_call; - if (parse_json(json_it, end, tool_call) && tool_call.contains("name") && tool_call.contains("arguments")) { + msg.tool_calls.emplace_back(process_tool_call(tool_call)); + it = json_it; // Move iterator past parsed JSON - msg.tool_calls.emplace_back(process_tool_call(tool_call)); - it = json_it; // Move iterator past parsed JSON - - // Handle close tags - consume_spaces(it, end); - if (!close_tag.empty() && !parse_literal(it, end, close_tag)) { - throw std::runtime_error("Failed to parse closing tag"); + // Handle close tags + consume_spaces(it, end); + if (!close_tag.empty() && !parse_literal(it, end, close_tag)) { + throw std::runtime_error("Failed to parse closing tag"); + } + consume_spaces(it, end); + if (!block_end.empty() && !parse_literal(it, end, block_end)) { + throw std::runtime_error("Failed to parse block end"); + } + consume_spaces(it, end); + } else { + // Not a valid tool call, treat as content + msg.content += std::string(match[0].first, match[0].second); + it = match[0].second; } - consume_spaces(it, end); - if (!block_end.empty() && !parse_literal(it, end, block_end)) { - throw std::runtime_error("Failed to parse block end"); - } - consume_spaces(it, end); } else { - // Not a valid tool call, treat as content - msg.content += std::string(match[0].first, match[0].second); - it = match[0].second; + auto function_name = match[4].str(); + if (function_name.empty()) { + function_name = match[5].str(); + } + GGML_ASSERT(!function_name.empty()); + + close_tag = "</function>"; + // Start parsing from after the opening tags + auto json_it = match[6].first; + json arguments; + if (parse_json(json_it, end, arguments)) { + msg.tool_calls.emplace_back(process_tool_call({ + {"name", function_name}, + {"arguments", arguments}, + })); + it = json_it; // Move iterator past parsed JSON + + // Handle close tags + consume_spaces(it, end); + if (!close_tag.empty() && !parse_literal(it, end, close_tag)) { + throw std::runtime_error("Failed to parse closing tag"); + } + consume_spaces(it, end); + if (!block_end.empty() && !parse_literal(it, end, block_end)) { + throw std::runtime_error("Failed to parse block end"); + } + consume_spaces(it, end); + } else { + // Not a valid tool call, treat as content + msg.content += std::string(match[0].first, match[0].second); + it = match[0].second; + } } } else { - auto function_name = match[4].str(); - if (function_name.empty()) { - function_name = match[5].str(); - } - GGML_ASSERT(!function_name.empty()); - - close_tag = "</function>"; - // Start parsing from after the opening tags - auto json_it = match[6].first; - json arguments; - if (parse_json(json_it, end, arguments)) { - msg.tool_calls.emplace_back(process_tool_call({ - {"name", function_name}, - {"arguments", arguments}, - })); - it = json_it; // Move iterator past parsed JSON - - // Handle close tags - consume_spaces(it, end); - if (!close_tag.empty() && !parse_literal(it, end, close_tag)) { - throw std::runtime_error("Failed to parse closing tag"); - } - consume_spaces(it, end); - if (!block_end.empty() && !parse_literal(it, end, block_end)) { - throw std::runtime_error("Failed to parse block end"); - } - consume_spaces(it, end); - } else { - // Not a valid tool call, treat as content - msg.content += std::string(match[0].first, match[0].second); - it = match[0].second; - } + // Add remaining content + msg.content += std::string(it, end); + break; } - } else { - // Add remaining content - msg.content += std::string(it, end); - break; } + return msg; + } catch (const std::exception & e) { + LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what()); + common_chat_msg msg; + msg.role = "assistant"; + msg.content = input; + return msg; } - return msg; - } catch (const std::exception & e) { - LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what()); - common_chat_msg msg; - msg.role = "assistant"; - msg.content = input; - return msg; - } + }); } static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { @@ -1609,6 +1621,11 @@ static common_chat_params common_chat_templates_apply_jinja( return common_chat_params_init_command_r7b(tmpl, params); } + // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools) + if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) { + return common_chat_params_init_hermes_2_pro(tmpl, params); + } + // Use generic handler when mixing tools + JSON schema. // TODO: support that mix in handlers below. if ((params.tools.is_array() && params.json_schema.is_object())) { @@ -1630,11 +1647,6 @@ static common_chat_params common_chat_templates_apply_jinja( return common_chat_params_init_without_tools(tmpl, params); } - // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools) - if (src.find("<tool_call>") != std::string::npos) { - return common_chat_params_init_hermes_2_pro(tmpl, params); - } - // Functionary v3.1 (w/ tools) if (src.find("<|start_header_id|>") != std::string::npos && src.find("<function=") != std::string::npos) { @@ -1752,7 +1764,9 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return common_chat_parse_functionary_v3_1_llama_3_1(input); case COMMON_CHAT_FORMAT_HERMES_2_PRO: - return common_chat_parse_hermes_2_pro(input); + return common_chat_parse_hermes_2_pro(input, /* extract_reasoning= */ false); + case COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING: + return common_chat_parse_hermes_2_pro(input, /* extract_reasoning= */ true); case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return common_chat_parse_firefunction_v2(input); case COMMON_CHAT_FORMAT_COMMAND_R7B: diff --git a/common/chat.h b/common/chat.h index e77bef82b..9aad84e88 100644 --- a/common/chat.h +++ b/common/chat.h @@ -53,6 +53,7 @@ enum common_chat_format { COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, COMMON_CHAT_FORMAT_HERMES_2_PRO, + COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING, COMMON_CHAT_FORMAT_COMMAND_R7B, COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 35c7ee34e..a1034b1a4 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -766,6 +766,19 @@ static void test_template_output_parsers() { "{\n \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_thoughts_unparsed_think, + common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_thoughts_unparsed_think, + common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_thoughts, + common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?", + COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING)); + assert_msg_equals(message_assist_thoughts, + common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?", + COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING)); + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); test_templates(tmpls.get(), end_tokens, message_assist_call, tools, "<tool_call>\n"