upgrade to llguidance 0.7.10 (#12576)

This commit is contained in:
Michał Moskal 2025-03-26 11:06:09 -07:00 committed by GitHub
parent 02082f1519
commit 2447ad8a98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 94 additions and 49 deletions

View File

@ -114,8 +114,8 @@ if (LLAMA_LLGUIDANCE)
ExternalProject_Add(llguidance_ext
GIT_REPOSITORY https://github.com/guidance-ai/llguidance
# v0.6.12:
GIT_TAG ced1c9023d47ec194fa977932d35ce65c2ebfc09
# v0.7.10:
GIT_TAG 0309d2a6bf40abda35344a362edc71e06d5009f8
PREFIX ${CMAKE_BINARY_DIR}/llguidance
SOURCE_DIR ${LLGUIDANCE_SRC}
BUILD_IN_SOURCE TRUE

View File

@ -11,25 +11,24 @@ struct llama_sampler_llg {
std::string grammar_kind;
std::string grammar_data;
LlgTokenizer * tokenizer;
LlgConstraint * grammar;
LlgMaskResult llg_res;
bool has_llg_res;
LlgMatcher * grammar;
};
static LlgConstraint * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
const char * grammar_data) {
static LlgMatcher * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
const char * grammar_data) {
LlgConstraintInit cinit;
llg_constraint_init_set_defaults(&cinit, tokenizer);
const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL");
if (log_level && *log_level) {
cinit.log_stderr_level = atoi(log_level);
}
auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data);
if (llg_get_error(c)) {
LOG_ERR("llg error: %s\n", llg_get_error(c));
llg_free_constraint(c);
auto c = llg_new_matcher(&cinit, grammar_kind, grammar_data);
if (llg_matcher_get_error(c)) {
LOG_ERR("llg error: %s\n", llg_matcher_get_error(c));
llg_free_matcher(c);
return nullptr;
}
return c;
}
@ -40,39 +39,29 @@ static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) {
static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) {
auto * ctx = (llama_sampler_llg *) smpl->ctx;
if (ctx->grammar) {
LlgCommitResult res;
llg_commit_token(ctx->grammar, token, &res);
ctx->has_llg_res = false;
llg_matcher_consume_token(ctx->grammar, token);
}
}
static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_llg *) smpl->ctx;
if (ctx->grammar) {
if (!ctx->has_llg_res) {
if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) {
ctx->has_llg_res = true;
const uint32_t * mask = llg_matcher_get_mask(ctx->grammar);
if (mask == nullptr) {
if (llg_matcher_compute_mask(ctx->grammar) == 0) {
mask = llg_matcher_get_mask(ctx->grammar);
} else {
LOG_ERR("llg error: %s\n", llg_get_error(ctx->grammar));
llg_free_constraint(ctx->grammar);
LOG_ERR("llg error: %s\n", llg_matcher_get_error(ctx->grammar));
llg_free_matcher(ctx->grammar);
ctx->grammar = nullptr;
return;
}
}
if (ctx->has_llg_res) {
if (ctx->llg_res.is_stop) {
for (size_t i = 0; i < cur_p->size; ++i) {
if (!llama_vocab_is_eog(ctx->vocab, cur_p->data[i].id)) {
cur_p->data[i].logit = -INFINITY;
}
}
} else {
const uint32_t * mask = ctx->llg_res.sample_mask;
for (size_t i = 0; i < cur_p->size; ++i) {
auto token = cur_p->data[i].id;
if ((mask[token / 32] & (1 << (token % 32))) == 0) {
cur_p->data[i].logit = -INFINITY;
}
}
for (size_t i = 0; i < cur_p->size; ++i) {
auto token = cur_p->data[i].id;
if ((mask[token / 32] & (1 << (token % 32))) == 0) {
cur_p->data[i].logit = -INFINITY;
}
}
}
@ -80,14 +69,9 @@ static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array
static void llama_sampler_llg_reset(llama_sampler * smpl) {
auto * ctx = (llama_sampler_llg *) smpl->ctx;
if (!ctx->grammar) {
return;
if (ctx->grammar) {
llg_matcher_reset(ctx->grammar);
}
auto * grammar_new = llama_sampler_llg_new(ctx->tokenizer, ctx->grammar_kind.c_str(), ctx->grammar_data.c_str());
llg_free_constraint(ctx->grammar);
ctx->grammar = grammar_new;
ctx->has_llg_res = false;
}
static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
@ -102,7 +86,7 @@ static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
if (ctx->grammar) {
result_ctx->grammar_kind = ctx->grammar_kind;
result_ctx->grammar_data = ctx->grammar_data;
result_ctx->grammar = llg_clone_constraint(ctx->grammar);
result_ctx->grammar = llg_clone_matcher(ctx->grammar);
result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer);
}
}
@ -114,7 +98,7 @@ static void llama_sampler_llg_free(llama_sampler * smpl) {
const auto * ctx = (llama_sampler_llg *) smpl->ctx;
if (ctx->grammar) {
llg_free_constraint(ctx->grammar);
llg_free_matcher(ctx->grammar);
llg_free_tokenizer(ctx->tokenizer);
}
@ -239,9 +223,11 @@ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * g
/* .grammar_data = */ grammar_data,
/* .tokenizer = */ tokenizer,
/* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data),
/* .llg_res = */ {},
/* .has_llg_res = */ false,
};
if (ctx->grammar) {
GGML_ASSERT(((size_t) llama_vocab_n_tokens(vocab) + 31) / 32 * 4 ==
llg_matcher_get_mask_byte_size(ctx->grammar));
}
} else {
*ctx = {
/* .vocab = */ vocab,
@ -249,15 +235,12 @@ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * g
/* .grammar_data = */ {},
/* .tokenizer = */ nullptr,
/* .grammar = */ nullptr,
/* .llg_res = */ {},
/* .has_llg_res = */ false,
};
}
return llama_sampler_init(
/* .iface = */ &llama_sampler_llg_i,
/* .ctx = */ ctx
);
/* .ctx = */ ctx);
}
#else

