From 4e39a3c332f84b890e91353ec448502cb8373a6f Mon Sep 17 00:00:00 2001
From: Olivier Chafik <ochafik@users.noreply.github.com>
Date: Mon, 10 Mar 2025 10:59:03 +0000
Subject: [PATCH] `server`: extract <think> tags from qwq outputs (#12297)

* extract <think> tags from qwq outputs

* const for all static regexes in chat.cpp
---
 common/chat.cpp     | 282 +++++++++++++++++++++++---------------------
 common/chat.h       |   1 +
 tests/test-chat.cpp |  13 ++
 3 files changed, 162 insertions(+), 134 deletions(-)

diff --git a/common/chat.cpp b/common/chat.cpp
index 1b3f286af..62ca26ad7 100644
--- a/common/chat.cpp
+++ b/common/chat.cpp
@@ -445,6 +445,7 @@ std::string common_chat_format_name(common_chat_format format) {
         case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
         case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
         case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
+        case COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING: return "Hermes 2 Pro (extract reasoning)";
         case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
         case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING: return "Command R7B (extract reasoning)";
         default:
@@ -878,9 +879,9 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
     return data;
 }
 static common_chat_msg common_chat_parse_command_r7b(const std::string & input, bool extract_reasoning) {
-    static std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S]*?)<\\|END_THINKING\\|>)([\\s\\S]*)");
-    static std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S]*?)<\\|END_ACTION\\|>");
-    static std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S]*?)<\\|END_RESPONSE\\|>");
+    static const std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S]*?)<\\|END_THINKING\\|>)([\\s\\S]*)");
+    static const std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S]*?)<\\|END_ACTION\\|>");
+    static const std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S]*?)<\\|END_RESPONSE\\|>");
 
     std::smatch match;
 
@@ -1012,10 +1013,10 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
 }
 static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
     // TODO: tighten & simplify the parser, don't accept leading text context.
-    static std::regex function_regex(
+    static const std::regex function_regex(
         "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
-    static std::regex close_regex("\\}\\s*");
-    static std::regex builtin_call_regex("<\\|python_tag\\|>\\s*([^.(]+)\\s*\\.\\s*call\\s*\\(\\s*([\\w]+)\\s*=\\s*([\\s\\S]*?)\\)");
+    static const std::regex close_regex("\\}\\s*");
+    static const std::regex builtin_call_regex("<\\|python_tag\\|>\\s*([^.(]+)\\s*\\.\\s*call\\s*\\(\\s*([\\w]+)\\s*=\\s*([\\s\\S]*?)\\)");
 
     if (with_builtin_tools) {
         std::smatch match;
@@ -1105,34 +1106,42 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
     data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING : COMMON_CHAT_FORMAT_DEEPSEEK_R1;
     return data;
 }
-static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, bool extract_reasoning) {
-    static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
-    static std::regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>");
-    static std::regex reasoning_content_regex("((?:<think>)?([\\s\\S\\r\\n]*?)</think>)?([\\s\\S\\r\\n]*)");
-    static std::regex tool_calls_regex("[\\s\\r\\n]*(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>)([\\s\\S\\r\\n]*?)<|tool▁calls▁end|>");
-    common_chat_msg msg;
-    msg.role = "assistant";
+static common_chat_msg handle_think_tag_prelude(const std::string & input, bool extract_reasoning, const std::function<common_chat_msg(const std::string &)> & rest_parser) {
     std::smatch match;
+    static const std::regex reasoning_content_regex("((?:<think>)?([\\s\\S\\r\\n]*?)</think>)?([\\s\\S\\r\\n]*)");
     if (std::regex_match(input, match, reasoning_content_regex)) {
-        std::string rest;
+        auto rest = match[3].str();
+        auto msg = rest_parser(rest);
+        auto reasoning_content = string_strip(match[2].str());
         if (extract_reasoning) {
-            msg.reasoning_content = string_strip(match[2].str());
-        } else {
-            msg.content = match[1].str();
+            msg.reasoning_content = reasoning_content;
+        } else if (!reasoning_content.empty()) {
+            std::ostringstream content;
+            content << "<think>" << reasoning_content << "</think>" << msg.content;
+            msg.content = content.str();
         }
-        rest = match[3].str();
+        return msg;
+    }
+    return rest_parser(input);
+}
+static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, bool extract_reasoning) {
+    return handle_think_tag_prelude(input, extract_reasoning, [](const std::string & input) {
+        static const std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
+        static const std::regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>");
+        static const std::regex tool_calls_regex("[\\s\\r\\n]*(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>)([\\s\\S\\r\\n]*?)<|tool▁calls▁end|>");
 
-        if (std::regex_search(rest, match, tool_calls_regex)) {
+        common_chat_msg msg;
+        msg.role = "assistant";
+        std::smatch match;
+        if (std::regex_search(input, match, tool_calls_regex)) {
             auto tool_calls = match[1].str();
             auto msg2 = parse_json_tool_calls(tool_calls, std::nullopt, function_regex, close_regex);
             msg.tool_calls = std::move(msg2.tool_calls);
         } else {
-            msg.content += std::string(rest.begin() + rest.find_first_not_of(" \r\n"), rest.end());
+            msg.content = input;
         }
-    } else {
-        msg.content = input;
-    }
-    return msg;
+        return msg;
+    });
 }
 
 static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
@@ -1237,8 +1246,8 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
 }
 
 static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & input) {
-    static std::regex function_regex(R"((?:>>>)?(?:assistant<|end_header_id|>\n)?(\w+)\n)");
-    static std::regex close_regex(R"($|(?=>>>))");
+    static const std::regex function_regex(R"((?:>>>)?(?:assistant<|end_header_id|>\n)?(\w+)\n)");
+    static const std::regex close_regex(R"($|(?=>>>))");
 
     std::string content;
     auto it = input.begin();
@@ -1327,7 +1336,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
 }
 static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
     // This version of Functionary still supports the llama 3.1 tool call format for the python tool.
