mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-04-26 08:36:07 +00:00
vulkan: Use fp16 for the flash attention P*V multiplication (#12783)
This is consistent with the ggml-cuda behavior and the mul_mat fallback.
This commit is contained in:
parent
7538246e7c
commit
7ecd780b1a
@ -330,9 +330,11 @@ void main() {
|
|||||||
// resize eM by using smear/reduce
|
// resize eM by using smear/reduce
|
||||||
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
||||||
|
|
||||||
O = eMdiag * O;
|
// multiply with fp16 accumulation, then add to O.
|
||||||
|
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
|
||||||
|
PV = coopMatMulAdd(P_A, V, PV);
|
||||||
|
|
||||||
O = coopMatMulAdd(P_A, V, O);
|
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(PV);
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there is split_k, then the split_k resolve shader does the final
|
// If there is split_k, then the split_k resolve shader does the final
|
||||||
|
Loading…
x
Reference in New Issue
Block a user