From 267c1399f15a278ec8c3cdcf9c90dc94151fbc38 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Tue, 1 Apr 2025 23:44:05 +0200 Subject: [PATCH] common : refactor downloading system, handle mmproj with -hf option (#12694) * (wip) refactor downloading system [no ci] * fix all examples * fix mmproj with -hf * gemma3: update readme * only handle mmproj in llava example * fix multi-shard download * windows: fix problem with std::min and std::max * fix 2 --- common/arg.cpp | 675 ++++++++++++++++-- common/common.cpp | 495 +------------ common/common.h | 44 +- examples/batched-bench/batched-bench.cpp | 2 +- examples/batched/batched.cpp | 2 +- examples/export-lora/export-lora.cpp | 2 +- examples/gritlm/gritlm.cpp | 2 +- examples/llava/README-gemma3.md | 20 + examples/llava/gemma3-cli.cpp | 6 +- examples/llava/llava-cli.cpp | 6 +- examples/llava/minicpmv-cli.cpp | 6 +- examples/llava/qwen2vl-cli.cpp | 6 +- examples/parallel/parallel.cpp | 2 +- examples/passkey/passkey.cpp | 2 +- examples/server/server.cpp | 19 +- .../speculative-simple/speculative-simple.cpp | 2 +- examples/speculative/speculative.cpp | 2 +- examples/tts/tts.cpp | 7 +- tests/test-arg-parser.cpp | 8 +- 19 files changed, 673 insertions(+), 635 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 8292adaac..47c26955e 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1,9 +1,19 @@ +#include "gguf.h" // for reading GGUF splits #include "arg.h" #include "log.h" #include "sampling.h" #include "chat.h" +// fix problem with std::min and std::max +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +# define NOMINMAX +#endif +#include +#endif + #include #include #include @@ -14,6 +24,14 @@ #include #include +//#define LLAMA_USE_CURL + +#if defined(LLAMA_USE_CURL) +#include +#include +#include +#endif + #include "json-schema-to-grammar.h" using json = nlohmann::ordered_json; @@ -125,47 +143,549 @@ std::string common_arg::to_string() { return ss.str(); } +// +// downloader +// + +struct common_hf_file_res { + std::string repo; // repo name with ":tag" removed + std::string ggufFile; + std::string mmprojFile; +}; + +#ifdef LLAMA_USE_CURL + +#ifdef __linux__ +#include +#elif defined(_WIN32) +# if !defined(PATH_MAX) +# define PATH_MAX MAX_PATH +# endif +#else +#include +#endif +#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 + +// +// CURL utils +// + +using curl_ptr = std::unique_ptr; + +// cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one +struct curl_slist_ptr { + struct curl_slist * ptr = nullptr; + ~curl_slist_ptr() { + if (ptr) { + curl_slist_free_all(ptr); + } + } +}; + +#define CURL_MAX_RETRY 3 +#define CURL_RETRY_DELAY_SECONDS 2 + +static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds) { + int remaining_attempts = max_attempts; + + while (remaining_attempts > 0) { + LOG_INF("%s: Trying to download from %s (attempt %d of %d)...\n", __func__ , url.c_str(), max_attempts - remaining_attempts + 1, max_attempts); + + CURLcode res = curl_easy_perform(curl); + if (res == CURLE_OK) { + return true; + } + + int exponential_backoff_delay = std::pow(retry_delay_seconds, max_attempts - remaining_attempts) * 1000; + LOG_WRN("%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay); + + remaining_attempts--; + std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay)); + } + + LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts); + + return false; +} + +// download one single file from remote URL to local path +static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token) { + // Initialize libcurl + curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); + curl_slist_ptr http_headers; + if (!curl) { + LOG_ERR("%s: error initializing libcurl\n", __func__); + return false; + } + + bool force_download = false; + + // Set the URL, allow to follow http redirection + curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); + + // Check if hf-token or bearer-token was specified + if (!bearer_token.empty()) { + std::string auth_header = "Authorization: Bearer " + bearer_token; + http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); + } + +#if defined(_WIN32) + // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of + // operating system. Currently implemented under MS-Windows. + curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); +#endif + + // Check if the file already exists locally + auto file_exists = std::filesystem::exists(path); + + // If the file exists, check its JSON metadata companion file. + std::string metadata_path = path + ".json"; + nlohmann::json metadata; + std::string etag; + std::string last_modified; + + if (file_exists) { + // Try and read the JSON metadata file (note: stream autoclosed upon exiting this block). + std::ifstream metadata_in(metadata_path); + if (metadata_in.good()) { + try { + metadata_in >> metadata; + LOG_INF("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str()); + if (metadata.contains("url") && metadata.at("url").is_string()) { + auto previous_url = metadata.at("url").get(); + if (previous_url != url) { + LOG_ERR("%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str()); + return false; + } + } + if (metadata.contains("etag") && metadata.at("etag").is_string()) { + etag = metadata.at("etag"); + } + if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) { + last_modified = metadata.at("lastModified"); + } + } catch (const nlohmann::json::exception & e) { + LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); + return false; + } + } + } else { + LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); + } + + // Send a HEAD request to retrieve the etag and last-modified headers + struct common_load_model_from_url_headers { + std::string etag; + std::string last_modified; + }; + + common_load_model_from_url_headers headers; + + { + typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *); + auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t { + common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata; + + static std::regex header_regex("([^:]+): (.*)\r\n"); + static std::regex etag_regex("ETag", std::regex_constants::icase); + static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase); + + std::string header(buffer, n_items); + std::smatch match; + if (std::regex_match(header, match, header_regex)) { + const std::string & key = match[1]; + const std::string & value = match[2]; + if (std::regex_match(key, match, etag_regex)) { + headers->etag = value; + } else if (std::regex_match(key, match, last_modified_regex)) { + headers->last_modified = value; + } + } + return n_items; + }; + + curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress + curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast(header_callback)); + curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers); + + bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS); + if (!was_perform_successful) { + return false; + } + + long http_code = 0; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code); + if (http_code != 200) { + // HEAD not supported, we don't know if the file has changed + // force trigger downloading + force_download = true; + LOG_ERR("%s: HEAD invalid http status code received: %ld\n", __func__, http_code); + } + } + + bool should_download = !file_exists || force_download; + if (!should_download) { + if (!etag.empty() && etag != headers.etag) { + LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str()); + should_download = true; + } else if (!last_modified.empty() && last_modified != headers.last_modified) { + LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str()); + should_download = true; + } + } + if (should_download) { + std::string path_temporary = path + ".downloadInProgress"; + if (file_exists) { + LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str()); + if (remove(path.c_str()) != 0) { + LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str()); + return false; + } + } + + // Set the output file + + struct FILE_deleter { + void operator()(FILE * f) const { + fclose(f); + } + }; + + std::unique_ptr outfile(fopen(path_temporary.c_str(), "wb")); + if (!outfile) { + LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str()); + return false; + } + + typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd); + auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t { + return fwrite(data, size, nmemb, (FILE *)fd); + }; + curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L); + curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); + curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get()); + + // display download progress + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L); + + // helper function to hide password in URL + auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string { + std::size_t protocol_pos = url.find("://"); + if (protocol_pos == std::string::npos) { + return url; // Malformed URL + } + + std::size_t at_pos = url.find('@', protocol_pos + 3); + if (at_pos == std::string::npos) { + return url; // No password in URL + } + + return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos); + }; + + // start the download + LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__, + llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str()); + bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS); + if (!was_perform_successful) { + return false; + } + + long http_code = 0; + curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code); + if (http_code < 200 || http_code >= 400) { + LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code); + return false; + } + + // Causes file to be closed explicitly here before we rename it. + outfile.reset(); + + // Write the updated JSON metadata file. + metadata.update({ + {"url", url}, + {"etag", headers.etag}, + {"lastModified", headers.last_modified} + }); + std::ofstream(metadata_path) << metadata.dump(4); + LOG_INF("%s: file metadata saved: %s\n", __func__, metadata_path.c_str()); + + if (rename(path_temporary.c_str(), path.c_str()) != 0) { + LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); + return false; + } + } + + return true; +} + +// download multiple files from remote URLs to local paths +// the input is a vector of pairs +static bool common_download_file_multiple(const std::vector> & urls, const std::string & bearer_token) { + // Prepare download in parallel + std::vector> futures_download; + for (auto const & item : urls) { + futures_download.push_back(std::async(std::launch::async, [bearer_token](const std::pair & it) -> bool { + return common_download_file_single(it.first, it.second, bearer_token); + }, item)); + } + + // Wait for all downloads to complete + for (auto & f : futures_download) { + if (!f.get()) { + return false; + } + } + + return true; +} + +static bool common_download_model( + const common_params_model & model, + const std::string & bearer_token) { + // Basic validation of the model.url + if (model.url.empty()) { + LOG_ERR("%s: invalid model url\n", __func__); + return false; + } + + if (!common_download_file_single(model.url, model.path, bearer_token)) { + return false; + } + + // check for additional GGUFs split to download + int n_split = 0; + { + struct gguf_init_params gguf_params = { + /*.no_alloc = */ true, + /*.ctx = */ NULL, + }; + auto * ctx_gguf = gguf_init_from_file(model.path.c_str(), gguf_params); + if (!ctx_gguf) { + LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, model.path.c_str()); + return false; + } + + auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT); + if (key_n_split >= 0) { + n_split = gguf_get_val_u16(ctx_gguf, key_n_split); + } + + gguf_free(ctx_gguf); + } + + if (n_split > 1) { + char split_prefix[PATH_MAX] = {0}; + char split_url_prefix[LLAMA_CURL_MAX_URL_LENGTH] = {0}; + + // Verify the first split file format + // and extract split URL and PATH prefixes + { + if (!llama_split_prefix(split_prefix, sizeof(split_prefix), model.path.c_str(), 0, n_split)) { + LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, model.path.c_str(), n_split); + return false; + } + + if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model.url.c_str(), 0, n_split)) { + LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model.url.c_str(), n_split); + return false; + } + } + + std::vector> urls; + for (int idx = 1; idx < n_split; idx++) { + char split_path[PATH_MAX] = {0}; + llama_split_path(split_path, sizeof(split_path), split_prefix, idx, n_split); + + char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0}; + llama_split_path(split_url, sizeof(split_url), split_url_prefix, idx, n_split); + + if (std::string(split_path) == model.path) { + continue; // skip the already downloaded file + } + + urls.push_back({split_url, split_path}); + } + + // Download in parallel + common_download_file_multiple(urls, bearer_token); + } + + return true; +} + +/** + * Allow getting the HF file from the HF repo with tag (like ollama), for example: + * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4 + * - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M + * - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s + * Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo) + * + * Return pair of (with "repo" already having tag removed) + * + * Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files. + */ +static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token) { + auto parts = string_split(hf_repo_with_tag, ':'); + std::string tag = parts.size() > 1 ? parts.back() : "latest"; + std::string hf_repo = parts[0]; + if (string_split(hf_repo, '/').size() != 2) { + throw std::invalid_argument("error: invalid HF repo format, expected /[:quant]\n"); + } + + // fetch model info from Hugging Face Hub API + curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); + curl_slist_ptr http_headers; + std::string res_str; + std::string url = "https://huggingface.co/v2/" + hf_repo + "/manifests/" + tag; + curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); + typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data); + auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t { + static_cast(data)->append((char * ) ptr, size * nmemb); + return size * nmemb; + }; + curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); + curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str); +#if defined(_WIN32) + curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); +#endif + if (!bearer_token.empty()) { + std::string auth_header = "Authorization: Bearer " + bearer_token; + http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); + } + // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response + http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); + http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json"); + curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); + + CURLcode res = curl_easy_perform(curl.get()); + + if (res != CURLE_OK) { + throw std::runtime_error("error: cannot make GET request to HF API"); + } + + long res_code; + std::string ggufFile = ""; + std::string mmprojFile = ""; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code); + if (res_code == 200) { + // extract ggufFile.rfilename in json, using regex + { + std::regex pattern("\"ggufFile\"[\\s\\S]*?\"rfilename\"\\s*:\\s*\"([^\"]+)\""); + std::smatch match; + if (std::regex_search(res_str, match, pattern)) { + ggufFile = match[1].str(); + } + } + // extract mmprojFile.rfilename in json, using regex + { + std::regex pattern("\"mmprojFile\"[\\s\\S]*?\"rfilename\"\\s*:\\s*\"([^\"]+)\""); + std::smatch match; + if (std::regex_search(res_str, match, pattern)) { + mmprojFile = match[1].str(); + } + } + } else if (res_code == 401) { + throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token"); + } else { + throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str())); + } + + // check response + if (ggufFile.empty()) { + throw std::runtime_error("error: model does not have ggufFile"); + } + + return { hf_repo, ggufFile, mmprojFile }; +} + +#else + +static bool common_download_file_single(const std::string &, const std::string &, const std::string &) { + LOG_ERR("error: built without CURL, cannot download model from internet\n"); + return false; +} + +static bool common_download_file_multiple(const std::vector> &, const std::string &) { + LOG_ERR("error: built without CURL, cannot download model from the internet\n"); + return false; +} + +static bool common_download_model( + const common_params_model &, + const std::string &) { + LOG_ERR("error: built without CURL, cannot download model from the internet\n"); + return false; +} + +static struct common_hf_file_res common_get_hf_file(const std::string &, const std::string &) { + LOG_ERR("error: built without CURL, cannot download model from the internet\n"); + return {}; +} + +#endif // LLAMA_USE_CURL + // // utils // -static void common_params_handle_model_default( - std::string & model, - const std::string & model_url, - std::string & hf_repo, - std::string & hf_file, - const std::string & hf_token, - const std::string & model_default) { - if (!hf_repo.empty()) { - // short-hand to avoid specifying --hf-file -> default it to --model - if (hf_file.empty()) { - if (model.empty()) { - auto auto_detected = common_get_hf_file(hf_repo, hf_token); - if (auto_detected.first.empty() || auto_detected.second.empty()) { - exit(1); // built without CURL, error message already printed +static void common_params_handle_model( + struct common_params_model & model, + const std::string & bearer_token, + const std::string & model_path_default, + bool is_mmproj = false) { // TODO: move is_mmproj to an enum when we have more files? + // handle pre-fill default model path and url based on hf_repo and hf_file + { + if (!model.hf_repo.empty()) { + // short-hand to avoid specifying --hf-file -> default it to --model + if (model.hf_file.empty()) { + if (model.path.empty()) { + auto auto_detected = common_get_hf_file(model.hf_repo, bearer_token); + if (auto_detected.repo.empty() || auto_detected.ggufFile.empty()) { + exit(1); // built without CURL, error message already printed + } + model.hf_repo = auto_detected.repo; + model.hf_file = is_mmproj ? auto_detected.mmprojFile : auto_detected.ggufFile; + } else { + model.hf_file = model.path; } - hf_repo = auto_detected.first; - hf_file = auto_detected.second; - } else { - hf_file = model; } + + // TODO: allow custom host + model.url = "https://huggingface.co/" + model.hf_repo + "/resolve/main/" + model.hf_file; + + // make sure model path is present (for caching purposes) + if (model.path.empty()) { + // this is to avoid different repo having same file name, or same file name in different subdirs + std::string filename = model.hf_repo + "_" + model.hf_file; + // to make sure we don't have any slashes in the filename + string_replace_all(filename, "/", "_"); + model.path = fs_get_cache_file(filename); + } + + } else if (!model.url.empty()) { + if (model.path.empty()) { + auto f = string_split(model.url, '#').front(); + f = string_split(f, '?').front(); + model.path = fs_get_cache_file(string_split(f, '/').back()); + } + + } else if (model.path.empty()) { + model.path = model_path_default; } - // make sure model path is present (for caching purposes) - if (model.empty()) { - // this is to avoid different repo having same file name, or same file name in different subdirs - std::string filename = hf_repo + "_" + hf_file; - // to make sure we don't have any slashes in the filename - string_replace_all(filename, "/", "_"); - model = fs_get_cache_file(filename); + } + + // then, download it if needed + if (!model.url.empty()) { + bool ok = common_download_model(model, bearer_token); + if (!ok) { + LOG_ERR("error: failed to download model from %s\n", model.url.c_str()); + exit(1); } - } else if (!model_url.empty()) { - if (model.empty()) { - auto f = string_split(model_url, '#').front(); - f = string_split(f, '?').front(); - model = fs_get_cache_file(string_split(f, '/').back()); - } - } else if (model.empty()) { - model = model_default; } } @@ -300,10 +820,16 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n"); } - // TODO: refactor model params in a common struct - common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token, DEFAULT_MODEL_PATH); - common_params_handle_model_default(params.speculative.model, params.speculative.model_url, params.speculative.hf_repo, params.speculative.hf_file, params.hf_token, ""); - common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token, ""); + common_params_handle_model(params.model, params.hf_token, DEFAULT_MODEL_PATH); + common_params_handle_model(params.speculative.model, params.hf_token, ""); + common_params_handle_model(params.vocoder.model, params.hf_token, ""); + + // allow --mmproj to be set from -hf + // assuming that mmproj is always in the same repo as text model + if (!params.model.hf_repo.empty() && ctx_arg.ex == LLAMA_EXAMPLE_LLAVA) { + params.mmproj.hf_repo = params.model.hf_repo; + } + common_params_handle_model(params.mmproj, params.hf_token, "", true); if (params.escape) { string_process_escapes(params.prompt); @@ -1561,7 +2087,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--mmproj"}, "FILE", "path to a multimodal projector file for LLaVA. see examples/llava/README.md", [](common_params & params, const std::string & value) { - params.mmproj = value; + params.mmproj.path = value; + } + ).set_examples({LLAMA_EXAMPLE_LLAVA})); + add_opt(common_arg( + {"--mmproj-url"}, "URL", + "URL to a multimodal projector file for LLaVA. see examples/llava/README.md", + [](common_params & params, const std::string & value) { + params.mmproj.url = value; } ).set_examples({LLAMA_EXAMPLE_LLAVA})); add_opt(common_arg( @@ -1790,14 +2323,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "or `--model-url` if set, otherwise %s)", DEFAULT_MODEL_PATH ), [](common_params & params, const std::string & value) { - params.model = value; + params.model.path = value; } ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}).set_env("LLAMA_ARG_MODEL")); add_opt(common_arg( {"-mu", "--model-url"}, "MODEL_URL", "model download url (default: unused)", [](common_params & params, const std::string & value) { - params.model_url = value; + params.model.url = value; } ).set_env("LLAMA_ARG_MODEL_URL")); add_opt(common_arg( @@ -1806,35 +2339,35 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "example: unsloth/phi-4-GGUF:q4_k_m\n" "(default: unused)", [](common_params & params, const std::string & value) { - params.hf_repo = value; + params.model.hf_repo = value; } ).set_env("LLAMA_ARG_HF_REPO")); add_opt(common_arg( {"-hfd", "-hfrd", "--hf-repo-draft"}, "/[:quant]", "Same as --hf-repo, but for the draft model (default: unused)", [](common_params & params, const std::string & value) { - params.speculative.hf_repo = value; + params.speculative.model.hf_repo = value; } ).set_env("LLAMA_ARG_HFD_REPO")); add_opt(common_arg( {"-hff", "--hf-file"}, "FILE", "Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)", [](common_params & params, const std::string & value) { - params.hf_file = value; + params.model.hf_file = value; } ).set_env("LLAMA_ARG_HF_FILE")); add_opt(common_arg( {"-hfv", "-hfrv", "--hf-repo-v"}, "/[:quant]", "Hugging Face model repository for the vocoder model (default: unused)", [](common_params & params, const std::string & value) { - params.vocoder.hf_repo = value; + params.vocoder.model.hf_repo = value; } ).set_env("LLAMA_ARG_HF_REPO_V")); add_opt(common_arg( {"-hffv", "--hf-file-v"}, "FILE", "Hugging Face model file for the vocoder model (default: unused)", [](common_params & params, const std::string & value) { - params.vocoder.hf_file = value; + params.vocoder.model.hf_file = value; } ).set_env("LLAMA_ARG_HF_FILE_V")); add_opt(common_arg( @@ -2454,7 +2987,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"-md", "--model-draft"}, "FNAME", "draft model for speculative decoding (default: unused)", [](common_params & params, const std::string & value) { - params.speculative.model = value; + params.speculative.model.path = value; } ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL_DRAFT")); @@ -2462,7 +2995,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"-mv", "--model-vocoder"}, "FNAME", "vocoder model for audio generation (default: unused)", [](common_params & params, const std::string & value) { - params.vocoder.model = value; + params.vocoder.model.path = value; } ).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( @@ -2485,10 +3018,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--tts-oute-default"}, string_format("use default OuteTTS models (note: can download weights from the internet)"), [](common_params & params) { - params.hf_repo = "OuteAI/OuteTTS-0.2-500M-GGUF"; - params.hf_file = "OuteTTS-0.2-500M-Q8_0.gguf"; - params.vocoder.hf_repo = "ggml-org/WavTokenizer"; - params.vocoder.hf_file = "WavTokenizer-Large-75-F16.gguf"; + params.model.hf_repo = "OuteAI/OuteTTS-0.2-500M-GGUF"; + params.model.hf_file = "OuteTTS-0.2-500M-Q8_0.gguf"; + params.vocoder.model.hf_repo = "ggml-org/WavTokenizer"; + params.vocoder.model.hf_file = "WavTokenizer-Large-75-F16.gguf"; } ).set_examples({LLAMA_EXAMPLE_TTS})); @@ -2496,8 +3029,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--embd-bge-small-en-default"}, string_format("use default bge-small-en-v1.5 model (note: can download weights from the internet)"), [](common_params & params) { - params.hf_repo = "ggml-org/bge-small-en-v1.5-Q8_0-GGUF"; - params.hf_file = "bge-small-en-v1.5-q8_0.gguf"; + params.model.hf_repo = "ggml-org/bge-small-en-v1.5-Q8_0-GGUF"; + params.model.hf_file = "bge-small-en-v1.5-q8_0.gguf"; params.pooling_type = LLAMA_POOLING_TYPE_NONE; params.embd_normalize = 2; params.n_ctx = 512; @@ -2510,8 +3043,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--embd-e5-small-en-default"}, string_format("use default e5-small-v2 model (note: can download weights from the internet)"), [](common_params & params) { - params.hf_repo = "ggml-org/e5-small-v2-Q8_0-GGUF"; - params.hf_file = "e5-small-v2-q8_0.gguf"; + params.model.hf_repo = "ggml-org/e5-small-v2-Q8_0-GGUF"; + params.model.hf_file = "e5-small-v2-q8_0.gguf"; params.pooling_type = LLAMA_POOLING_TYPE_NONE; params.embd_normalize = 2; params.n_ctx = 512; @@ -2524,8 +3057,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--embd-gte-small-default"}, string_format("use default gte-small model (note: can download weights from the internet)"), [](common_params & params) { - params.hf_repo = "ggml-org/gte-small-Q8_0-GGUF"; - params.hf_file = "gte-small-q8_0.gguf"; + params.model.hf_repo = "ggml-org/gte-small-Q8_0-GGUF"; + params.model.hf_file = "gte-small-q8_0.gguf"; params.pooling_type = LLAMA_POOLING_TYPE_NONE; params.embd_normalize = 2; params.n_ctx = 512; @@ -2538,8 +3071,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--fim-qwen-1.5b-default"}, string_format("use default Qwen 2.5 Coder 1.5B (note: can download weights from the internet)"), [](common_params & params) { - params.hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-Q8_0-GGUF"; - params.hf_file = "qwen2.5-coder-1.5b-q8_0.gguf"; + params.model.hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-Q8_0-GGUF"; + params.model.hf_file = "qwen2.5-coder-1.5b-q8_0.gguf"; params.port = 8012; params.n_gpu_layers = 99; params.flash_attn = true; @@ -2554,8 +3087,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--fim-qwen-3b-default"}, string_format("use default Qwen 2.5 Coder 3B (note: can download weights from the internet)"), [](common_params & params) { - params.hf_repo = "ggml-org/Qwen2.5-Coder-3B-Q8_0-GGUF"; - params.hf_file = "qwen2.5-coder-3b-q8_0.gguf"; + params.model.hf_repo = "ggml-org/Qwen2.5-Coder-3B-Q8_0-GGUF"; + params.model.hf_file = "qwen2.5-coder-3b-q8_0.gguf"; params.port = 8012; params.n_gpu_layers = 99; params.flash_attn = true; @@ -2570,8 +3103,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--fim-qwen-7b-default"}, string_format("use default Qwen 2.5 Coder 7B (note: can download weights from the internet)"), [](common_params & params) { - params.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF"; - params.hf_file = "qwen2.5-coder-7b-q8_0.gguf"; + params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF"; + params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf"; params.port = 8012; params.n_gpu_layers = 99; params.flash_attn = true; @@ -2586,10 +3119,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--fim-qwen-7b-spec"}, string_format("use Qwen 2.5 Coder 7B + 0.5B draft for speculative decoding (note: can download weights from the internet)"), [](common_params & params) { - params.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF"; - params.hf_file = "qwen2.5-coder-7b-q8_0.gguf"; - params.speculative.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF"; - params.speculative.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf"; + params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF"; + params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf"; + params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF"; + params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf"; params.speculative.n_gpu_layers = 99; params.port = 8012; params.n_gpu_layers = 99; @@ -2605,10 +3138,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--fim-qwen-14b-spec"}, string_format("use Qwen 2.5 Coder 14B + 0.5B draft for speculative decoding (note: can download weights from the internet)"), [](common_params & params) { - params.hf_repo = "ggml-org/Qwen2.5-Coder-14B-Q8_0-GGUF"; - params.hf_file = "qwen2.5-coder-14b-q8_0.gguf"; - params.speculative.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF"; - params.speculative.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf"; + params.model.hf_repo = "ggml-org/Qwen2.5-Coder-14B-Q8_0-GGUF"; + params.model.hf_file = "qwen2.5-coder-14b-q8_0.gguf"; + params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF"; + params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf"; params.speculative.n_gpu_layers = 99; params.port = 8012; params.n_gpu_layers = 99; diff --git a/common/common.cpp b/common/common.cpp index 18ffb4e73..22642c84a 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -51,45 +51,11 @@ #include #include #endif -#if defined(LLAMA_USE_CURL) -#include -#include -#include -#endif #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data #endif -#if defined(LLAMA_USE_CURL) -#ifdef __linux__ -#include -#elif defined(_WIN32) -# if !defined(PATH_MAX) -# define PATH_MAX MAX_PATH -# endif -#else -#include -#endif -#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 - -// -// CURL utils -// - -using curl_ptr = std::unique_ptr; - -// cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one -struct curl_slist_ptr { - struct curl_slist * ptr = nullptr; - ~curl_slist_ptr() { - if (ptr) { - curl_slist_free_all(ptr); - } - } -}; -#endif // LLAMA_USE_CURL - using json = nlohmann::ordered_json; // @@ -900,22 +866,14 @@ std::string fs_get_cache_file(const std::string & filename) { // // Model utils // + struct common_init_result common_init_from_params(common_params & params) { common_init_result iparams; auto mparams = common_model_params_to_llama(params); - llama_model * model = nullptr; - - if (!params.hf_repo.empty() && !params.hf_file.empty()) { - model = common_load_model_from_hf(params.hf_repo, params.hf_file, params.model, params.hf_token, mparams); - } else if (!params.model_url.empty()) { - model = common_load_model_from_url(params.model_url, params.model, params.hf_token, mparams); - } else { - model = llama_model_load_from_file(params.model.c_str(), mparams); - } - + llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams); if (model == NULL) { - LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.c_str()); + LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str()); return iparams; } @@ -950,7 +908,7 @@ struct common_init_result common_init_from_params(common_params & params) { llama_context * lctx = llama_init_from_model(model, cparams); if (lctx == NULL) { - LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str()); + LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str()); llama_model_free(model); return iparams; } @@ -1164,451 +1122,6 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p return tpp; } -#ifdef LLAMA_USE_CURL - -#define CURL_MAX_RETRY 3 -#define CURL_RETRY_DELAY_SECONDS 2 - -static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds) { - int remaining_attempts = max_attempts; - - while (remaining_attempts > 0) { - LOG_INF("%s: Trying to download from %s (attempt %d of %d)...\n", __func__ , url.c_str(), max_attempts - remaining_attempts + 1, max_attempts); - - CURLcode res = curl_easy_perform(curl); - if (res == CURLE_OK) { - return true; - } - - int exponential_backoff_delay = std::pow(retry_delay_seconds, max_attempts - remaining_attempts) * 1000; - LOG_WRN("%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay); - - remaining_attempts--; - std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay)); - } - - LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts); - - return false; -} - -static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) { - // Initialize libcurl - curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); - curl_slist_ptr http_headers; - if (!curl) { - LOG_ERR("%s: error initializing libcurl\n", __func__); - return false; - } - - bool force_download = false; - - // Set the URL, allow to follow http redirection - curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); - curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); - - // Check if hf-token or bearer-token was specified - if (!hf_token.empty()) { - std::string auth_header = "Authorization: Bearer " + hf_token; - http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); - curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); - } - -#if defined(_WIN32) - // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of - // operating system. Currently implemented under MS-Windows. - curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); -#endif - - // Check if the file already exists locally - auto file_exists = std::filesystem::exists(path); - - // If the file exists, check its JSON metadata companion file. - std::string metadata_path = path + ".json"; - nlohmann::json metadata; - std::string etag; - std::string last_modified; - - if (file_exists) { - // Try and read the JSON metadata file (note: stream autoclosed upon exiting this block). - std::ifstream metadata_in(metadata_path); - if (metadata_in.good()) { - try { - metadata_in >> metadata; - LOG_INF("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str()); - if (metadata.contains("url") && metadata.at("url").is_string()) { - auto previous_url = metadata.at("url").get(); - if (previous_url != url) { - LOG_ERR("%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str()); - return false; - } - } - if (metadata.contains("etag") && metadata.at("etag").is_string()) { - etag = metadata.at("etag"); - } - if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) { - last_modified = metadata.at("lastModified"); - } - } catch (const nlohmann::json::exception & e) { - LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); - return false; - } - } - } else { - LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); - } - - // Send a HEAD request to retrieve the etag and last-modified headers - struct common_load_model_from_url_headers { - std::string etag; - std::string last_modified; - }; - - common_load_model_from_url_headers headers; - - { - typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *); - auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t { - common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata; - - static std::regex header_regex("([^:]+): (.*)\r\n"); - static std::regex etag_regex("ETag", std::regex_constants::icase); - static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase); - - std::string header(buffer, n_items); - std::smatch match; - if (std::regex_match(header, match, header_regex)) { - const std::string & key = match[1]; - const std::string & value = match[2]; - if (std::regex_match(key, match, etag_regex)) { - headers->etag = value; - } else if (std::regex_match(key, match, last_modified_regex)) { - headers->last_modified = value; - } - } - return n_items; - }; - - curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb - curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress - curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast(header_callback)); - curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers); - - bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS); - if (!was_perform_successful) { - return false; - } - - long http_code = 0; - curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code); - if (http_code != 200) { - // HEAD not supported, we don't know if the file has changed - // force trigger downloading - force_download = true; - LOG_ERR("%s: HEAD invalid http status code received: %ld\n", __func__, http_code); - } - } - - bool should_download = !file_exists || force_download; - if (!should_download) { - if (!etag.empty() && etag != headers.etag) { - LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str()); - should_download = true; - } else if (!last_modified.empty() && last_modified != headers.last_modified) { - LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str()); - should_download = true; - } - } - if (should_download) { - std::string path_temporary = path + ".downloadInProgress"; - if (file_exists) { - LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str()); - if (remove(path.c_str()) != 0) { - LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str()); - return false; - } - } - - // Set the output file - - struct FILE_deleter { - void operator()(FILE * f) const { - fclose(f); - } - }; - - std::unique_ptr outfile(fopen(path_temporary.c_str(), "wb")); - if (!outfile) { - LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str()); - return false; - } - - typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd); - auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t { - return fwrite(data, size, nmemb, (FILE *)fd); - }; - curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L); - curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); - curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get()); - - // display download progress - curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L); - - // helper function to hide password in URL - auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string { - std::size_t protocol_pos = url.find("://"); - if (protocol_pos == std::string::npos) { - return url; // Malformed URL - } - - std::size_t at_pos = url.find('@', protocol_pos + 3); - if (at_pos == std::string::npos) { - return url; // No password in URL - } - - return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos); - }; - - // start the download - LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__, - llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str()); - bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS); - if (!was_perform_successful) { - return false; - } - - long http_code = 0; - curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code); - if (http_code < 200 || http_code >= 400) { - LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code); - return false; - } - - // Causes file to be closed explicitly here before we rename it. - outfile.reset(); - - // Write the updated JSON metadata file. - metadata.update({ - {"url", url}, - {"etag", headers.etag}, - {"lastModified", headers.last_modified} - }); - std::ofstream(metadata_path) << metadata.dump(4); - LOG_INF("%s: file metadata saved: %s\n", __func__, metadata_path.c_str()); - - if (rename(path_temporary.c_str(), path.c_str()) != 0) { - LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); - return false; - } - } - - return true; -} - -struct llama_model * common_load_model_from_url( - const std::string & model_url, - const std::string & local_path, - const std::string & hf_token, - const struct llama_model_params & params) { - // Basic validation of the model_url - if (model_url.empty()) { - LOG_ERR("%s: invalid model_url\n", __func__); - return NULL; - } - - if (!common_download_file(model_url, local_path, hf_token)) { - return NULL; - } - - // check for additional GGUFs split to download - int n_split = 0; - { - struct gguf_init_params gguf_params = { - /*.no_alloc = */ true, - /*.ctx = */ NULL, - }; - auto * ctx_gguf = gguf_init_from_file(local_path.c_str(), gguf_params); - if (!ctx_gguf) { - LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, local_path.c_str()); - return NULL; - } - - auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT); - if (key_n_split >= 0) { - n_split = gguf_get_val_u16(ctx_gguf, key_n_split); - } - - gguf_free(ctx_gguf); - } - - if (n_split > 1) { - char split_prefix[PATH_MAX] = {0}; - char split_url_prefix[LLAMA_CURL_MAX_URL_LENGTH] = {0}; - - // Verify the first split file format - // and extract split URL and PATH prefixes - { - if (!llama_split_prefix(split_prefix, sizeof(split_prefix), local_path.c_str(), 0, n_split)) { - LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, local_path.c_str(), n_split); - return NULL; - } - - if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model_url.c_str(), 0, n_split)) { - LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model_url.c_str(), n_split); - return NULL; - } - } - - // Prepare download in parallel - std::vector> futures_download; - for (int idx = 1; idx < n_split; idx++) { - futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split, hf_token](int download_idx) -> bool { - char split_path[PATH_MAX] = {0}; - llama_split_path(split_path, sizeof(split_path), split_prefix, download_idx, n_split); - - char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0}; - llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split); - - return common_download_file(split_url, split_path, hf_token); - }, idx)); - } - - // Wait for all downloads to complete - for (auto & f : futures_download) { - if (!f.get()) { - return NULL; - } - } - } - - return llama_model_load_from_file(local_path.c_str(), params); -} - -struct llama_model * common_load_model_from_hf( - const std::string & repo, - const std::string & remote_path, - const std::string & local_path, - const std::string & hf_token, - const struct llama_model_params & params) { - // construct hugging face model url: - // - // --repo ggml-org/models --file tinyllama-1.1b/ggml-model-f16.gguf - // https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf - // - // --repo TheBloke/Mixtral-8x7B-v0.1-GGUF --file mixtral-8x7b-v0.1.Q4_K_M.gguf - // https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q4_K_M.gguf - // - - std::string model_url = "https://huggingface.co/"; - model_url += repo; - model_url += "/resolve/main/"; - model_url += remote_path; - - return common_load_model_from_url(model_url, local_path, hf_token, params); -} - -/** - * Allow getting the HF file from the HF repo with tag (like ollama), for example: - * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4 - * - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M - * - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s - * Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo) - * - * Return pair of (with "repo" already having tag removed) - * - * Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files. - */ -std::pair common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & hf_token) { - auto parts = string_split(hf_repo_with_tag, ':'); - std::string tag = parts.size() > 1 ? parts.back() : "latest"; - std::string hf_repo = parts[0]; - if (string_split(hf_repo, '/').size() != 2) { - throw std::invalid_argument("error: invalid HF repo format, expected /[:quant]\n"); - } - - // fetch model info from Hugging Face Hub API - json model_info; - curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); - curl_slist_ptr http_headers; - std::string res_str; - std::string url = "https://huggingface.co/v2/" + hf_repo + "/manifests/" + tag; - curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); - curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); - typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data); - auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t { - static_cast(data)->append((char * ) ptr, size * nmemb); - return size * nmemb; - }; - curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); - curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str); -#if defined(_WIN32) - curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); -#endif - if (!hf_token.empty()) { - std::string auth_header = "Authorization: Bearer " + hf_token; - http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); - } - // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response - http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); - http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json"); - curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); - - CURLcode res = curl_easy_perform(curl.get()); - - if (res != CURLE_OK) { - throw std::runtime_error("error: cannot make GET request to HF API"); - } - - long res_code; - curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code); - if (res_code == 200) { - model_info = json::parse(res_str); - } else if (res_code == 401) { - throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token"); - } else { - throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str())); - } - - // check response - if (!model_info.contains("ggufFile")) { - throw std::runtime_error("error: model does not have ggufFile"); - } - json & gguf_file = model_info.at("ggufFile"); - if (!gguf_file.contains("rfilename")) { - throw std::runtime_error("error: ggufFile does not have rfilename"); - } - - return std::make_pair(hf_repo, gguf_file.at("rfilename")); -} - -#else - -struct llama_model * common_load_model_from_url( - const std::string & /*model_url*/, - const std::string & /*local_path*/, - const std::string & /*hf_token*/, - const struct llama_model_params & /*params*/) { - LOG_WRN("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__); - return nullptr; -} - -struct llama_model * common_load_model_from_hf( - const std::string & /*repo*/, - const std::string & /*remote_path*/, - const std::string & /*local_path*/, - const std::string & /*hf_token*/, - const struct llama_model_params & /*params*/) { - LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__); - return nullptr; -} - -std::pair common_get_hf_file(const std::string &, const std::string &) { - LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__); - return std::make_pair("", ""); -} - -#endif // LLAMA_USE_CURL - // // Batch utils // diff --git a/common/common.h b/common/common.h index 1c0f19977..41ff9905e 100644 --- a/common/common.h +++ b/common/common.h @@ -184,6 +184,13 @@ struct common_params_sampling { std::string print() const; }; +struct common_params_model { + std::string path = ""; // model local path // NOLINT + std::string url = ""; // model url to download // NOLINT + std::string hf_repo = ""; // HF repo // NOLINT + std::string hf_file = ""; // HF file // NOLINT +}; + struct common_params_speculative { std::vector devices; // devices to use for offloading @@ -197,19 +204,11 @@ struct common_params_speculative { struct cpu_params cpuparams; struct cpu_params cpuparams_batch; - std::string hf_repo = ""; // HF repo // NOLINT - std::string hf_file = ""; // HF file // NOLINT - - std::string model = ""; // draft model for speculative decoding // NOLINT - std::string model_url = ""; // model url to download // NOLINT + struct common_params_model model; }; struct common_params_vocoder { - std::string hf_repo = ""; // HF repo // NOLINT - std::string hf_file = ""; // HF file // NOLINT - - std::string model = ""; // model path // NOLINT - std::string model_url = ""; // model url to download // NOLINT + struct common_params_model model; std::string speaker_file = ""; // speaker file path // NOLINT @@ -267,12 +266,10 @@ struct common_params { struct common_params_speculative speculative; struct common_params_vocoder vocoder; - std::string model = ""; // model path // NOLINT + struct common_params_model model; + std::string model_alias = ""; // model alias // NOLINT - std::string model_url = ""; // model url to download // NOLINT std::string hf_token = ""; // HF token // NOLINT - std::string hf_repo = ""; // HF repo // NOLINT - std::string hf_file = ""; // HF file // NOLINT std::string prompt = ""; // NOLINT std::string system_prompt = ""; // NOLINT std::string prompt_file = ""; // store the external prompt file name // NOLINT @@ -347,7 +344,7 @@ struct common_params { common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO; // multimodal models (see examples/llava) - std::string mmproj = ""; // path to multimodal projector // NOLINT + struct common_params_model mmproj; std::vector image; // path to image file(s) // embedding @@ -546,23 +543,6 @@ struct llama_model_params common_model_params_to_llama ( common_params struct llama_context_params common_context_params_to_llama(const common_params & params); struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params); -struct llama_model * common_load_model_from_url( - const std::string & model_url, - const std::string & local_path, - const std::string & hf_token, - const struct llama_model_params & params); - -struct llama_model * common_load_model_from_hf( - const std::string & repo, - const std::string & remote_path, - const std::string & local_path, - const std::string & hf_token, - const struct llama_model_params & params); - -std::pair common_get_hf_file( - const std::string & hf_repo_with_tag, - const std::string & hf_token); - // clear LoRA adapters from context, then apply new list of adapters void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora); diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 430e8be51..0f4019293 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -38,7 +38,7 @@ int main(int argc, char ** argv) { llama_model_params model_params = common_model_params_to_llama(params); - llama_model * model = llama_model_load_from_file(params.model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params); if (model == NULL) { fprintf(stderr , "%s: error: unable to load model\n" , __func__); diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 21b95ef5e..1a5de5928 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -41,7 +41,7 @@ int main(int argc, char ** argv) { llama_model_params model_params = common_model_params_to_llama(params); - llama_model * model = llama_model_load_from_file(params.model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params); if (model == NULL) { LOG_ERR("%s: error: unable to load model\n" , __func__); diff --git a/examples/export-lora/export-lora.cpp b/examples/export-lora/export-lora.cpp index e7d0fbfff..24dc85cf2 100644 --- a/examples/export-lora/export-lora.cpp +++ b/examples/export-lora/export-lora.cpp @@ -421,7 +421,7 @@ int main(int argc, char ** argv) { g_verbose = (params.verbosity > 1); try { - lora_merge_ctx ctx(params.model, params.lora_adapters, params.out_file, params.cpuparams.n_threads); + lora_merge_ctx ctx(params.model.path, params.lora_adapters, params.out_file, params.cpuparams.n_threads); ctx.run_merge(); } catch (const std::exception & err) { fprintf(stderr, "%s\n", err.what()); diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index f7db7861c..539bc4d60 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -168,7 +168,7 @@ int main(int argc, char * argv[]) { llama_backend_init(); - llama_model * model = llama_model_load_from_file(params.model.c_str(), mparams); + llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams); // create generation context llama_context * ctx = llama_init_from_model(model, cparams); diff --git a/examples/llava/README-gemma3.md b/examples/llava/README-gemma3.md index 20bf73fb5..3c25ee258 100644 --- a/examples/llava/README-gemma3.md +++ b/examples/llava/README-gemma3.md @@ -4,6 +4,26 @@ > > This is very experimental, only used for demo purpose. +## Quick started + +You can use pre-quantized model from [ggml-org](https://huggingface.co/ggml-org)'s Hugging Face account + +```bash +# build +cmake -B build +cmake --build build --target llama-gemma3-cli + +# alternatively, install from brew (MacOS) +brew install llama.cpp + +# run it +llama-gemma3-cli -hf ggml-org/gemma-3-4b-it-GGUF +llama-gemma3-cli -hf ggml-org/gemma-3-12b-it-GGUF +llama-gemma3-cli -hf ggml-org/gemma-3-27b-it-GGUF + +# note: 1B model does not support vision +``` + ## How to get mmproj.gguf? ```bash diff --git a/examples/llava/gemma3-cli.cpp b/examples/llava/gemma3-cli.cpp index c36bb2eda..7813ac19f 100644 --- a/examples/llava/gemma3-cli.cpp +++ b/examples/llava/gemma3-cli.cpp @@ -78,7 +78,7 @@ struct gemma3_context { } void init_clip_model(common_params & params) { - const char * clip_path = params.mmproj.c_str(); + const char * clip_path = params.mmproj.path.c_str(); ctx_clip = clip_model_load(clip_path, params.verbosity > 1); } @@ -232,13 +232,13 @@ int main(int argc, char ** argv) { common_init(); - if (params.mmproj.empty()) { + if (params.mmproj.path.empty()) { show_additional_info(argc, argv); return 1; } gemma3_context ctx(params); - printf("%s: %s\n", __func__, params.model.c_str()); + printf("%s: %s\n", __func__, params.model.path.c_str()); bool is_single_turn = !params.prompt.empty() && !params.image.empty(); diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 40aa0876f..a15131343 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -225,7 +225,7 @@ static struct llama_model * llava_init(common_params * params) { llama_model_params model_params = common_model_params_to_llama(*params); - llama_model * model = llama_model_load_from_file(params->model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params->model.path.c_str(), model_params); if (model == NULL) { LOG_ERR("%s: unable to load model\n" , __func__); return NULL; @@ -234,7 +234,7 @@ static struct llama_model * llava_init(common_params * params) { } static struct llava_context * llava_init_context(common_params * params, llama_model * model) { - const char * clip_path = params->mmproj.c_str(); + const char * clip_path = params->mmproj.path.c_str(); auto prompt = params->prompt; if (prompt.empty()) { @@ -283,7 +283,7 @@ int main(int argc, char ** argv) { common_init(); - if (params.mmproj.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) { + if (params.mmproj.path.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) { print_usage(argc, argv); return 1; } diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index 12f536cf5..48fddeaa4 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -31,7 +31,7 @@ static struct llama_model * llava_init(common_params * params) { llama_model_params model_params = common_model_params_to_llama(*params); - llama_model * model = llama_model_load_from_file(params->model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params->model.path.c_str(), model_params); if (model == NULL) { LOG_ERR("%s: unable to load model\n" , __func__); return NULL; @@ -80,7 +80,7 @@ static void llava_free(struct llava_context * ctx_llava) { } static struct clip_ctx * clip_init_context(common_params * params) { - const char * clip_path = params->mmproj.c_str(); + const char * clip_path = params->mmproj.path.c_str(); auto prompt = params->prompt; if (prompt.empty()) { @@ -290,7 +290,7 @@ int main(int argc, char ** argv) { common_init(); - if (params.mmproj.empty() || (params.image.empty())) { + if (params.mmproj.path.empty() || (params.image.empty())) { show_additional_info(argc, argv); return 1; } diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index 132a7da54..c6481e482 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -314,7 +314,7 @@ static struct llama_model * llava_init(common_params * params) { llama_model_params model_params = common_model_params_to_llama(*params); - llama_model * model = llama_model_load_from_file(params->model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params->model.path.c_str(), model_params); if (model == NULL) { LOG_ERR("%s: unable to load model\n" , __func__); return NULL; @@ -323,7 +323,7 @@ static struct llama_model * llava_init(common_params * params) { } static struct llava_context * llava_init_context(common_params * params, llama_model * model) { - const char * clip_path = params->mmproj.c_str(); + const char * clip_path = params->mmproj.path.c_str(); auto prompt = params->prompt; if (prompt.empty()) { @@ -524,7 +524,7 @@ int main(int argc, char ** argv) { common_init(); - if (params.mmproj.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) { + if (params.mmproj.path.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) { print_usage(argc, argv); return 1; } diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 588632f04..e0e6da631 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -405,7 +405,7 @@ int main(int argc, char ** argv) { params.prompt_file = "used built-in defaults"; } LOG_INF("External prompt file: \033[32m%s\033[0m\n", params.prompt_file.c_str()); - LOG_INF("Model and path used: \033[32m%s\033[0m\n\n", params.model.c_str()); + LOG_INF("Model and path used: \033[32m%s\033[0m\n\n", params.model.path.c_str()); LOG_INF("Total prompt tokens: %6d, speed: %5.2f t/s\n", n_total_prompt, (double) (n_total_prompt ) / (t_main_end - t_main_start) * 1e6); LOG_INF("Total gen tokens: %6d, speed: %5.2f t/s\n", n_total_gen, (double) (n_total_gen ) / (t_main_end - t_main_start) * 1e6); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index ea3a6c1fc..347ea4a69 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -64,7 +64,7 @@ int main(int argc, char ** argv) { llama_model_params model_params = common_model_params_to_llama(params); - llama_model * model = llama_model_load_from_file(params.model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params); if (model == NULL) { LOG_ERR("%s: unable to load model\n" , __func__); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 17a292da1..d140f8c44 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1876,7 +1876,7 @@ struct server_context { } bool load_model(const common_params & params) { - SRV_INF("loading model '%s'\n", params.model.c_str()); + SRV_INF("loading model '%s'\n", params.model.path.c_str()); params_base = params; @@ -1886,7 +1886,7 @@ struct server_context { ctx = llama_init.context.get(); if (model == nullptr) { - SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str()); + SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); return false; } @@ -1897,16 +1897,13 @@ struct server_context { add_bos_token = llama_vocab_get_add_bos(vocab); has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL; - if (!params_base.speculative.model.empty() || !params_base.speculative.hf_repo.empty()) { - SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str()); + if (!params_base.speculative.model.path.empty() || !params_base.speculative.model.hf_repo.empty()) { + SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str()); auto params_dft = params_base; params_dft.devices = params_base.speculative.devices; - params_dft.hf_file = params_base.speculative.hf_file; - params_dft.hf_repo = params_base.speculative.hf_repo; params_dft.model = params_base.speculative.model; - params_dft.model_url = params_base.speculative.model_url; params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx; params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; params_dft.n_parallel = 1; @@ -1920,12 +1917,12 @@ struct server_context { model_dft = llama_init_dft.model.get(); if (model_dft == nullptr) { - SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.c_str()); + SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str()); return false; } if (!common_speculative_are_compatible(ctx, llama_init_dft.context.get())) { - SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.c_str(), params_base.model.c_str()); + SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str()); return false; } @@ -3865,7 +3862,7 @@ int main(int argc, char ** argv) { json data = { { "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "total_slots", ctx_server.params_base.n_parallel }, - { "model_path", ctx_server.params_base.model }, + { "model_path", ctx_server.params_base.model.path }, { "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) }, { "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)}, { "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)}, @@ -4131,7 +4128,7 @@ int main(int argc, char ** argv) { {"object", "list"}, {"data", { { - {"id", params.model_alias.empty() ? params.model : params.model_alias}, + {"id", params.model_alias.empty() ? params.model.path : params.model_alias}, {"object", "model"}, {"created", std::time(0)}, {"owned_by", "llamacpp"}, diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index a5d2bc9d0..0783ed4a4 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -24,7 +24,7 @@ int main(int argc, char ** argv) { common_init(); - if (params.speculative.model.empty()) { + if (params.speculative.model.path.empty()) { LOG_ERR("%s: --model-draft is required\n", __func__); return 1; } diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 627d01bbc..561c30883 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -46,7 +46,7 @@ int main(int argc, char ** argv) { common_init(); - if (params.speculative.model.empty()) { + if (params.speculative.model.path.empty()) { LOG_ERR("%s: --model-draft is required\n", __func__); return 1; } diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp index c7ac94cc5..0f0479869 100644 --- a/examples/tts/tts.cpp +++ b/examples/tts/tts.cpp @@ -577,12 +577,7 @@ int main(int argc, char ** argv) { const llama_vocab * vocab = llama_model_get_vocab(model_ttc); - // TODO: refactor in a common struct - params.model = params.vocoder.model; - params.model_url = params.vocoder.model_url; - params.hf_repo = params.vocoder.hf_repo; - params.hf_file = params.vocoder.hf_file; - + params.model = params.vocoder.model; params.embedding = true; common_init_result llama_init_cts = common_init_from_params(params); diff --git a/tests/test-arg-parser.cpp b/tests/test-arg-parser.cpp index 69604b87c..537fc63a4 100644 --- a/tests/test-arg-parser.cpp +++ b/tests/test-arg-parser.cpp @@ -77,7 +77,7 @@ int main(void) { argv = {"binary_name", "-m", "model_file.gguf"}; assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON)); - assert(params.model == "model_file.gguf"); + assert(params.model.path == "model_file.gguf"); argv = {"binary_name", "-t", "1234"}; assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON)); @@ -89,7 +89,7 @@ int main(void) { argv = {"binary_name", "-m", "abc.gguf", "--predict", "6789", "--batch-size", "9090"}; assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON)); - assert(params.model == "abc.gguf"); + assert(params.model.path == "abc.gguf"); assert(params.n_predict == 6789); assert(params.n_batch == 9090); @@ -112,7 +112,7 @@ int main(void) { setenv("LLAMA_ARG_THREADS", "1010", true); argv = {"binary_name"}; assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON)); - assert(params.model == "blah.gguf"); + assert(params.model.path == "blah.gguf"); assert(params.cpuparams.n_threads == 1010); @@ -122,7 +122,7 @@ int main(void) { setenv("LLAMA_ARG_THREADS", "1010", true); argv = {"binary_name", "-m", "overwritten.gguf"}; assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON)); - assert(params.model == "overwritten.gguf"); + assert(params.model.path == "overwritten.gguf"); assert(params.cpuparams.n_threads == 1010); #endif // _WIN32