View File

@ -1086,6 +1086,65 @@ static void test_json_schema() {
});
}
static void one_hot(llama_token_data_array & tok_arr, llama_token selected) {
auto n_vocab = tok_arr.size;
tok_arr.selected = -1;
tok_arr.sorted = false;
for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
tok_arr.data[token_id].id = token_id;
tok_arr.data[token_id].logit = 0.0f;
}
tok_arr.data[selected].logit = 100.0f;
}
static void test_sampler_chain(void) {
auto sparams = llama_sampler_chain_default_params();
sparams.no_perf = false;
llama_sampler * sampler = llama_sampler_chain_init(sparams);
const auto grammar_data = R"(%llguidance {}
start: /[A-Z ]*/)";
llama_sampler_chain_add(sampler, llama_sampler_init_llg(vocab, "lark", grammar_data));
llama_sampler_chain_add(sampler, llama_sampler_init_dist(42));
auto input = "ALL YOUR BASE ARE BELONG TO US";
auto tokens = common_tokenize(vocab, input, false, false);
auto n_vocab = llama_vocab_n_tokens(vocab);
std::vector<llama_token_data> cur;
cur.reserve(n_vocab);
for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
cur.emplace_back(llama_token_data{ token_id, 0.0f, 0.0f });
}
auto tok_arr = llama_token_data_array{ cur.data(), cur.size(), -1, false };
for (const auto token : tokens) {
one_hot(tok_arr, token);
fprintf(stderr, "applying token: %d\n", token);
llama_sampler_apply(sampler, &tok_arr);
auto idx = tok_arr.selected;
fprintf(stderr, " -> %d %f\n", cur[idx].id, cur[idx].logit);
assert(cur[tok_arr.selected].id == token);
llama_sampler_accept(sampler, token);
}
auto tok_eos = llama_vocab_eot(vocab);
if (tok_eos == LLAMA_TOKEN_NULL) {
tok_eos = llama_vocab_eos(vocab);
}
one_hot(tok_arr, tok_eos);
llama_sampler_apply(sampler, &tok_arr);
assert(cur[tok_arr.selected].id == tok_eos);
}
int main(int argc, const char ** argv) {
fprintf(stdout, "Running llguidance integration tests...\n");
@ -1135,6 +1194,9 @@ int main(int argc, const char ** argv) {
test_special_chars();
test_quantifiers();
test_json_schema();
test_sampler_chain();
fprintf(stdout, "All tests passed.\n");
return 0;
}