mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-04-20 21:46:07 +00:00
Use dot instruction to multiply two values at once, to enable dual issue instructions
This commit is contained in:
parent
039cc0745a
commit
ab27292b26
ggml/src/ggml-vulkan/vulkan-shaders
@ -98,11 +98,11 @@ layout (constant_id = 12) const uint LOAD_VEC_B_SHIFT = 0;
|
||||
#ifdef COOPMAT
|
||||
#define SHMEM_STRIDE (BK + 8)
|
||||
#else
|
||||
#define SHMEM_STRIDE (BK + 1)
|
||||
#define SHMEM_STRIDE (BK / 2 + 1)
|
||||
#endif
|
||||
|
||||
shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
|
||||
shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
|
||||
shared FLOAT_TYPE_VEC2 buf_a[BM * SHMEM_STRIDE];
|
||||
shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE];
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[3072];
|
||||
@ -223,8 +223,8 @@ void main() {
|
||||
}
|
||||
#else
|
||||
ACC_TYPE sums[WMITER * TM * WNITER * TN];
|
||||
FLOAT_TYPE cache_a[WMITER * TM];
|
||||
FLOAT_TYPE cache_b[TN];
|
||||
FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
|
||||
FLOAT_TYPE_VEC2 cache_b[TN];
|
||||
|
||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = ACC_TYPE(0.0f);
|
||||
@ -262,7 +262,7 @@ void main() {
|
||||
}
|
||||
}
|
||||
#else
|
||||
[[unroll]] for (uint i = 0; i < BK; i++) {
|
||||
[[unroll]] for (uint i = 0; i < BK / 2; i++) {
|
||||
// Load from shared into cache
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (uint j = 0; j < TM; j++) {
|
||||
@ -278,7 +278,7 @@ void main() {
|
||||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
|
||||
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[cc]), sums[sums_idx]);
|
||||
sums[sums_idx] += dot(ACC_TYPE_VEC2(cache_a[wsir * TM + cr]), ACC_TYPE_VEC2(cache_b[cc]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4,32 +4,31 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
#if defined(FLOAT16) && defined(A_TYPE_VEC8)
|
||||
if (LOAD_VEC_A_SHIFT == 3) {
|
||||
const uint idx = pos_a + col * ((p.stride_a >> LOAD_VEC_A_SHIFT)) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + ((row << LOAD_VEC_A_SHIFT));
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(data_a_vec8[idx][0].x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a_vec8[idx][0].y);
|
||||
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a_vec8[idx][0].z);
|
||||
buf_a[buf_idx + 3] = FLOAT_TYPE(data_a_vec8[idx][0].w);
|
||||
buf_a[buf_idx + 4] = FLOAT_TYPE(data_a_vec8[idx][1].x);
|
||||
buf_a[buf_idx + 5] = FLOAT_TYPE(data_a_vec8[idx][1].y);
|
||||
buf_a[buf_idx + 6] = FLOAT_TYPE(data_a_vec8[idx][1].z);
|
||||
buf_a[buf_idx + 7] = FLOAT_TYPE(data_a_vec8[idx][1].w);
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
|
||||
const FLOAT_TYPE_VEC8 vals = FLOAT_TYPE_VEC8(data_a_vec8[idx]);
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(vals[0].xy);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(vals[0].zw);
|
||||
buf_a[buf_idx + 2] = FLOAT_TYPE_VEC2(vals[1].xy);
|
||||
buf_a[buf_idx + 3] = FLOAT_TYPE_VEC2(vals[1].zw);
|
||||
} else
|
||||
#endif
|
||||
if (LOAD_VEC_A_SHIFT == 2) {
|
||||
const uint idx = pos_a + col * ((p.stride_a >> LOAD_VEC_A_SHIFT)) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + ((row << LOAD_VEC_A_SHIFT));
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(data_a_vec4[idx].x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a_vec4[idx].y);
|
||||
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a_vec4[idx].z);
|
||||
buf_a[buf_idx + 3] = FLOAT_TYPE(data_a_vec4[idx].w);
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
|
||||
const FLOAT_TYPE_VEC4 vals = FLOAT_TYPE_VEC4(data_a_vec4[idx]);
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(vals.xy);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(vals.zw);
|
||||
} else if (idx_m < p.M && idx_k + 1 < end_k) {
|
||||
buf_a[col * SHMEM_STRIDE + row] = FLOAT_TYPE_VEC2(data_a[pos_a + col * p.stride_a + row ],
|
||||
data_a[pos_a + col * p.stride_a + row + 1]);
|
||||
} else if (idx_m < p.M && idx_k < end_k) {
|
||||
buf_a[col * SHMEM_STRIDE + row] = FLOAT_TYPE(data_a[pos_a + col * p.stride_a + row]);
|
||||
buf_a[col * SHMEM_STRIDE + row] = FLOAT_TYPE_VEC2(data_a[pos_a + col * p.stride_a + row], 0.0f);
|
||||
} else {
|
||||
buf_a[col * SHMEM_STRIDE + row] = FLOAT_TYPE(0.0f);
|
||||
buf_a[col * SHMEM_STRIDE + row] = FLOAT_TYPE_VEC2(0.0f);
|
||||
}
|
||||
#elif defined(DATA_A_Q4_0)
|
||||
const uint idx = pos_a + col * ((p.stride_a >> LOAD_VEC_A_SHIFT)) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + 4 * row;
|
||||
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (4 * row) / 2;
|
||||
|
||||
const uint ib = idx / 4;
|
||||
const uint iqs = idx & 0x03;
|
||||
@ -39,17 +38,13 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d;
|
||||
const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v0.x);
|
||||
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
|
||||
buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
|
||||
buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
|
||||
buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
|
||||
buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
|
||||
buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
|
||||
buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v0.zw);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v1.xy);
|
||||
buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw);
|
||||
#elif defined(DATA_A_Q4_1)
|
||||
const uint idx = pos_a + col * ((p.stride_a >> LOAD_VEC_A_SHIFT)) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + 4 * row;
|
||||
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (4 * row) / 2;
|
||||
|
||||
const uint ib = idx / 4;
|
||||
const uint iqs = idx & 0x03;
|
||||
@ -60,17 +55,13 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m;
|
||||
const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v0.x);
|
||||
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
|
||||
buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
|
||||
buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
|
||||
buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
|
||||
buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
|
||||
buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
|
||||
buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v0.zw);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v1.xy);
|
||||
buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw);
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
const uint idx = pos_a + col * ((p.stride_a >> LOAD_VEC_A_SHIFT)) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
|
||||
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (2 * row) / 2;
|
||||
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = idx & 0x07;
|
||||
@ -83,13 +74,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
|
||||
const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
|
||||
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (2 * row) / 2;
|
||||
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = idx & 0x07;
|
||||
@ -103,13 +92,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
|
||||
const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
|
||||
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
|
||||
#elif defined(DATA_A_Q8_0)
|
||||
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
|
||||
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = idx & 0x07;
|
||||
@ -119,13 +106,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy;
|
||||
const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx + 2] = FLOAT_TYPE(v.z);
|
||||
buf_a[buf_idx + 3] = FLOAT_TYPE(v.w);
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
|
||||
#elif defined(DATA_A_Q2_K)
|
||||
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
@ -140,11 +125,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
|
||||
const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
|
||||
#elif defined(DATA_A_Q3_K)
|
||||
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
@ -162,11 +146,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
| (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4));
|
||||
const float dl = float(data_a[ib].d) * float(us - 32);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)));
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)),
|
||||
dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
|
||||
#elif defined(DATA_A_Q4_K)
|
||||
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
@ -195,11 +179,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const float d = loadd.x * sc;
|
||||
const float m = -loadd.y * mbyte;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m));
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m),
|
||||
fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
|
||||
#elif defined(DATA_A_Q5_K)
|
||||
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
@ -231,11 +215,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const float d = loadd.x * sc;
|
||||
const float m = -loadd.y * mbyte;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m));
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m),
|
||||
fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
|
||||
#elif defined(DATA_A_Q6_K)
|
||||
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
@ -250,11 +234,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
|
||||
const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32));
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32),
|
||||
dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
|
||||
#elif defined(DATA_A_IQ1_S)
|
||||
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib32 = (idx % 128) / 16; // 0..7
|
||||
@ -274,11 +258,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
);
|
||||
const vec2 v = dl * (vec2(gvec) + delta);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
|
||||
#elif defined(DATA_A_IQ1_M)
|
||||
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib8 = (idx % 128) / 4;
|
||||
@ -300,11 +283,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
);
|
||||
const vec2 v = dl * (vec2(gvec) + delta);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
|
||||
#elif defined(DATA_A_IQ2_XXS)
|
||||
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib32 = (idx % 128) / 16; // 0..7
|
||||
@ -325,11 +307,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
|
||||
#elif defined(DATA_A_IQ2_XS)
|
||||
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib32 = (idx % 128) / 16; // 0..7
|
||||
@ -345,11 +326,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
|
||||
#elif defined(DATA_A_IQ2_S)
|
||||
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib8 = (idx % 128) / 4; // 0..31
|
||||
@ -367,11 +347,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
|
||||
#elif defined(DATA_A_IQ3_XXS)
|
||||
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = (idx % 128) / 2; // 0..63
|
||||
@ -392,11 +371,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
|
||||
#elif defined(DATA_A_IQ3_S)
|
||||
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = (idx % 128) / 2; // 0..63
|
||||
@ -412,11 +390,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
|
||||
#elif defined(DATA_A_IQ4_XS)
|
||||
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib32 = (idx % 128) / 16; // 0..7
|
||||
@ -431,11 +408,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const float d = float(data_a[ib].d);
|
||||
const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
|
||||
#elif defined(DATA_A_IQ4_NL)
|
||||
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (2 * row) / 2;
|
||||
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = idx & 0x07;
|
||||
@ -443,10 +419,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d);
|
||||
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(kvalues_iq4nl[vui & 0xF]) * d;
|
||||
buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d;
|
||||
buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d;
|
||||
buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d;
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(kvalues_iq4nl[vui & 0xF] * d,
|
||||
kvalues_iq4nl[bitfieldExtract(vui, 8, 4)] * d);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)] * d,
|
||||
kvalues_iq4nl[vui >> 12] * d);
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -455,28 +431,27 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
|
||||
#if defined(B_TYPE_VEC8)
|
||||
if (LOAD_VEC_B_SHIFT == 3) {
|
||||
const uint idx = pos_b + col * (p.stride_b >> LOAD_VEC_B_SHIFT) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_B_SHIFT);
|
||||
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b_vec8[idx][0].x);
|
||||
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b_vec8[idx][0].y);
|
||||
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b_vec8[idx][0].z);
|
||||
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b_vec8[idx][0].w);
|
||||
buf_b[buf_idx + 4] = FLOAT_TYPE(data_b_vec8[idx][1].x);
|
||||
buf_b[buf_idx + 5] = FLOAT_TYPE(data_b_vec8[idx][1].y);
|
||||
buf_b[buf_idx + 6] = FLOAT_TYPE(data_b_vec8[idx][1].z);
|
||||
buf_b[buf_idx + 7] = FLOAT_TYPE(data_b_vec8[idx][1].w);
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_B_SHIFT) / 2;
|
||||
const FLOAT_TYPE_VEC8 vals = FLOAT_TYPE_VEC8(data_b_vec8[idx]);
|
||||
buf_b[buf_idx + 0] = FLOAT_TYPE_VEC2(vals[0].xy);
|
||||
buf_b[buf_idx + 1] = FLOAT_TYPE_VEC2(vals[0].zw);
|
||||
buf_b[buf_idx + 2] = FLOAT_TYPE_VEC2(vals[1].xy);
|
||||
buf_b[buf_idx + 3] = FLOAT_TYPE_VEC2(vals[1].zw);
|
||||
} else
|
||||
#endif
|
||||
if (LOAD_VEC_B_SHIFT == 2) {
|
||||
const uint idx = pos_b + col * (p.stride_b >> LOAD_VEC_B_SHIFT) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_B_SHIFT);
|
||||
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b_vec4[idx].x);
|
||||
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b_vec4[idx].y);
|
||||
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b_vec4[idx].z);
|
||||
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b_vec4[idx].w);
|
||||
} else if (idx_n < p.N && idx_k < end_k) {
|
||||
buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE(data_b[pos_b + col * p.stride_b + row]);
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_B_SHIFT) / 2;
|
||||
const FLOAT_TYPE_VEC4 vals = FLOAT_TYPE_VEC4(data_b_vec4[idx]);
|
||||
buf_b[buf_idx + 0] = FLOAT_TYPE_VEC2(vals.xy);
|
||||
buf_b[buf_idx + 1] = FLOAT_TYPE_VEC2(vals.zw);
|
||||
} else if (idx_n < p.N && idx_k + 1 < end_k) {
|
||||
buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE_VEC2(data_b[pos_b + col * p.stride_b + row ],
|
||||
data_b[pos_b + col * p.stride_b + row + 1]);
|
||||
} else if (idx_n < p.N && idx_k + 1 < end_k) {
|
||||
buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE_VEC2(data_b[pos_b + col * p.stride_b + row], 0.0f);
|
||||
} else {
|
||||
buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE(0.0f);
|
||||
buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE_VEC2(0.0f);
|
||||
}
|
||||
}
|
||||
#else
|
||||
@ -485,32 +460,34 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
|
||||
if (LOAD_VEC_B_SHIFT == 3) {
|
||||
const u16vec2 row_idx = row_ids[ic * BN + col];
|
||||
const uint idx = pos_b + row_idx.y * (p.batch_stride_b >> LOAD_VEC_B_SHIFT) + (row_idx.x % p.ne11) * (p.stride_b >> LOAD_VEC_B_SHIFT) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_B_SHIFT);
|
||||
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b_vec8[idx][0].x);
|
||||
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b_vec8[idx][0].y);
|
||||
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b_vec8[idx][0].z);
|
||||
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b_vec8[idx][0].w);
|
||||
buf_b[buf_idx + 4] = FLOAT_TYPE(data_b_vec8[idx][1].x);
|
||||
buf_b[buf_idx + 5] = FLOAT_TYPE(data_b_vec8[idx][1].y);
|
||||
buf_b[buf_idx + 6] = FLOAT_TYPE(data_b_vec8[idx][1].z);
|
||||
buf_b[buf_idx + 7] = FLOAT_TYPE(data_b_vec8[idx][1].w);
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_B_SHIFT) / 2;
|
||||
const FLOAT_TYPE_VEC8 vals = FLOAT_TYPE_VEC8(data_b_vec8[idx]);
|
||||
buf_b[buf_idx + 0] = FLOAT_TYPE_VEC2(vals[0].xy);
|
||||
buf_b[buf_idx + 1] = FLOAT_TYPE_VEC2(vals[0].zw);
|
||||
buf_b[buf_idx + 2] = FLOAT_TYPE_VEC2(vals[1].xy);
|
||||
buf_b[buf_idx + 3] = FLOAT_TYPE_VEC2(vals[1].zw);
|
||||
} else
|
||||
#endif
|
||||
if (LOAD_VEC_B_SHIFT == 2) {
|
||||
const u16vec2 row_idx = row_ids[ic * BN + col];
|
||||
const uint idx = pos_b + row_idx.y * (p.batch_stride_b >> LOAD_VEC_B_SHIFT) + (row_idx.x % p.ne11) * (p.stride_b >> LOAD_VEC_B_SHIFT) + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_B_SHIFT);
|
||||
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b_vec4[idx].x);
|
||||
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b_vec4[idx].y);
|
||||
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b_vec4[idx].z);
|
||||
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b_vec4[idx].w);
|
||||
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_B_SHIFT) / 2;
|
||||
const FLOAT_TYPE_VEC4 vals = FLOAT_TYPE_VEC4(data_b_vec4[idx]);
|
||||
buf_b[buf_idx + 0] = FLOAT_TYPE_VEC2(vals.xy);
|
||||
buf_b[buf_idx + 1] = FLOAT_TYPE_VEC2(vals.zw);
|
||||
} else {
|
||||
const uint row_i = ic * BN + col;
|
||||
if (row_i < _ne1) {
|
||||
const u16vec2 row_idx = row_ids[row_i];
|
||||
buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row]);
|
||||
const uint row_i_1 = ic * BN + col;
|
||||
const uint row_i_2 = ic * BN + col + 1;
|
||||
if (row_i_1 < _ne1 && row_i_2 < _ne1) {
|
||||
const u16vec2 row_idx_1 = row_ids[row_i_1];
|
||||
const u16vec2 row_idx_2 = row_ids[row_i_2];
|
||||
buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE_VEC2(data_b[pos_b + row_idx_1.y * p.batch_stride_b + (row_idx_1.x % p.ne11) * p.stride_b + row],
|
||||
data_b[pos_b + row_idx_2.y * p.batch_stride_b + (row_idx_2.x % p.ne11) * p.stride_b + row]);
|
||||
} else if (row_i_1 < _ne1) {
|
||||
const u16vec2 row_idx = row_ids[row_i_1];
|
||||
buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE_VEC2(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row], 0.0f);
|
||||
} else {
|
||||
buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE(0.0f);
|
||||
buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE_VEC2(0.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -293,10 +293,10 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const
|
||||
void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) {
|
||||
std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
|
||||
|
||||
std::map<std::string, std::string> base_dict = {
|
||||
{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"},
|
||||
{"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"},
|
||||
};
|
||||
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float" },
|
||||
{"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2" },
|
||||
{"FLOAT_TYPE_VEC4", (coopmat2 || fp16) ? "f16vec4" : "vec4" },
|
||||
{"FLOAT_TYPE_VEC8", (coopmat2 || fp16) ? "f16mat2x4" : "mat2x4"}};
|
||||
std::string shader_name = "matmul";
|
||||
|
||||
if (matmul_id) {
|
||||
@ -308,7 +308,8 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
||||
base_dict["FLOAT16"] = "1";
|
||||
}
|
||||
|
||||
base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
|
||||
base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float";
|
||||
base_dict["ACC_TYPE_VEC2"] = f16acc ? "f16vec2" : "vec2";
|
||||
|
||||
if (coopmat) {
|
||||
base_dict["COOPMAT"] = "1";
|
||||
@ -386,6 +387,7 @@ void process_shaders() {
|
||||
// flash attention
|
||||
for (const auto& f16acc : {false, true}) {
|
||||
std::string acctype = f16acc ? "float16_t" : "float";
|
||||
std::string acctype_vec2 = f16acc ? "f16vec2" : "vec2";
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
if (tname == "f32") {
|
||||
|
Loading…
x
Reference in New Issue
Block a user