mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-04-14 10:36:07 +00:00
llama : fix FA when KV cache is not used (i.e. embeddings) (#12825)
* ggml : FA supports F32 V * graph : cast KV to F16 when the KV cache is not used ggml-ci * server : add test that exercises embeddings with FA enabled ggml-ci
This commit is contained in:
parent
78a1ba0a4f
commit
a19b5cef16
@ -49,6 +49,26 @@ def test_embedding_multiple():
|
||||
assert len(d['embedding']) > 1
|
||||
|
||||
|
||||
def test_embedding_multiple_with_fa():
|
||||
server = ServerPreset.bert_bge_small_with_fa()
|
||||
server.pooling = 'last'
|
||||
server.start()
|
||||
# one of these should trigger the FA branch (i.e. context size % 256 == 0)
|
||||
res = server.make_request("POST", "/v1/embeddings", data={
|
||||
"input": [
|
||||
"a "*253,
|
||||
"b "*254,
|
||||
"c "*255,
|
||||
"d "*256,
|
||||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body['data']) == 4
|
||||
for d in res.body['data']:
|
||||
assert 'embedding' in d
|
||||
assert len(d['embedding']) > 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input,is_multi_prompt",
|
||||
[
|
||||
|
@ -323,6 +323,21 @@ class ServerPreset:
|
||||
server.server_embeddings = True
|
||||
return server
|
||||
|
||||
@staticmethod
|
||||
def bert_bge_small_with_fa() -> ServerProcess:
|
||||
server = ServerProcess()
|
||||
server.model_hf_repo = "ggml-org/models"
|
||||
server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
|
||||
server.model_alias = "bert-bge-small"
|
||||
server.n_ctx = 1024
|
||||
server.n_batch = 300
|
||||
server.n_ubatch = 300
|
||||
server.n_slots = 2
|
||||
server.fa = True
|
||||
server.seed = 42
|
||||
server.server_embeddings = True
|
||||
return server
|
||||
|
||||
@staticmethod
|
||||
def tinyllama_infill() -> ServerProcess:
|
||||
server = ServerProcess()
|
||||
|
@ -15,7 +15,7 @@ async def main():
|
||||
model_url = "http://127.0.0.1:6900"
|
||||
responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
|
||||
url= f"{model_url}/embedding",
|
||||
json= {"content": str(0)*1024}
|
||||
json= {"content": "a "*1022}
|
||||
) for i in range(n)])
|
||||
|
||||
for response in responses:
|
||||
|
@ -6721,8 +6721,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
|
||||
ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
|
||||
|
||||
GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
|
||||
GGML_ASSERT(v_to_float && "fattn: unsupported V-type");
|
||||
GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
|
||||
GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
|
||||
|
||||
// loop over n_batch and n_head
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
@ -6818,10 +6818,14 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
vs = expf(s - M);
|
||||
}
|
||||
|
||||
v_to_float(v_data, V32, DV);
|
||||
|
||||
// V += v*expf(s - M)
|
||||
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
|
||||
if (v_to_float) {
|
||||
v_to_float(v_data, V32, DV);
|
||||
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
|
||||
} else {
|
||||
// V is F32
|
||||
ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs);
|
||||
}
|
||||
}
|
||||
|
||||
S = S*ms + vs; // scale and increment sum with partial sum
|
||||
|
@ -1345,6 +1345,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
case GGML_OP_ARANGE:
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
if (op->src[0]->ne[0] == 32) {
|
||||
// head size == 32 (e.g. bert-bge-small)
|
||||
// TODO: not sure if it is worth adding kernels for this size
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->type != op->src[2]->type) {
|
||||
return false;
|
||||
}
|
||||
|
@ -1215,6 +1215,15 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
v = ggml_transpose(ctx0, v);
|
||||
}
|
||||
|
||||
// this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
|
||||
if (k->type == GGML_TYPE_F32) {
|
||||
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
|
||||
}
|
||||
|
||||
if (v->type == GGML_TYPE_F32) {
|
||||
v = ggml_cast(ctx0, v, GGML_TYPE_F16);
|
||||
}
|
||||
|
||||
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
||||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user