mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-04-16 03:26:08 +00:00
context : always use non-causal attention for encoder graphs (#12447)
* context : always use non-causal attention for encoder graphs ggml-ci * context : move the change to llama_context::encode() ggml-ci
This commit is contained in:
parent
35cae5ba05
commit
8551c44d84
@ -1057,6 +1057,13 @@ int llama_context::encode(llama_batch & inp_batch) {
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||
|
||||
const auto causal_attn_org = cparams.causal_attn;
|
||||
|
||||
// always use non-causal attention for encoder graphs
|
||||
// TODO: this is a tmp solution until we have a proper way to support enc-dec models
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
|
||||
cparams.causal_attn = false;
|
||||
|
||||
auto * gf = graph_init();
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
|
||||
|
||||
@ -1064,6 +1071,8 @@ int llama_context::encode(llama_batch & inp_batch) {
|
||||
|
||||
res->set_inputs(&ubatch);
|
||||
|
||||
cparams.causal_attn = causal_attn_org;
|
||||
|
||||
const auto compute_status = graph_compute(gf, n_tokens > 1);
|
||||
switch (compute_status) {
|
||||
case GGML_STATUS_SUCCESS:
|
||||
|
Loading…
x
Reference in New Issue
Block a user