-    static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
+    static const std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
     std::smatch match;
     if (std::regex_search(input, match, python_tag_regex)) {
         auto code = match[1].str();
@@ -1341,8 +1350,8 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s
         });
         return msg;
     }
-    static std::regex function_regex(R"(<function=(\w+)>)");
-    static std::regex close_regex(R"(</function>)");
+    static const std::regex function_regex(R"(<function=(\w+)>)");
+    static const std::regex close_regex(R"(</function>)");
     // TODO: tighten & simplify.
     return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
 }
@@ -1409,6 +1418,8 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
             "(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?\\s*\\{\\s*\"", //name\"\\s*:\\s*\"" + escaped_name + "\"",
         });
         data.preserved_tokens = {
+            "<think>",
+            "</think>",
             "<tool_call>",
             "</tool_call>",
             "<function",
@@ -1429,122 +1440,123 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
     });
 
     data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
-    data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO;
+    data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING : COMMON_CHAT_FORMAT_HERMES_2_PRO;
     return data;
 }
-static common_chat_msg common_chat_parse_hermes_2_pro(const std::string& input) {
-    const static std::regex open_regex(
-        "(?:"
-        "(```(?:xml|json)?\\n\\s*)?"         // match 1 (block_start)
-        "(<tool_call>"                   // match 2 (open_tag)
-        "|<function_call>"
-        "|<tool>"
-        "|<tools>"
-        "|<response>"
-        "|<json>"
-        "|<xml>"
-        "|<JSON>"
-        ")?"
-        "(\\s*\\{\\s*\"name\"\\s*:[\\s\\S]*)"    // match 3 (named tool call + rest)
-        ")"
-        "|"
-        "(?:<function=([^>]+)>"            // match 4 (function name)
-        "|<function name=\"([^\"]+)\">)" // match 5 (function name again)
-        "([\\s\\S]*)"                   // match 6 (function arguments + rest)})"
-    );
+static common_chat_msg common_chat_parse_hermes_2_pro(const std::string& input, bool extract_reasoning) {
+    return handle_think_tag_prelude(input, extract_reasoning, [](const std::string & input) {
+        static const std::regex open_regex(
+            "(?:"
+            "(```(?:xml|json)?\\n\\s*)?"         // match 1 (block_start)
+            "(<tool_call>"                   // match 2 (open_tag)
+            "|<function_call>"
+            "|<tool>"
+            "|<tools>"
+            "|<response>"
+            "|<json>"
+            "|<xml>"
+            "|<JSON>"
+            ")?"
+            "(\\s*\\{\\s*\"name\"\\s*:[\\s\\S]*)"    // match 3 (named tool call + rest)
+            ")"
+            "|"
+            "(?:<function=([^>]+)>"            // match 4 (function name)
+            "|<function name=\"([^\"]+)\">)" // match 5 (function name again)
+            "([\\s\\S]*)"                   // match 6 (function arguments + rest)})"
+        );
 
-    try {
+        try {
+            common_chat_msg msg;
+            msg.role = "assistant";
 
-        common_chat_msg msg;
-        msg.role = "assistant";
+            std::string::const_iterator it = input.begin();
+            const std::string::const_iterator end = input.end();
+            std::smatch match;
 
-        std::string::const_iterator it = input.begin();
-        const std::string::const_iterator end = input.end();
-        std::smatch match;
+            while (it != end) {
+                if (std::regex_search(it, end, match, open_regex)) {
+                    // Add content before the match
+                    msg.content += std::string(it, match[0].first);
 
-        while (it != end) {
-            if (std::regex_search(it, end, match, open_regex)) {
-                // Add content before the match
-                msg.content += std::string(it, match[0].first);
+                    auto block_start = match[1].str();
+                    std::string block_end = block_start.empty() ? "" : "```";
 
-                auto block_start = match[1].str();
-                std::string block_end = block_start.empty() ? "" : "```";
+                    auto open_tag = match[2].str();
+                    std::string close_tag;
 
-                auto open_tag = match[2].str();
-                std::string close_tag;
+                    if (match[3].matched) {
+                        close_tag = open_tag.empty() ? "" : "</" + open_tag.substr(1);
+                        auto json_it = match[3].first;
+                        json tool_call;
+                        if (parse_json(json_it, end, tool_call) && tool_call.contains("name") && tool_call.contains("arguments")) {
 
-                if (match[3].matched) {
-                    close_tag = open_tag.empty() ? "" : "</" + open_tag.substr(1);
-                    auto json_it = match[3].first;
-                    json tool_call;
-                    if (parse_json(json_it, end, tool_call) && tool_call.contains("name") && tool_call.contains("arguments")) {
+                            msg.tool_calls.emplace_back(process_tool_call(tool_call));
+                            it = json_it;  // Move iterator past parsed JSON
 
-                        msg.tool_calls.emplace_back(process_tool_call(tool_call));
-                        it = json_it;  // Move iterator past parsed JSON
-
-                        // Handle close tags
-                        consume_spaces(it, end);
-                        if (!close_tag.empty() && !parse_literal(it, end, close_tag)) {
-                            throw std::runtime_error("Failed to parse closing tag");
+                            // Handle close tags
+                            consume_spaces(it, end);
+                            if (!close_tag.empty() && !parse_literal(it, end, close_tag)) {
+                                throw std::runtime_error("Failed to parse closing tag");
+                            }
+                            consume_spaces(it, end);
+                            if (!block_end.empty() && !parse_literal(it, end, block_end)) {
+                                throw std::runtime_error("Failed to parse block end");
+                            }
+                            consume_spaces(it, end);
+                        } else {
+                            // Not a valid tool call, treat as content
+                            msg.content += std::string(match[0].first, match[0].second);
+                            it = match[0].second;
                         }
-                        consume_spaces(it, end);
-                        if (!block_end.empty() && !parse_literal(it, end, block_end)) {
-                            throw std::runtime_error("Failed to parse block end");
-                        }
-                        consume_spaces(it, end);
                     } else {
-                        // Not a valid tool call, treat as content
-                        msg.content += std::string(match[0].first, match[0].second);
-                        it = match[0].second;
+                        auto function_name = match[4].str();
+                        if (function_name.empty()) {
+                            function_name = match[5].str();
+                        }
+                        GGML_ASSERT(!function_name.empty());
+
+                        close_tag = "</function>";
+                        // Start parsing from after the opening tags
+                        auto json_it = match[6].first;
+                        json arguments;
+                        if (parse_json(json_it, end, arguments)) {
+                            msg.tool_calls.emplace_back(process_tool_call({
+                                {"name", function_name},
+                                {"arguments", arguments},
+                            }));
+                            it = json_it;  // Move iterator past parsed JSON
+
+                            // Handle close tags
+                            consume_spaces(it, end);
+                            if (!close_tag.empty() && !parse_literal(it, end, close_tag)) {
+                                throw std::runtime_error("Failed to parse closing tag");
+                            }
+                            consume_spaces(it, end);
+                            if (!block_end.empty() && !parse_literal(it, end, block_end)) {
+                                throw std::runtime_error("Failed to parse block end");
+                            }
+                            consume_spaces(it, end);
+                        } else {
+                            // Not a valid tool call, treat as content
+                            msg.content += std::string(match[0].first, match[0].second);
+                            it = match[0].second;
+                        }
                     }
                 } else {
-                    auto function_name = match[4].str();
-                    if (function_name.empty()) {
-                        function_name = match[5].str();
-                    }
-                    GGML_ASSERT(!function_name.empty());
-
-                    close_tag = "</function>";
-                    // Start parsing from after the opening tags
-                    auto json_it = match[6].first;
-                    json arguments;
-                    if (parse_json(json_it, end, arguments)) {
-                        msg.tool_calls.emplace_back(process_tool_call({
-                            {"name", function_name},
-                            {"arguments", arguments},
-                        }));
-                        it = json_it;  // Move iterator past parsed JSON
-
-                        // Handle close tags
-                        consume_spaces(it, end);
-                        if (!close_tag.empty() && !parse_literal(it, end, close_tag)) {
-                            throw std::runtime_error("Failed to parse closing tag");
-                        }
-                        consume_spaces(it, end);
-                        if (!block_end.empty() && !parse_literal(it, end, block_end)) {
-                            throw std::runtime_error("Failed to parse block end");
-                        }
-                        consume_spaces(it, end);
-                    } else {
-                        // Not a valid tool call, treat as content
-                        msg.content += std::string(match[0].first, match[0].second);
-                        it = match[0].second;
-                    }
+                    // Add remaining content
+                    msg.content += std::string(it, end);
+                    break;
                 }
-            } else {
-                // Add remaining content
-                msg.content += std::string(it, end);
-                break;
             }
+            return msg;
+        } catch (const std::exception & e) {
+            LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what());
+            common_chat_msg msg;
+            msg.role = "assistant";
+            msg.content = input;
+            return msg;
         }
-        return msg;
-    } catch (const std::exception & e) {
-        LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what());
-        common_chat_msg msg;
-        msg.role = "assistant";
-        msg.content = input;
-        return msg;
-    }
+    });
 }
 
 static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
@@ -1609,6 +1621,11 @@ static common_chat_params common_chat_templates_apply_jinja(
         return common_chat_params_init_command_r7b(tmpl, params);
     }
 
+    // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
+    if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
+        return common_chat_params_init_hermes_2_pro(tmpl, params);
+    }
+
     // Use generic handler when mixing tools + JSON schema.
     // TODO: support that mix in handlers below.
     if ((params.tools.is_array() && params.json_schema.is_object())) {
@@ -1630,11 +1647,6 @@ static common_chat_params common_chat_templates_apply_jinja(
         return common_chat_params_init_without_tools(tmpl, params);
     }
 
-    // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
-    if (src.find("<tool_call>") != std::string::npos) {
-        return common_chat_params_init_hermes_2_pro(tmpl, params);
-    }
-
     // Functionary v3.1 (w/ tools)
     if (src.find("<|start_header_id|>") != std::string::npos
         && src.find("<function=") != std::string::npos) {
@@ -1752,7 +1764,9 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format
         case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
             return common_chat_parse_functionary_v3_1_llama_3_1(input);
         case COMMON_CHAT_FORMAT_HERMES_2_PRO:
-            return common_chat_parse_hermes_2_pro(input);
+            return common_chat_parse_hermes_2_pro(input, /* extract_reasoning= */ false);
+        case COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING:
+            return common_chat_parse_hermes_2_pro(input, /* extract_reasoning= */ true);
         case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
             return common_chat_parse_firefunction_v2(input);
         case COMMON_CHAT_FORMAT_COMMAND_R7B:
diff --git a/common/chat.h b/common/chat.h
index e77bef82b..9aad84e88 100644
--- a/common/chat.h
+++ b/common/chat.h
@@ -53,6 +53,7 @@ enum common_chat_format {
     COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
     COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
     COMMON_CHAT_FORMAT_HERMES_2_PRO,
+    COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING,
     COMMON_CHAT_FORMAT_COMMAND_R7B,
     COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
 
diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp
index 35c7ee34e..a1034b1a4 100644
--- a/tests/test-chat.cpp
+++ b/tests/test-chat.cpp
@@ -766,6 +766,19 @@ static void test_template_output_parsers() {
             "{\n  \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
             COMMON_CHAT_FORMAT_HERMES_2_PRO));
 
+        assert_msg_equals(message_assist_thoughts_unparsed_think,
+            common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
+            COMMON_CHAT_FORMAT_HERMES_2_PRO));
+        assert_msg_equals(message_assist_thoughts_unparsed_think,
+            common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
+            COMMON_CHAT_FORMAT_HERMES_2_PRO));
+        assert_msg_equals(message_assist_thoughts,
+            common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
+            COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING));
+        assert_msg_equals(message_assist_thoughts,
+            common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
+            COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING));
+
         test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
         test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
                       "<tool_call>\n"