diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 1e954b4ce..ca5a00b03 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -1,6 +1,70 @@ #ifndef GGML_METAL_IMPL #define GGML_METAL_IMPL +// kernel parameters for mat-vec threadgroups +// +// N_R0: number of src0 rows to process per simdgroup +// N_SG: number of simdgroups per threadgroup +// +// TODO: for optimal performance, become function of the device and work size + +#define N_R0_Q4_0 4 +#define N_SG_Q4_0 2 + +#define N_R0_Q4_1 4 +#define N_SG_Q4_1 2 + +#define N_R0_Q5_0 4 +#define N_SG_Q5_0 2 + +#define N_R0_Q5_1 4 +#define N_SG_Q5_1 2 + +#define N_R0_Q8_0 4 +#define N_SG_Q8_0 2 + +#define N_R0_Q2_K 4 +#define N_SG_Q2_K 2 + +#define N_R0_Q3_K 2 +#define N_SG_Q3_K 2 + +#define N_R0_Q4_K 4 +#define N_SG_Q4_K 2 + +#define N_R0_Q5_K 2 +#define N_SG_Q5_K 2 + +#define N_R0_Q6_K 1 +#define N_SG_Q6_K 2 + +#define N_R0_IQ1_S 4 +#define N_SG_IQ1_S 2 + +#define N_R0_IQ1_M 4 +#define N_SG_IQ1_M 2 + +#define N_R0_IQ2_XXS 4 +#define N_SG_IQ2_XXS 2 + +#define N_R0_IQ2_XS 4 +#define N_SG_IQ2_XS 2 + +#define N_R0_IQ2_S 4 +#define N_SG_IQ2_S 2 + +#define N_R0_IQ3_XXS 4 +#define N_SG_IQ3_XXS 2 + +#define N_R0_IQ3_S 4 +#define N_SG_IQ3_S 2 + +#define N_R0_IQ4_NL 2 +#define N_SG_IQ4_NL 2 + +#define N_R0_IQ4_XS 2 +#define N_SG_IQ4_XS 2 + // kernel argument structs // // - element counters (e.g. ne00) typically use int32_t to reduce register usage diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index af65e7d9f..195d96782 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2561,171 +2561,180 @@ static void ggml_metal_encode_node( [encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { - int nth0 = 32; - int nth1 = 1; - int nrows = 1; - //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); - id pipeline = nil; + int nsg = 0; // number of simdgroups + int nr0 = 0; // number of src0 rows per simdgroup + int nr1 = 1; // number of src1 rows per threadgroup + + size_t smem = 0; // shared memory + // use custom matrix x vector kernel switch (src0t) { case GGML_TYPE_F32: { GGML_ASSERT(src1t == GGML_TYPE_F32); + nsg = 1; + nr0 = 1; + nr1 = 4; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline; - nrows = 4; } break; case GGML_TYPE_F16: { - nth0 = 32; - nth1 = 1; + nsg = 1; + nr0 = 1; if (src1t == GGML_TYPE_F32) { if (ne11 * ne12 < 4) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline; } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline; - nrows = ne11; + nr1 = ne11; } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline; - nrows = 4; + nr1 = 4; } } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline; - nrows = 4; + nr1 = 4; } } break; case GGML_TYPE_BF16: { - nth0 = 32; - nth1 = 1; + nsg = 1; + nr0 = 1; if (src1t == GGML_TYPE_F32) { if (ne11 * ne12 < 4) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline; } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline; - nrows = ne11; + nr1 = ne11; } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline; - nrows = 4; + nr1 = 4; } } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline; - nrows = 4; + nr1 = 4; } } break; case GGML_TYPE_Q4_0: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q4_0; + nr0 = N_R0_Q4_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline; } break; case GGML_TYPE_Q4_1: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q4_1; + nr0 = N_R0_Q4_1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline; } break; case GGML_TYPE_Q5_0: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q5_0; + nr0 = N_R0_Q5_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline; } break; case GGML_TYPE_Q5_1: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q5_1; + nr0 = N_R0_Q5_1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline; } break; case GGML_TYPE_Q8_0: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q8_0; + nr0 = N_R0_Q8_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline; } break; case GGML_TYPE_Q2_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q2_K; + nr0 = N_R0_Q2_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline; } break; case GGML_TYPE_Q3_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q3_K; + nr0 = N_R0_Q3_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline; } break; case GGML_TYPE_Q4_K: { - nth0 = 4; //1; - nth1 = 8; //32; + nsg = N_SG_Q4_K; + nr0 = N_R0_Q4_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline; } break; case GGML_TYPE_Q5_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q5_K; + nr0 = N_R0_Q5_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline; } break; case GGML_TYPE_Q6_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q6_K; + nr0 = N_R0_Q6_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline; } break; case GGML_TYPE_IQ2_XXS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ2_XXS; + nr0 = N_R0_IQ2_XXS; + smem = 256*8+128; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline; } break; case GGML_TYPE_IQ2_XS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ2_XS; + nr0 = N_R0_IQ2_XS; + smem = 512*8+128; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline; } break; case GGML_TYPE_IQ3_XXS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ3_XXS; + nr0 = N_R0_IQ3_XXS; + smem = 256*4+128; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline; } break; case GGML_TYPE_IQ3_S: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ3_S; + nr0 = N_R0_IQ3_S; + smem = 512*4; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline; } break; case GGML_TYPE_IQ2_S: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ2_S; + nr0 = N_R0_IQ2_S; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline; } break; case GGML_TYPE_IQ1_S: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ1_S; + nr0 = N_R0_IQ1_S; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline; } break; case GGML_TYPE_IQ1_M: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ1_M; + nr0 = N_R0_IQ1_M; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline; } break; case GGML_TYPE_IQ4_NL: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ4_NL; + nr0 = N_R0_IQ4_NL; + smem = 32*sizeof(float); pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline; } break; case GGML_TYPE_IQ4_XS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ4_XS; + nr0 = N_R0_IQ4_XS; + smem = 32*sizeof(float); pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline; } break; default: @@ -2762,41 +2771,10 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || - src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { - const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { - const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { - const int mem_size = 32*sizeof(float); - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q3_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else { - const int64_t ny = (ne11 + nrows - 1)/nrows; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + if (smem > 0) { + [encoder setThreadgroupMemoryLength:smem atIndex:0]; } + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } } break; case GGML_OP_MUL_MAT_ID: @@ -2902,146 +2880,155 @@ static void ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { - int nth0 = 32; - int nth1 = 1; - int nrows = 1; - //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); - id pipeline = nil; + int nsg = 0; // number of simdgroups + int nr0 = 0; // number of src0 rows per simdgroup + int nr1 = 1; // number of src1 rows per threadgroup + + size_t smem = 0; // shared memory + // use custom matrix x vector kernel switch (src0t) { case GGML_TYPE_F32: { GGML_ASSERT(src1t == GGML_TYPE_F32); + nsg = 1; + nr0 = 1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline; } break; case GGML_TYPE_F16: { GGML_ASSERT(src1t == GGML_TYPE_F32); - nth0 = 32; - nth1 = 1; + nsg = 1; + nr0 = 1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; } break; case GGML_TYPE_BF16: { GGML_ASSERT(src1t == GGML_TYPE_F32); - nth0 = 32; - nth1 = 1; + nsg = 1; + nr0 = 1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline; } break; case GGML_TYPE_Q4_0: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q4_0; + nr0 = N_R0_Q4_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline; } break; case GGML_TYPE_Q4_1: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q4_1; + nr0 = N_R0_Q4_1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline; } break; case GGML_TYPE_Q5_0: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q5_0; + nr0 = N_R0_Q5_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline; } break; case GGML_TYPE_Q5_1: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q5_1; + nr0 = N_R0_Q5_1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline; } break; case GGML_TYPE_Q8_0: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q8_0; + nr0 = N_R0_Q8_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; } break; case GGML_TYPE_Q2_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q2_K; + nr0 = N_R0_Q2_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline; } break; case GGML_TYPE_Q3_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q3_K; + nr0 = N_R0_Q3_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline; } break; case GGML_TYPE_Q4_K: { - nth0 = 4; //1; - nth1 = 8; //32; + nsg = N_SG_Q4_K; + nr0 = N_R0_Q4_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline; } break; case GGML_TYPE_Q5_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q5_K; + nr0 = N_R0_Q5_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline; } break; case GGML_TYPE_Q6_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q6_K; + nr0 = N_R0_Q6_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline; } break; case GGML_TYPE_IQ2_XXS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ2_XXS; + nr0 = N_R0_IQ2_XXS; + smem = 256*8+128; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline; } break; case GGML_TYPE_IQ2_XS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ2_XS; + nr0 = N_R0_IQ2_XS; + smem = 512*8+128; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline; } break; case GGML_TYPE_IQ3_XXS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ3_XXS; + nr0 = N_R0_IQ3_XXS; + smem = 256*4+128; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline; } break; case GGML_TYPE_IQ3_S: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ3_S; + nr0 = N_R0_IQ3_S; + smem = 512*4; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline; } break; case GGML_TYPE_IQ2_S: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ2_S; + nr0 = N_R0_IQ2_S; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline; } break; case GGML_TYPE_IQ1_S: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ1_S; + nr0 = N_R0_IQ1_S; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline; } break; case GGML_TYPE_IQ1_M: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ1_M; + nr0 = N_R0_IQ1_M; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline; } break; case GGML_TYPE_IQ4_NL: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ4_NL; + nr0 = N_R0_IQ4_NL; + smem = 32*sizeof(float); pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline; } break; case GGML_TYPE_IQ4_XS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ4_XS; + nr0 = N_R0_IQ4_XS; + smem = 32*sizeof(float); pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; } break; default: @@ -3052,7 +3039,7 @@ static void ggml_metal_encode_node( }; if (ggml_is_quantized(src0t)) { - GGML_ASSERT(ne00 >= nth0*nth1); + GGML_ASSERT(ne00 >= nsg*nr0); } ggml_metal_kargs_mul_mv_id args = { @@ -3085,43 +3072,12 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; const int64_t _ne1 = 1; - const int tgz = dst_rows; + const int64_t ne123 = dst_rows; - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || - src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { - const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { - const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { - const int mem_size = 32*sizeof(float); - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q3_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else { - const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1 - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + if (smem > 0) { + [encoder setThreadgroupMemoryLength:smem atIndex:0]; } + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } } break; case GGML_OP_GET_ROWS: diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 3cef81b79..38f03efba 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1439,7 +1439,7 @@ kernel void kernel_rwkv_wkv7_f32( float4 sa_vec(0.0); - for (int j = 0; j < head_size; j += 4) { + for (uint j = 0; j < head_size; j += 4) { float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]); float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]); sa_vec += a_vec * s_vec; @@ -1853,14 +1853,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; } -// putting them in the kernel cause a significant performance penalty -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 2 // number of SIMD groups in a thread group -//Note: This is a template, but strictly speaking it only applies to -// quantizations where the block size is 32. It also does not -// guard against the number of rows not being divisible by -// N_DST, so this is another explicit assumption of the implementation. -template +template void mul_vec_q_n_f32_impl( args_t args, device const char * src0, @@ -1876,7 +1869,7 @@ void mul_vec_q_n_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -1888,15 +1881,15 @@ void mul_vec_q_n_f32_impl( device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows - device const block_q_type * ax[nr]; - for (int row = 0; row < nr; ++row) { + device const block_q_type * ax[nr0]; + for (int row = 0; row < nr0; ++row) { const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); } float yl[16]; // src1 vector cache - float sumf[nr] = {0.f}; + float sumf[nr0] = {0.f}; const short ix = (tiisg/2); const short il = (tiisg%2)*8; @@ -1908,7 +1901,7 @@ void mul_vec_q_n_f32_impl( float sumy[2] = { 0.f, 0.f }; #pragma unroll - for (int i = 0; i < 8; i += 2) { + for (short i = 0; i < 8; i += 2) { sumy[0] += yb[i + 0] + yb[i + 1]; yl[i + 0] = yb[i + 0]; yl[i + 1] = yb[i + 1]/256.f; @@ -1919,7 +1912,7 @@ void mul_vec_q_n_f32_impl( } #pragma unroll - for (int row = 0; row < nr; row++) { + for (short row = 0; row < nr0; row++) { sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il); } @@ -1928,7 +1921,7 @@ void mul_vec_q_n_f32_impl( device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; - for (int row = 0; row < nr; ++row) { + for (int row = 0; row < nr0; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0 && first_row + row < args.ne01) { @@ -1945,7 +1938,7 @@ kernel void kernel_mul_mv_q4_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q4_1_f32( @@ -1956,7 +1949,7 @@ kernel void kernel_mul_mv_q4_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_0_f32( @@ -1967,7 +1960,7 @@ kernel void kernel_mul_mv_q5_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_1_f32( @@ -1978,12 +1971,12 @@ kernel void kernel_mul_mv_q5_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } #define NB_Q8_0 8 -template +template void kernel_mul_mv_q8_0_f32_impl( args_t args, device const char * src0, @@ -1993,16 +1986,13 @@ void kernel_mul_mv_q8_0_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { - const int nr = N_DST; - const int nsg = N_SIMDGROUP; - const int nw = N_SIMDWIDTH; - const int nb = args.ne00/QK8_0; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0*nsg + sgitg)*nr; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -2014,15 +2004,15 @@ void kernel_mul_mv_q8_0_f32_impl( device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows - device const block_q8_0 * ax[nr]; - for (int row = 0; row < nr; ++row) { + device const block_q8_0 * ax[nr0]; + for (int row = 0; row < nr0; ++row) { const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); } float yl[NB_Q8_0]; - float sumf[nr] = { 0.f }; + float sumf[nr0] = { 0.f }; const short ix = tiisg/4; const short il = tiisg%4; @@ -2035,7 +2025,7 @@ void kernel_mul_mv_q8_0_f32_impl( yl[i] = yb[i]; } - for (int row = 0; row < nr; row++) { + for (short row = 0; row < nr0; row++) { device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0; float sumq = 0.f; for (short iq = 0; iq < NB_Q8_0; ++iq) { @@ -2049,7 +2039,7 @@ void kernel_mul_mv_q8_0_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr; ++row) { + for (int row = 0; row < nr0; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0 && first_row + row < args.ne01) { @@ -2067,7 +2057,7 @@ kernel void kernel_mul_mv_q8_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } // mat-vec kernel processing in chunks of float4 @@ -2404,9 +2394,9 @@ void kernel_mul_mv_impl( sumf += (T0) x[i] * (T1) y[i]; } - float all_sum = simd_sum(sumf); + float sum_all = simd_sum(sumf); if (tiisg == 0) { - dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; + dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; } } } else { @@ -2427,10 +2417,10 @@ void kernel_mul_mv_impl( sumf += dot((float4) x4[i], (float4) y4[i]); } - float all_sum = simd_sum(sumf); + float sum_all = simd_sum(sumf); if (tiisg == 0) { - for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; + for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]); + dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; } } } @@ -2492,9 +2482,9 @@ kernel void kernel_mul_mv_1row( for (int i = tiisg; i < args.ne00; i += 32) { sumf += (float) x[i] * (float) y[i]; } - float all_sum = simd_sum(sumf); + float sum_all = simd_sum(sumf); if (tiisg == 0) { - dst_f32[r0] = all_sum; + dst_f32[r0] = sum_all; } } else { device const T4 * x4 = (device const T4 *) x; @@ -2504,11 +2494,11 @@ kernel void kernel_mul_mv_1row( sumf += dot((float4) x4[i], y4[i]); } - float all_sum = simd_sum(sumf); + float sum_all = simd_sum(sumf); if (tiisg == 0) { - for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst_f32[r0] = all_sum; + for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]); + dst_f32[r0] = sum_all; } } } @@ -2553,9 +2543,9 @@ kernel void kernel_mul_mv_l4( sumf += dot((float4) x4[i], y4[i]); } - float all_sum = simd_sum(sumf); + float sum_all = simd_sum(sumf); if (tiisg == 0) { - dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; + dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; } } } @@ -4321,7 +4311,7 @@ kernel void kernel_cpy_f32_iq4_nl( float amax = 0.0f; // absolute max float max = 0.0f; - for (int j = 0; j < QK4_0; j++) { + for (int j = 0; j < QK4_NL; j++) { const float v = src[j]; if (amax < fabs(v)) { amax = fabs(v); @@ -4429,7 +4419,7 @@ kernel void kernel_concat( } } -template +template void kernel_mul_mv_q2_K_f32_impl( args_t args, device const char * src0, @@ -4445,7 +4435,7 @@ void kernel_mul_mv_q2_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -4457,20 +4447,19 @@ void kernel_mul_mv_q2_K_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; - const int ix = tiisg/8; // 0...3 - const int it = tiisg%8; // 0...7 - const int iq = it/4; // 0 or 1 - const int ir = it%4; // 0...3 - const int is = (8*ir)/16;// 0 or 1 + const short ix = tiisg/8; // 0...3 + const short it = tiisg%8; // 0...7 + const short iq = it/4; // 0 or 1 + const short ir = it%4; // 0...3 + const short is = (8*ir)/16;// 0 or 1 device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; for (int ib = ix; ib < nb; ib += 4) { - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { + for (short i = 0; i < 8; ++i) { yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8]; yl[i+16] = y4[i+64]; sumy[2] += yl[i+16]; @@ -4481,7 +4470,7 @@ void kernel_mul_mv_q2_K_f32_impl( device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; device const half * dh = &x[ib].d; - for (int row = 0; row < N_DST; row++) { + for (short row = 0; row < nr0; row++) { float4 acc1 = {0.f, 0.f, 0.f, 0.f}; float4 acc2 = {0.f, 0.f, 0.f, 0.f}; for (int i = 0; i < 8; i += 2) { @@ -4512,10 +4501,10 @@ void kernel_mul_mv_q2_K_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } @@ -4530,10 +4519,10 @@ kernel void kernel_mul_mv_q2_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q3_K_f32_impl( args_t args, device const char * src0, @@ -4550,7 +4539,7 @@ void kernel_mul_mv_q3_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -4566,13 +4555,12 @@ void kernel_mul_mv_q3_K_f32_impl( //const uint16_t kmask1 = 0x3030; //const uint16_t kmask2 = 0x0f0f; - const int tid = tiisg/4; - const int ix = tiisg%4; - const int ip = tid/4; // 0 or 1 - const int il = 2*((tid%4)/2); // 0 or 2 - const int ir = tid%2; - const int n = 8; - const int l0 = n*ir; + const short tid = tiisg/4; + const short ix = tiisg%4; + const short ip = tid/4; // 0 or 1 + const short il = 2*((tid%4)/2); // 0 or 2 + const short ir = tid%2; + const short l0 = 8*ir; // One would think that the Metal compiler would figure out that ip and il can only have // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it @@ -4597,8 +4585,8 @@ void kernel_mul_mv_q3_K_f32_impl( const uint16_t s_shift1 = 4*ip; const uint16_t s_shift2 = s_shift1 + il; - const int q_offset = 32*ip + l0; - const int y_offset = 128*ip + 32*il + l0; + const short q_offset = 32*ip + l0; + const short y_offset = 128*ip + 32*il + l0; device const float * y1 = yy + ix*QK_K + y_offset; @@ -4606,10 +4594,11 @@ void kernel_mul_mv_q3_K_f32_impl( thread uint16_t * scales16 = (thread uint16_t *)&scales32; thread const int8_t * scales = (thread const int8_t *)&scales32; - float sumf1[2] = {0.f}; - float sumf2[2] = {0.f}; + float sumf1[nr0] = {0.f}; + float sumf2[nr0] = {0.f}; + for (int i = ix; i < nb; i += 4) { - for (int l = 0; l < 8; ++l) { + for (short l = 0; l < 8; ++l) { yl[l+ 0] = y1[l+ 0]; yl[l+ 8] = y1[l+16]; yl[l+16] = y1[l+32]; @@ -4621,7 +4610,7 @@ void kernel_mul_mv_q3_K_f32_impl( device const uint16_t * a = (device const uint16_t *)(x[i].scales); device const half * dh = &x[i].d; - for (int row = 0; row < 2; ++row) { + for (short row = 0; row < nr0; ++row) { const float d_all = (float)dh[0]; scales16[0] = a[4]; @@ -4632,7 +4621,7 @@ void kernel_mul_mv_q3_K_f32_impl( scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32; float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0; - for (int l = 0; l < n; l += 2) { + for (short l = 0; l < 8; l += 2) { const int32_t qs = q[l/2]; s1 += yl[l+0] * (qs & qm[il/2][0]); s2 += yl[l+1] * (qs & qm[il/2][1]); @@ -4647,7 +4636,7 @@ void kernel_mul_mv_q3_K_f32_impl( sumf2[row] += d2 * (scales[2] - 32); s1 = s2 = s3 = s4 = s5 = s6 = 0; - for (int l = 0; l < n; l += 2) { + for (short l = 0; l < 8; l += 2) { const int32_t qs = q[l/2+8]; s1 += yl[l+8] * (qs & qm[il/2][0]); s2 += yl[l+9] * (qs & qm[il/2][1]); @@ -4670,7 +4659,7 @@ void kernel_mul_mv_q3_K_f32_impl( y1 += 4 * QK_K; } - for (int row = 0; row < 2; ++row) { + for (int row = 0; row < nr0; ++row) { const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift); sumf1[row] = simd_sum(sumf); } @@ -4678,7 +4667,7 @@ void kernel_mul_mv_q3_K_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; if (tiisg == 0) { - for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { dst_f32[first_row + row] = sumf1[row]; } } @@ -4694,10 +4683,10 @@ kernel void kernel_mul_mv_q3_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q4_K_f32_impl( args_t args, device const char * src0, @@ -4707,22 +4696,22 @@ void kernel_mul_mv_q4_K_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { - const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; const uint16_t kmask3 = 0xc0c0; - const int ix = tiisg/8; // 0...3 - const int it = tiisg%8; // 0...7 - const int iq = it/4; // 0 or 1 - const int ir = it%4; // 0...3 + const short ix = tiisg/8; // 0...3 + const short it = tiisg%8; // 0...7 + const short iq = it/4; // 0 or 1 + const short ir = it%4; // 0...3 const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int first_row = r0 * N_DST; + + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -4735,7 +4724,8 @@ void kernel_mul_mv_q4_K_f32_impl( float yl[16]; float yh[16]; - float sumf[N_DST]={0.f}, all_sum; + + float sumf[nr0]={0.f}; device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; @@ -4744,7 +4734,8 @@ void kernel_mul_mv_q4_K_f32_impl( for (int ib = ix; ib < nb; ib += 4) { float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { + + for (short i = 0; i < 8; ++i) { yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0]; yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8]; yh[i+0] = y4[i+128]; sumy[2] += yh[i+0]; @@ -4755,7 +4746,7 @@ void kernel_mul_mv_q4_K_f32_impl( device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; device const half * dh = &x[ib].d; - for (int row = 0; row < N_DST; row++) { + for (short row = 0; row < nr0; row++) { sc16[0] = sc[0] & kmask1; sc16[1] = sc[2] & kmask1; sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); @@ -4765,19 +4756,21 @@ void kernel_mul_mv_q4_K_f32_impl( float4 acc1 = {0.f, 0.f, 0.f, 0.f}; float4 acc2 = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+0] * (q1[i/2] & 0x000F); - acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00); - acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0); - acc1[3] += yl[i+9] * (q1[i/2] & 0xF000); - acc2[0] += yh[i+0] * (q2[i/2] & 0x000F); - acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00); - acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0); - acc2[3] += yh[i+9] * (q2[i/2] & 0xF000); + + for (short i = 0; i < 4; ++i) { + acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F); + acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00); + acc1[2] += yl[2*i + 8] * (q1[i] & 0x00F0); + acc1[3] += yl[2*i + 9] * (q1[i] & 0xF000); + acc2[0] += yh[2*i + 0] * (q2[i] & 0x000F); + acc2[1] += yh[2*i + 1] * (q2[i] & 0x0F00); + acc2[2] += yh[2*i + 8] * (q2[i] & 0x00F0); + acc2[3] += yh[2*i + 9] * (q2[i] & 0xF000); } float dall = dh[0]; float dmin = dh[1]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + @@ -4794,10 +4787,10 @@ void kernel_mul_mv_q4_K_f32_impl( device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } @@ -4812,10 +4805,10 @@ kernel void kernel_mul_mv_q4_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q5_K_f32_impl( args_t args, device const char * src0, @@ -4832,7 +4825,7 @@ void kernel_mul_mv_q5_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -4843,7 +4836,7 @@ void kernel_mul_mv_q5_K_f32_impl( device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0); device const float * yy = (device const float *) (src1 + offset1); - float sumf[2]={0.f}; + float sumf[nr0]={0.f}; float yl[16], yh[16]; @@ -4851,15 +4844,14 @@ void kernel_mul_mv_q5_K_f32_impl( const uint16_t kmask2 = 0x0f0f; const uint16_t kmask3 = 0xc0c0; - const int tid = tiisg/4; - const int ix = tiisg%4; - const int iq = tid/4; - const int ir = tid%4; - const int n = 8; + const short tid = tiisg/4; + const short ix = tiisg%4; + const short iq = tid/4; + const short ir = tid%4; - const int l0 = n*ir; - const int q_offset = 32*iq + l0; - const int y_offset = 64*iq + l0; + const short l0 = 8*ir; + const short q_offset = 32*iq + l0; + const short y_offset = 64*iq + l0; const uint8_t hm1 = 1u << (2*iq); const uint8_t hm2 = hm1 << 1; @@ -4879,14 +4871,14 @@ void kernel_mul_mv_q5_K_f32_impl( device const float * y2 = y1 + 128; float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < 8; ++l) { + for (short l = 0; l < 8; ++l) { yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; } - for (int row = 0; row < 2; ++row) { + for (short row = 0; row < nr0; ++row) { device const uint8_t * q2 = q1 + 64; sc16[0] = a[0] & kmask1; @@ -4896,7 +4888,7 @@ void kernel_mul_mv_q5_K_f32_impl( float4 acc1 = {0.f}; float4 acc2 = {0.f}; - for (int l = 0; l < n; ++l) { + for (short l = 0; l < 8; ++l) { uint8_t h = qh[l]; acc1[0] += yl[l+0] * (q1[l] & 0x0F); acc1[1] += yl[l+8] * (q1[l] & 0xF0); @@ -4926,7 +4918,7 @@ void kernel_mul_mv_q5_K_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = tot; @@ -4944,10 +4936,10 @@ kernel void kernel_mul_mv_q5_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q6_K_f32_impl( args_t args, device const char * src0, @@ -4969,62 +4961,77 @@ void kernel_mul_mv_q6_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int row = 2*r0 + sgitg; - - if (row >= args.ne0) { - return; - } + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint64_t offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0); device const float * yy = (device const float *) (src1 + offset1); - float sumf = 0; + float sumf[nr0] = { 0.f }; - const int tid = tiisg/2; - const int ix = tiisg%2; - const int ip = tid/8; // 0 or 1 - const int il = tid%8; - const int n = 4; - const int l0 = n*il; - const int is = 8*ip + l0/16; + float yl[16]; - const int y_offset = 128*ip + l0; - const int q_offset_l = 64*ip + l0; - const int q_offset_h = 32*ip + l0; + const short tid = tiisg/2; + const short ix = tiisg%2; + const short ip = tid/8; // 0 or 1 + const short il = tid%8; + const short l0 = 4*il; + const short is = 8*ip + l0/16; + + const short y_offset = 128*ip + l0; + const short q_offset_l = 64*ip + l0; + const short q_offset_h = 32*ip + l0; for (int i = ix; i < nb; i += 2) { device const uint8_t * q1 = x[i].ql + q_offset_l; device const uint8_t * q2 = q1 + 32; device const uint8_t * qh = x[i].qh + q_offset_h; device const int8_t * sc = x[i].scales + is; + device const half * dh = &x[i].d; device const float * y = yy + i * QK_K + y_offset; - const float dall = x[i].d; - - float4 sums = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < n; ++l) { - sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); - sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); - sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); - sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + for (short l = 0; l < 4; ++l) { + yl[4*l + 0] = y[l + 0]; + yl[4*l + 1] = y[l + 32]; + yl[4*l + 2] = y[l + 64]; + yl[4*l + 3] = y[l + 96]; } - sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); + for (short row = 0; row < nr0; ++row) { + const float dall = dh[0]; + float4 sums = {0.f, 0.f, 0.f, 0.f}; + + for (short l = 0; l < 4; ++l) { + sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); + sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); + sums[2] += yl[4*l + 2] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); + sums[3] += yl[4*l + 3] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + } + + sumf[row] += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); + + q1 += args.nb01; + q2 += args.nb01; + qh += args.nb01; + sc += args.nb01; + dh += args.nb01/2; + } } device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - const float tot = simd_sum(sumf); - if (tiisg == 0) { - dst_f32[row] = tot; + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = sum_all; + } } } @@ -5038,12 +5045,12 @@ kernel void kernel_mul_mv_q6_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } // ======================= "True" 2-bit -template +template void kernel_mul_mv_iq2_xxs_f32_impl( args_t args, device const char * src0, @@ -5059,7 +5066,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5071,7 +5078,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); @@ -5092,8 +5099,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -5104,18 +5110,17 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device const uint16_t * q2 = xr->qs + 4 * ib; device const half * dh = &xr->d; - for (int row = 0; row < N_DST; row++) { - + for (short row = 0; row < nr0; row++) { const float db = dh[0]; device const uint8_t * aux8 = (device const uint8_t *)q2; const uint32_t aux32 = q2[2] | (q2[3] << 16); const float d = db * (0.5f + (aux32 >> 28)); float sum = 0; - for (int l = 0; l < 4; ++l) { + for (short l = 0; l < 4; ++l) { const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]); const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; - for (int j = 0; j < 8; ++j) { + for (short j = 0; j < 8; ++j) { sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } } @@ -5130,10 +5135,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum * 0.25f; + dst_f32[first_row + row] = sum_all * 0.25f; } } } @@ -5148,10 +5153,10 @@ kernel void kernel_mul_mv_iq2_xxs_f32( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq2_xs_f32_impl( args_t args, device const char * src0, @@ -5167,7 +5172,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5179,7 +5184,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); @@ -5200,8 +5205,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -5213,8 +5217,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( device const uint8_t * sc = xr->scales + ib; device const half * dh = &xr->d; - for (int row = 0; row < N_DST; row++) { - + for (short row = 0; row < nr0; row++) { const float db = dh[0]; const uint8_t ls1 = sc[0] & 0xf; const uint8_t ls2 = sc[0] >> 4; @@ -5222,17 +5225,17 @@ void kernel_mul_mv_iq2_xs_f32_impl( const float d2 = db * (0.5f + ls2); float sum1 = 0, sum2 = 0; - for (int l = 0; l < 2; ++l) { + for (short l = 0; l < 2; ++l) { const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); const uint8_t signs = ssigns[(q2[l] >> 9)]; - for (int j = 0; j < 8; ++j) { + for (short j = 0; j < 8; ++j) { sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } } - for (int l = 2; l < 4; ++l) { + for (short l = 2; l < 4; ++l) { const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); const uint8_t signs = ssigns[(q2[l] >> 9)]; - for (int j = 0; j < 8; ++j) { + for (short j = 0; j < 8; ++j) { sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } } @@ -5248,10 +5251,10 @@ void kernel_mul_mv_iq2_xs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum * 0.25f; + dst_f32[first_row + row] = sum_all * 0.25f; } } } @@ -5267,10 +5270,10 @@ kernel void kernel_mul_mv_iq2_xs_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq3_xxs_f32_impl( args_t args, device const char * src0, @@ -5286,7 +5289,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5298,7 +5301,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); @@ -5319,7 +5322,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -5331,17 +5334,17 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib; device const half * dh = &xr->d; - for (int row = 0; row < N_DST; row++) { + for (short row = 0; row < nr0; row++) { const float db = dh[0]; const uint32_t aux32 = gas[0] | (gas[1] << 16); const float d = db * (0.5f + (aux32 >> 28)); float2 sum = {0}; - for (int l = 0; l < 4; ++l) { + for (short l = 0; l < 4; ++l) { const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]); const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]); const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; - for (int j = 0; j < 4; ++j) { + for (short j = 0; j < 4; ++j) { sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); } @@ -5358,10 +5361,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum * 0.5f; + dst_f32[first_row + row] = sum_all * 0.5f; } } } @@ -5377,10 +5380,10 @@ kernel void kernel_mul_mv_iq3_xxs_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq3_s_f32_impl( args_t args, device const char * src0, @@ -5396,7 +5399,7 @@ void kernel_mul_mv_iq3_s_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5408,7 +5411,7 @@ void kernel_mul_mv_iq3_s_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); @@ -5425,8 +5428,7 @@ void kernel_mul_mv_iq3_s_f32_impl( device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -5440,18 +5442,17 @@ void kernel_mul_mv_iq3_s_f32_impl( device const uint8_t * signs = xr->signs + 4 * ib; device const half * dh = &xr->d; - for (int row = 0; row < N_DST; row++) { - + for (short row = 0; row < nr0; row++) { const float db = dh[0]; const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf)); float2 sum = {0}; - for (int l = 0; l < 4; ++l) { + for (short l = 0; l < 4; ++l) { const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues; const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues; const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]); const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]); - for (int j = 0; j < 4; ++j) { + for (short j = 0; j < 4; ++j) { sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); } @@ -5470,10 +5471,10 @@ void kernel_mul_mv_iq3_s_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } @@ -5489,10 +5490,10 @@ kernel void kernel_mul_mv_iq3_s_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq2_s_f32_impl( args_t args, device const char * src0, @@ -5508,7 +5509,7 @@ void kernel_mul_mv_iq2_s_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5520,7 +5521,7 @@ void kernel_mul_mv_iq2_s_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); @@ -5532,13 +5533,12 @@ void kernel_mul_mv_iq2_s_f32_impl( // threadgroup_barrier(mem_flags::mem_threadgroup); //} - const int ix = tiisg; + const short ix = tiisg; device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -5552,19 +5552,18 @@ void kernel_mul_mv_iq2_s_f32_impl( device const uint8_t * signs = qs + QK_K/8; device const half * dh = &xr->d; - for (int row = 0; row < N_DST; row++) { - + for (short row = 0; row < nr0; row++) { const float db = dh[0]; const float d1 = db * (0.5f + (sc[0] & 0xf)); const float d2 = db * (0.5f + (sc[0] >> 4)); float2 sum = {0}; - for (int l = 0; l < 2; ++l) { + for (short l = 0; l < 2; ++l) { //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); - for (int j = 0; j < 8; ++j) { + for (short j = 0; j < 8; ++j) { sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]); sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]); } @@ -5583,10 +5582,10 @@ void kernel_mul_mv_iq2_s_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum * 0.25f; + dst_f32[first_row + row] = sum_all * 0.25f; } } } @@ -5602,10 +5601,10 @@ kernel void kernel_mul_mv_iq2_s_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq1_s_f32_impl( args_t args, device const char * src0, @@ -5621,7 +5620,7 @@ void kernel_mul_mv_iq1_s_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5633,18 +5632,17 @@ void kernel_mul_mv_iq1_s_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); - const int ix = tiisg; + const short ix = tiisg; device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - float sumy = 0; - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; sumy += yl[i]; } @@ -5657,15 +5655,14 @@ void kernel_mul_mv_iq1_s_f32_impl( device const uint16_t * qh = xr->qh + ib; device const half * dh = &xr->d; - for (int row = 0; row < N_DST; row++) { - + for (short row = 0; row < nr0; row++) { constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700))); constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700))); constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700))); float sum = 0; - for (int j = 0; j < 4; ++j) { + for (short j = 0; j < 4; ++j) { sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4) + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) @@ -5683,15 +5680,28 @@ void kernel_mul_mv_iq1_s_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } -template +[[host_name("kernel_mul_mv_iq1_s_f32")]] +kernel void kernel_mul_mv_iq1_s_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +template void kernel_mul_mv_iq1_m_f32_impl( args_t args, device const char * src0, @@ -5703,11 +5713,12 @@ void kernel_mul_mv_iq1_m_f32_impl( ushort sgitg) { const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5719,20 +5730,19 @@ void kernel_mul_mv_iq1_m_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); - const int ix = tiisg; + const short ix = tiisg; device const float * y4 = y + 32 * ix; iq1m_scale_t scale; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - float4 sumy = {0.f}; - for (int i = 0; i < 8; ++i) { + for (short i = 0; i < 8; ++i) { yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8]; yl[i+16] = y4[i+16]; sumy[2] += yl[i+16]; @@ -5747,7 +5757,7 @@ void kernel_mul_mv_iq1_m_f32_impl( device const uint8_t * qh = xr->qh + 2 * ib; device const uint16_t * sc = (device const uint16_t *)xr->scales; - for (int row = 0; row < N_DST; row++) { + for (short row = 0; row < nr0; row++) { scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); @@ -5756,7 +5766,7 @@ void kernel_mul_mv_iq1_m_f32_impl( constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700))); float2 sum = {0.f}; - for (int j = 0; j < 4; ++j) { + for (short j = 0; j < 4; ++j) { sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4); sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) @@ -5778,15 +5788,28 @@ void kernel_mul_mv_iq1_m_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } -template +[[host_name("kernel_mul_mv_iq1_m_f32")]] +kernel void kernel_mul_mv_iq1_m_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +template void kernel_mul_mv_iq4_nl_f32_impl( args_t args, device const char * src0, @@ -5799,10 +5822,12 @@ void kernel_mul_mv_iq4_nl_f32_impl( threadgroup float * shmem_f32 = (threadgroup float *) shmem; const int nb = args.ne00/QK4_NL; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * 2 + sgitg) * 2; + + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5813,14 +5838,14 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); - const int ix = tiisg/2; // 0...15 - const int it = tiisg%2; // 0 or 1 + const short ix = tiisg/2; // 0...15 + const short it = tiisg%2; // 0 or 1 shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[2]={0.f}, all_sum; + float sumf[nr0]={0.f}; device const float * yb = y + ix * QK4_NL + it * 8; @@ -5830,12 +5855,13 @@ void kernel_mul_mv_iq4_nl_f32_impl( float4 qf1, qf2; for (int ib = ix; ib < nb; ib += 16) { - device const float4 * y4 = (device const float4 *)yb; - yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - - for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) { + yl[0] = y4[0]; + yl[1] = y4[4]; + yl[2] = y4[1]; + yl[3] = y4[5]; + for (short row = 0; row < nr0; row++) { device const block_iq4_nl & xb = x[row*nb + ib]; device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); @@ -5860,7 +5886,6 @@ void kernel_mul_mv_iq4_nl_f32_impl( acc1 += acc2; sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); - } yb += 16 * QK4_NL; @@ -5868,15 +5893,29 @@ void kernel_mul_mv_iq4_nl_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } -template +[[host_name("kernel_mul_mv_iq4_nl_f32")]] +kernel void kernel_mul_mv_iq4_nl_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + +template void kernel_mul_mv_iq4_xs_f32_impl( args_t args, device const char * src0, @@ -5892,7 +5931,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * 2 + sgitg) * 2; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5903,16 +5942,16 @@ void kernel_mul_mv_iq4_xs_f32_impl( device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); - const int ix = tiisg/16; // 0 or 1 - const int it = tiisg%16; // 0...15 - const int ib = it/2; - const int il = it%2; + const short ix = tiisg/16; // 0 or 1 + const short it = tiisg%16; // 0...15 + const short ib = it/2; + const short il = it%2; shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[2]={0.f}, all_sum; + float sumf[nr0]={0.f}; device const float * yb = y + ix * QK_K + ib * 32 + il * 8; @@ -5923,9 +5962,12 @@ void kernel_mul_mv_iq4_xs_f32_impl( for (int ibl = ix; ibl < nb; ibl += 2) { device const float4 * y4 = (device const float4 *)yb; - yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + yl[0] = y4[0]; + yl[1] = y4[4]; + yl[2] = y4[1]; + yl[3] = y4[5]; - for (int row = 0; row < 2; ++row) { + for (short row = 0; row < nr0; ++row) { device const block_iq4_xs & xb = x[row*nb + ibl]; device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); @@ -5949,7 +5991,6 @@ void kernel_mul_mv_iq4_xs_f32_impl( const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32; sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); - } yb += 2 * QK_K; @@ -5957,54 +5998,14 @@ void kernel_mul_mv_iq4_xs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } -[[host_name("kernel_mul_mv_iq1_s_f32")]] -kernel void kernel_mul_mv_iq1_s_f32( - constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); -} - -[[host_name("kernel_mul_mv_iq1_m_f32")]] -kernel void kernel_mul_mv_iq1_m_f32( - constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); -} - -[[host_name("kernel_mul_mv_iq4_nl_f32")]] -kernel void kernel_mul_mv_iq4_nl_f32( - constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device char * dst, - threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); -} - [[host_name("kernel_mul_mv_iq4_xs_f32")]] kernel void kernel_mul_mv_iq4_xs_f32( constant ggml_metal_kargs_mul_mv & args, @@ -6016,7 +6017,7 @@ kernel void kernel_mul_mv_iq4_xs_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } template @@ -6660,25 +6661,27 @@ template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t #if defined(GGML_METAL_USE_BF16) template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; #endif -template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; + +template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; + +template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; kernel void kernel_pool_2d_max_f32( device const float * src0,