mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-04-27 12:56:08 +00:00

nem1 must be a multiple of GGML_KQ_MASK_PAD, and GGML_KQ_MASK_PAD is a multiple of the number of rows in the matrix. The KV dim is a multiple of the number of columns for the aligned shader.
376 lines
14 KiB
Plaintext
376 lines
14 KiB
Plaintext
#version 450
|
|
|
|
#extension GL_EXT_control_flow_attributes : enable
|
|
#extension GL_EXT_shader_16bit_storage : require
|
|
|
|
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
|
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
|
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
|
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
|
|
|
|
#extension GL_KHR_memory_scope_semantics : enable
|
|
#extension GL_KHR_cooperative_matrix : enable
|
|
#extension GL_NV_cooperative_matrix2 : enable
|
|
#extension GL_EXT_buffer_reference : enable
|
|
#extension GL_KHR_shader_subgroup_ballot : enable
|
|
#extension GL_KHR_shader_subgroup_vote : enable
|
|
#extension GL_EXT_null_initializer : enable
|
|
|
|
#include "types.comp"
|
|
#include "dequant_funcs_cm2.comp"
|
|
|
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
layout (constant_id = 1) const uint32_t Br = 32;
|
|
layout (constant_id = 2) const uint32_t Bc = 32;
|
|
layout (constant_id = 3) const uint32_t D = 32;
|
|
layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV;
|
|
|
|
layout (push_constant) uniform parameter {
|
|
uint32_t N;
|
|
uint32_t KV;
|
|
|
|
uint32_t ne1;
|
|
uint32_t ne2;
|
|
uint32_t ne3;
|
|
|
|
uint32_t neq2;
|
|
uint32_t neq3;
|
|
uint32_t nek2;
|
|
uint32_t nek3;
|
|
uint32_t nev2;
|
|
uint32_t nev3;
|
|
uint32_t nem1;
|
|
|
|
uint32_t nb01;
|
|
uint32_t nb02;
|
|
uint32_t nb03;
|
|
uint32_t nb11;
|
|
uint32_t nb12;
|
|
uint32_t nb13;
|
|
uint32_t nb21;
|
|
uint32_t nb22;
|
|
uint32_t nb23;
|
|
uint32_t nb31;
|
|
|
|
float scale;
|
|
float max_bias;
|
|
float logit_softcap;
|
|
|
|
uint32_t mask;
|
|
uint32_t n_head_log2;
|
|
float m0;
|
|
float m1;
|
|
|
|
uint32_t gqa_ratio;
|
|
uint32_t split_kv;
|
|
uint32_t k_num;
|
|
} p;
|
|
|
|
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
|
|
layout (binding = 1) readonly buffer K {uint8_t data_k[];};
|
|
layout (binding = 2) readonly buffer V {uint8_t data_v[];};
|
|
layout (binding = 3) readonly buffer M {uint8_t data_m[];};
|
|
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
|
|
|
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
|
|
|
ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
|
|
return max(x, y);
|
|
}
|
|
|
|
ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
|
|
return x;
|
|
}
|
|
|
|
// Replace matrix elements >= numRows or numCols with 'replace'
|
|
ACC_TYPE replacePadding(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem, const in ACC_TYPE replace, const in uint32_t numRows, const in uint32_t numCols) {
|
|
if (row >= numRows || col >= numCols) {
|
|
return replace;
|
|
}
|
|
return elem;
|
|
}
|
|
|
|
ACC_TYPE Exp(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem)
|
|
{
|
|
return exp(elem);
|
|
}
|
|
|
|
ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem0, const in ACC_TYPE elem1)
|
|
{
|
|
return max(elem0, elem1);
|
|
}
|
|
|
|
#if defined(BLOCK_SIZE)
|
|
#define DECODEFUNC , DEQUANTFUNC
|
|
#else
|
|
#define DECODEFUNC
|
|
#endif
|
|
|
|
// Store the output when doing grouped query attention.
|
|
// Rows index by Q's dimension 2, and the first N rows are valid.
|
|
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
|
{
|
|
if (r < N && c < D) {
|
|
uint32_t offset = (iq2 + r) * D + c;
|
|
data_o[o_offset + offset] = D_TYPE(elem);
|
|
}
|
|
return elem;
|
|
}
|
|
|
|
// Store column zero. This is used to save per-row m and L values for split_k.
|
|
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
|
{
|
|
if (r < N && c == 0) {
|
|
uint32_t offset = iq2 + r;
|
|
data_o[o_offset + offset] = D_TYPE(elem);
|
|
}
|
|
return elem;
|
|
}
|
|
|
|
// Load the slope matrix, indexed by Q's dimension 2.
|
|
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
|
{
|
|
const uint32_t h = iq2 + (r & (p.gqa_ratio - 1));
|
|
|
|
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
|
|
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
|
|
|
|
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
|
}
|
|
|
|
void main() {
|
|
#ifdef NEEDS_INIT_IQ_SHMEM
|
|
init_iq_shmem(gl_WorkGroupSize);
|
|
#endif
|
|
|
|
const uint32_t N = p.N;
|
|
const uint32_t KV = p.KV;
|
|
|
|
uint32_t i = gl_WorkGroupID.x;
|
|
uint32_t split_k_index = 0;
|
|
|
|
if (p.k_num > 1) {
|
|
i = 0;
|
|
split_k_index = gl_WorkGroupID.x;
|
|
}
|
|
|
|
const uint32_t Tr = CEIL_DIV(N, Br);
|
|
|
|
const uint32_t start_j = split_k_index * p.split_kv / Bc;
|
|
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
|
|
|
|
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
|
|
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
|
|
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
|
|
const uint32_t iq3 = gl_WorkGroupID.z;
|
|
|
|
// broadcast factors
|
|
const uint32_t rk2 = p.neq2/p.nek2;
|
|
const uint32_t rk3 = p.neq3/p.nek3;
|
|
|
|
const uint32_t rv2 = p.neq2/p.nev2;
|
|
const uint32_t rv3 = p.neq3/p.nev3;
|
|
|
|
// k indices
|
|
const uint32_t ik3 = iq3 / rk3;
|
|
const uint32_t ik2 = iq2 / rk2;
|
|
|
|
// v indices
|
|
const uint32_t iv3 = iq3 / rv3;
|
|
const uint32_t iv2 = iq2 / rv2;
|
|
|
|
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
|
tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp);
|
|
tensorLayoutNV<2, Clamp> tensorLayoutV = createTensorLayoutNV(2, Clamp);
|
|
|
|
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
|
|
|
|
#if defined(BLOCK_SIZE)
|
|
tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE);
|
|
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
|
|
#endif
|
|
|
|
tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D);
|
|
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
|
|
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
|
|
|
|
// nb?1 are already divided by the type size and are in units of elements.
|
|
// When using grouped query attention, Q is indexed by iq2, so the stride
|
|
// should be nb02 (which is in bytes).
|
|
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
|
uint32_t k_stride = p.nb11;
|
|
uint32_t v_stride = p.nb21;
|
|
// hint to the compiler that strides are aligned for the aligned variant of the shader
|
|
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
|
{
|
|
q_stride &= ~7;
|
|
#if !defined(BLOCK_SIZE)
|
|
k_stride &= ~7;
|
|
v_stride &= ~7;
|
|
#endif
|
|
}
|
|
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
|
|
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
|
|
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
|
|
|
|
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Q;
|
|
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
|
|
|
|
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
|
|
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D));
|
|
|
|
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA>(Q);
|
|
Qf16 *= float16_t(p.scale);
|
|
|
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
|
|
|
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
|
|
|
|
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
|
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-1.0/0.0);
|
|
|
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
|
|
|
|
// ALiBi
|
|
if (p.max_bias > 0.0f) {
|
|
coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
|
|
}
|
|
|
|
[[dont_unroll]]
|
|
for (uint32_t j = start_j; j < end_j; ++j) {
|
|
|
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
|
|
|
coopmat<float16_t, gl_ScopeWorkgroup, D, Bc, gl_MatrixUseB> K_T;
|
|
|
|
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
|
|
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC);
|
|
S = coopMatMulAdd(Qf16, K_T, S);
|
|
|
|
if (p.logit_softcap != 0.0f) {
|
|
[[unroll]]
|
|
for (int k = 0; k < S.length(); ++k) {
|
|
S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
|
|
}
|
|
}
|
|
|
|
if (p.mask != 0) {
|
|
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
|
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
|
// When using grouped query attention, all rows use the same mask.
|
|
if (p.gqa_ratio > 1) {
|
|
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, 0, 1);
|
|
}
|
|
|
|
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
|
|
|
coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
|
|
|
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
|
}
|
|
|
|
// Clear padding elements to -inf, so they don't contribute to rowmax
|
|
if (Clamp != 0 &&
|
|
((j + 1) * Bc > KV ||
|
|
(i + 1) * Br > N)) {
|
|
|
|
uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
|
|
uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
|
|
|
|
coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(-1.0/0.0), R, C);
|
|
}
|
|
|
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM;
|
|
|
|
coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce);
|
|
|
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> Mold = M;
|
|
|
|
// M = max(rowmax, Mold)
|
|
// P = e^(S - M)
|
|
// eM = e^(Mold - M)
|
|
coopMatPerElementNV(M, rowmax, Max, Mold);
|
|
coopMatPerElementNV(P, S - M, Exp);
|
|
coopMatPerElementNV(eM, Mold - M, Exp);
|
|
|
|
// Clear padding elements to 0, so they don't contribute to rowsum
|
|
if (Clamp != 0 &&
|
|
((j + 1) * Bc > KV ||
|
|
(i + 1) * Br > N)) {
|
|
|
|
uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
|
|
uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
|
|
|
|
coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C);
|
|
}
|
|
|
|
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P);
|
|
|
|
// compute rowsum by multiplying by matrix of all ones.
|
|
coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0);
|
|
|
|
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
|
|
rowsum = coopMatMulAdd(P_A, One, rowsum);
|
|
|
|
coopmat<float16_t, gl_ScopeWorkgroup, Bc, D, gl_MatrixUseB> V;
|
|
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
|
|
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC);
|
|
|
|
L = eM*L + rowsum;
|
|
|
|
// This is the "diagonal" matrix in the paper, but since we do componentwise
|
|
// multiply rather than matrix multiply it has the diagonal element smeared
|
|
// across the row
|
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> eMdiag;
|
|
|
|
// resize eM by using smear/reduce
|
|
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
|
|
|
O = eMdiag * O;
|
|
|
|
O = coopMatMulAdd(P_A, V, O);
|
|
}
|
|
|
|
// If there is split_k, then the split_k resolve shader does the final
|
|
// division by L. Store the intermediate O value and per-row m and L values.
|
|
if (p.k_num > 1) {
|
|
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
|
|
|
|
uint32_t o_offset = D * p.ne1 * split_k_index;
|
|
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
|
|
|
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
|
|
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
|
|
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
|
|
return;
|
|
}
|
|
|
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag;
|
|
|
|
// resize L by using smear/reduce
|
|
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
|
|
|
[[unroll]]
|
|
for (int k = 0; k < Ldiag.length(); ++k) {
|
|
Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
|
|
}
|
|
|
|
O = Ldiag*O;
|
|
|
|
uint32_t o_offset = iq3*p.ne2*p.ne1;
|
|
|
|
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
|
|
if (p.gqa_ratio > 1) {
|
|
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
|
} else {
|
|
tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
|
|
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
|
|
|
|
// permute dimensions
|
|
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
|
|
|
|
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute);
|
|
}
|
|
}
|