diff --git a/ci/README.md b/ci/README.md
index 2d0004b18..ec3f44350 100644
--- a/ci/README.md
+++ b/ci/README.md
@@ -60,7 +60,7 @@ docker run --privileged -it \
 Inside the container, execute the following commands:
 
 ```bash
-apt update -y && apt install -y bc cmake git python3.10-venv time unzip wget
+apt update -y && apt install -y bc cmake ccache git python3.10-venv time unzip wget
 git config --global --add safe.directory /ws
 GG_BUILD_MUSA=1 bash ./ci/run.sh /ci-results /ci-cache
 ```
diff --git a/ci/run.sh b/ci/run.sh
index efc24391d..4d9c2dc48 100755
--- a/ci/run.sh
+++ b/ci/run.sh
@@ -69,7 +69,7 @@ fi
 if [ ! -z ${GG_BUILD_MUSA} ]; then
     # Use qy1 by default (MTT S80)
     MUSA_ARCH=${MUSA_ARCH:-21}
-    CMAKE_EXTRA="-DGGML_MUSA=ON -DMUSA_ARCHITECTURES=${MUSA_ARCH}"
+    CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_MUSA=ON -DMUSA_ARCHITECTURES=${MUSA_ARCH}"
 fi
 ## helpers
 
diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h
index 6c02b69ea..086c822d7 100644
--- a/ggml/src/ggml-common.h
+++ b/ggml/src/ggml-common.h
@@ -158,6 +158,12 @@ typedef sycl::half2 ggml_half2;
 
 #endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP
 
+#ifdef _MSC_VER
+#define GGML_EXTENSION
+#else // _MSC_VER
+#define GGML_EXTENSION __extension__
+#endif // _MSC_VER
+
 #define QK4_0 32
 typedef struct {
     ggml_half d;           // delta
@@ -167,7 +173,7 @@ static_assert(sizeof(block_q4_0) == sizeof(ggml_half) + QK4_0 / 2, "wrong q4_0 b
 
 #define QK4_1 32
 typedef struct {
-    union {
+    GGML_EXTENSION union {
         struct {
             ggml_half d; // delta
             ggml_half m; // min
@@ -188,7 +194,7 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_half) + sizeof(uint32_t) + QK5_0
 
 #define QK5_1 32
 typedef struct {
-    union {
+    GGML_EXTENSION union {
         struct {
             ggml_half d; // delta
             ggml_half m; // min
@@ -209,7 +215,7 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_half) + QK8_0, "wrong q8_0 block
 
 #define QK8_1 32
 typedef struct {
-    union {
+    GGML_EXTENSION union {
         struct {
             ggml_half d; // delta
             ggml_half s; // d * sum(qs[i])
@@ -250,7 +256,7 @@ static_assert(sizeof(block_tq2_0) == sizeof(ggml_half) + QK_K / 4, "wrong tq2_0
 typedef struct {
     uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
     uint8_t qs[QK_K/4];      // quants
-    union {
+    GGML_EXTENSION union {
         struct {
             ggml_half d;    // super-block scale for quantized scales
             ggml_half dmin; // super-block scale for quantized mins
@@ -277,7 +283,7 @@ static_assert(sizeof(block_q3_K) == sizeof(ggml_half) + QK_K / 4 + QK_K / 8 + 12
 // weight is represented as x = a * q + b
 // Effectively 4.5 bits per weight
 typedef struct {
-    union {
+    GGML_EXTENSION union {
         struct {
             ggml_half d;    // super-block scale for quantized scales
             ggml_half dmin; // super-block scale for quantized mins
@@ -294,7 +300,7 @@ static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2,
 // weight is represented as x = a * q + b
 // Effectively 5.5 bits per weight
 typedef struct {
-    union {
+    GGML_EXTENSION union {
         struct {
             ggml_half d;    // super-block scale for quantized scales
             ggml_half dmin; // super-block scale for quantized mins
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index f8c55a2b8..a718b6a12 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -288,6 +288,10 @@ static __device__ void no_device_code(
     __trap();
 
     GGML_UNUSED(no_device_code); // suppress unused function warning
+
+#if defined(GGML_USE_MUSA)
+    __builtin_unreachable();
+#endif // defined(GGML_USE_MUSA)
 }
 
 #ifdef __CUDA_ARCH__
diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu
index aafbaf803..e9ffd274b 100644
--- a/ggml/src/ggml-cuda/concat.cu
+++ b/ggml/src/ggml-cuda/concat.cu
@@ -38,7 +38,7 @@ static __global__ void concat_f32_dim1(const float * x, const float * y, float *
         blockIdx.y * ne0 +
         blockIdx.z * ne0 * gridDim.y;
 
-    if (blockIdx.y < ne01) { // src0
+    if (blockIdx.y < (unsigned)ne01) { // src0
         int offset_src =
             nidx +
             blockIdx.y * ne0 +
@@ -64,7 +64,7 @@ static __global__ void concat_f32_dim2(const float * x, const float * y, float *
         blockIdx.y * ne0 +
         blockIdx.z * ne0 * gridDim.y;
 
-    if (blockIdx.z < ne02) { // src0
+    if (blockIdx.z < (unsigned)ne02) { // src0
         int offset_src =
             nidx +
             blockIdx.y * ne0 +
diff --git a/ggml/src/ggml-cuda/conv-transpose-1d.cu b/ggml/src/ggml-cuda/conv-transpose-1d.cu
index b1e94d6f7..fe4caf674 100644
--- a/ggml/src/ggml-cuda/conv-transpose-1d.cu
+++ b/ggml/src/ggml-cuda/conv-transpose-1d.cu
@@ -34,6 +34,10 @@ static  __global__ void conv_transpose_1d_kernel(
         }
     }
     dst[global_index] = accumulator;
+    GGML_UNUSED(p0); GGML_UNUSED(d0); GGML_UNUSED(src0_ne3);
+    GGML_UNUSED(src1_ne3); GGML_UNUSED(dst_ne3);
+    GGML_UNUSED(src1_ne1); GGML_UNUSED(dst_ne1);
+    GGML_UNUSED(src1_ne2); GGML_UNUSED(dst_ne2);
 }
 
 static void conv_transpose_1d_f32_f32_cuda(
@@ -75,8 +79,6 @@ void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor
     const int p0 = 0;//opts[3];
     const int d0 = 1;//opts[4];
 
-    const int64_t kernel_size = ggml_nelements(src0);
-    const int64_t input_size = ggml_nelements(src1);
     const int64_t output_size = ggml_nelements(dst);
 
     conv_transpose_1d_f32_f32_cuda(s0, p0, d0, output_size,
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
index 795b720d6..2997e2b4d 100644
--- a/ggml/src/ggml-cuda/convert.cu
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -577,7 +577,7 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res
         return;
     }
 
-    const src_t * x = (src_t *) vx;
+    const src_t * x = (const src_t *) vx;
 
     y[i] = x[i];
 }
diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh
index 1c2a2a138..3fe22092f 100644
--- a/ggml/src/ggml-cuda/fattn-common.cuh
+++ b/ggml/src/ggml-cuda/fattn-common.cuh
@@ -315,14 +315,14 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared(
 
     float vals[sizeof(int)] = {0.0f};
 #pragma unroll
-    for (int l = 0; l < sizeof(int); ++l) {
+    for (int l = 0; l < int(sizeof(int)); ++l) {
         vals[l] = scale * x[4*threadIdx.x + l];
     }
 
     float amax = fabsf(vals[0]);
     float sum  = vals[0];
 #pragma unroll
-    for (int l = 1; l < sizeof(int); ++l) {
+    for (int l = 1; l < int(sizeof(int)); ++l) {
         amax = fmaxf(amax, fabsf(vals[l]));
         sum += vals[l];
     }
@@ -338,7 +338,7 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared(
 
     if (d != 0.0f) {
 #pragma unroll
-        for (int l = 0; l < sizeof(int); ++l) {
+        for (int l = 0; l < int(sizeof(int)); ++l) {
             q8[l] = roundf(vals[l] / d);
         }
     }
@@ -638,7 +638,7 @@ static __global__ void flash_attn_combine_results(
     float VKQ_denominator = 0.0f;
     for (int l = 0; l < parallel_blocks; ++l) {
         const float diff = meta[l].x - kqmax;
-        const float KQ_max_scale = expf(diff);
+        float KQ_max_scale = expf(diff);
         const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
         *((uint32_t *) &KQ_max_scale) &= ftz_mask;
 
@@ -649,6 +649,7 @@ static __global__ void flash_attn_combine_results(
     dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
 }
 
+[[noreturn]]
 static void on_no_fattn_vec_case(const int D) {
     if (D == 64) {
         fprintf(stderr, "Unsupported KV type combination for head_size 64.\n");
diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
index 024032f62..04804a15c 100644
--- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh
+++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
@@ -406,6 +406,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
 #endif // CP_ASYNC_AVAILABLE
 
 #else
+    GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
+    GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
+    GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
+    GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_KV);
+    GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
+    GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
+    GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
+    GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
+    GGML_UNUSED(kb0);
     NO_DEVICE_CODE;
 #endif // NEW_MMA_AVAILABLE
 }
@@ -797,6 +806,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         __syncthreads();
     }
 #else
+    GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
+    GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
+    GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
+    GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1);
+    GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_KV); GGML_UNUSED(stride_mask);
+    GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
     NO_DEVICE_CODE;
 #endif // NEW_MMA_AVAILABLE
 }
@@ -931,6 +946,16 @@ static __global__ void flash_attn_ext_f16(
         (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
          ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
 #else
+    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+    GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
+    GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
+    GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
+    GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
+    GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
+    GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
+    GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
+    GGML_UNUSED(ne2); GGML_UNUSED(ne3);
     NO_DEVICE_CODE;
 #endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
 }
@@ -985,38 +1010,38 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
     extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/4, 4); \
     extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/8, 8); \
 
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,   8);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,   8);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,   8);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,   8);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,   8);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,   8);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,   8)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,   8)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,   8)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,   8)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,   8)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,   8)
 
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  16);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  16);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  16);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,  16);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,  16);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,  16);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  16)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  16)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  16)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,  16)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,  16)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,  16)
 
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  32);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  32);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  32);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,  32);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,  32);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,  32);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  32)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  32)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  32)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,  32)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,  32)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,  32)
 
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  64);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  64);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  64);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,  64);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,  64);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,  64);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  64)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  64)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  64)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,  64)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,  64)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,  64)
 
 // Kernels with ncols == 128 are only 4% faster due to register pressure.
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128);
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128);
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128);
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128);
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128);
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128); // Needs too much shared memory.
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128)
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128)
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128)
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128)
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128)
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128) // Needs too much shared memory.
diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu
index 77455d8e4..e0039e175 100644
--- a/ggml/src/ggml-cuda/fattn-tile-f16.cu
+++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu
@@ -282,7 +282,19 @@ static __global__ void flash_attn_tile_ext_f16(
         }
     }
 #else
-   NO_DEVICE_CODE;
+    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+    GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
+    GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
+    GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
+    GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
+    GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
+    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+    GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
+    GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
+    GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
+    GGML_UNUSED(ne2); GGML_UNUSED(ne3);
+    NO_DEVICE_CODE;
 #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
 }
 
diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu
index 85fea4404..81290c901 100644
--- a/ggml/src/ggml-cuda/fattn-tile-f32.cu
+++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu
@@ -281,6 +281,18 @@ static __global__ void flash_attn_tile_ext_f32(
         }
     }
 #else
+    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+    GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
+    GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
+    GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
+    GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
+    GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
+    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+    GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
+    GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
+    GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
+    GGML_UNUSED(ne2); GGML_UNUSED(ne3);
     NO_DEVICE_CODE;
 #endif // FLASH_ATTN_AVAILABLE
 }
diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh
index 32c52ebe3..e17d2d0e4 100644
--- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh
+++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh
@@ -292,7 +292,19 @@ static __global__ void flash_attn_vec_ext_f16(
         dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
     }
 #else
-   NO_DEVICE_CODE;
+    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+    GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
+    GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
+    GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
+    GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
+    GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
+    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+    GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
+    GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
+    GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
+    GGML_UNUSED(ne2); GGML_UNUSED(ne3);
+    NO_DEVICE_CODE;
 #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
 }
 
diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh
index 336c136d1..704874855 100644
--- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh
+++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh
@@ -277,6 +277,16 @@ static __global__ void flash_attn_vec_ext_f32(
         dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
     }
 #else
+    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+    GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
+    GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
+    GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
+    GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
+    GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
+    GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
+    GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
+    GGML_UNUSED(ne2); GGML_UNUSED(ne3);
     NO_DEVICE_CODE;
 #endif // FLASH_ATTN_AVAILABLE
 }
diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu
index 5c214ea31..bc21b27a0 100644
--- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu
+++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu
@@ -430,7 +430,17 @@ static __global__ void flash_attn_ext_f16(
         dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val;
     }
 #else
-   NO_DEVICE_CODE;
+    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+    GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
+    GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
+    GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
+    GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
+    GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
+    GGML_UNUSED(ne31); GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+    GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
+    GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
+    GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
+    NO_DEVICE_CODE;
 #endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
 }
 
diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh
index 9206bfeba..2af63355a 100644
--- a/ggml/src/ggml-cuda/mma.cuh
+++ b/ggml/src/ggml-cuda/mma.cuh
@@ -26,6 +26,7 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
     asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
         : "=r"(ret) : "r"(x));
 #else
+    GGML_UNUSED(x);
     NO_DEVICE_CODE;
 #endif // defined(NEW_MMA_AVAILABLE)
     return ret;
@@ -178,6 +179,7 @@ namespace ggml_cuda_mma {
             : "l"(xs));
 #else
         load_generic(xs0, stride);
+        GGML_UNUSED(t);
 #endif // NEW_MMA_AVAILABLE
     }
 
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
index f136c4195..532358018 100644
--- a/ggml/src/ggml-cuda/mmq.cuh
+++ b/ggml/src/ggml-cuda/mmq.cuh
@@ -945,7 +945,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
         }
     }
 #else
-    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
     NO_DEVICE_CODE;
 #endif // NEW_MMA_AVAILABLE
 }
@@ -1024,7 +1024,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
     }
 
 #pragma unroll
-    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
+    for (int k01 = 0; k01 < WARP_SIZE/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
         const int k0 = k00 + k01;
 
 #pragma unroll
@@ -1035,19 +1035,34 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
             for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
                 const int i = i0 + threadIdx.x;
 
-                if (k01 < WARP_SIZE/2) {
-                    constexpr int ns = 2;
-                    sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
-                        &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
-                        &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
-                        &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
-                } else {
-                    constexpr int ns = 1;
-                    sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
-                        &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
-                        &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
-                        &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
-                }
+                constexpr int ns = 2;
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
+                    &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+                    &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
+                    &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
+            }
+        }
+    }
+
+    // Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop.
+    // As a workaround 2 separate loops are used instead.
+#pragma unroll
+    for (int k01 = WARP_SIZE/2; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                constexpr int ns = 1;
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
+                    &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+                    &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
+                    &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
             }
         }
     }
@@ -1176,7 +1191,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
         }
     }
 #else
-    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
     NO_DEVICE_CODE;
 #endif // NEW_MMA_AVAILABLE
 }
@@ -1253,7 +1268,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const float d = bxi->d;
 
 #pragma unroll
-        for (int l = 0; l < sizeof(int); ++l) {
+        for (int l = 0; l < int(sizeof(int)); ++l) {
             x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*(threadIdx.x % (WARP_SIZE/8)) + l] = d*sc8[l];
         }
 #else
@@ -1376,7 +1391,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
 
 #pragma unroll
-        for (int l = 0; l < sizeof(int); ++l) {
+        for (int l = 0; l < int(sizeof(int)); ++l) {
             x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
         }
     }
@@ -1517,7 +1532,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
 
 #pragma unroll
-        for (int l = 0; l < sizeof(int); ++l) {
+        for (int l = 0; l < int(sizeof(int)); ++l) {
             x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
         }
     }
@@ -1810,7 +1825,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
         }
     }
 #else
-    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
     NO_DEVICE_CODE;
 #endif // NEW_MMA_AVAILABLE
 }
@@ -2570,6 +2585,8 @@ static __device__ void mul_mat_q_process_tile(
     } else {
         write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j);
     }
+
+    GGML_UNUSED(ne00); GGML_UNUSED(ne10);
 }
 
 
@@ -2695,7 +2712,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
         const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
 
         // Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block:
-        if (it != blockIdx.x || jt != blockIdx.y) {
+        if ((unsigned)it != blockIdx.x || (unsigned)jt != blockIdx.y) {
             continue;
         }
 
@@ -2825,7 +2842,6 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
 template <ggml_type type>
 void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
     const int id    = ggml_cuda_get_device();
-    const int nsm   = ggml_cuda_info().devices[id].nsm;
     const int cc    = ggml_cuda_info().devices[id].cc;
     const int smpbo = ggml_cuda_info().devices[id].smpbo;
 
diff --git a/ggml/src/ggml-cuda/mmv.cu b/ggml/src/ggml-cuda/mmv.cu
index f89ed03b5..b39961cd1 100644
--- a/ggml/src/ggml-cuda/mmv.cu
+++ b/ggml/src/ggml-cuda/mmv.cu
@@ -29,7 +29,7 @@ static __global__ void mul_mat_vec(
         __syncthreads();
     }
 
-    float sumf;
+    float sumf = 0.0f;
 
     if constexpr (std::is_same<T, half>::value) {
         const half2 * x2 = (const half2 *) x;
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index 45ea30f62..eef8585a7 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -151,7 +151,7 @@ static __global__ void mul_mat_vec_q(
     constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
 
     // partial sum for each thread
-    float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
+    float tmp[ncols_y][rows_per_cuda_block] = {{0.0f}};
 
     const block_q8_1 * y = (const block_q8_1 *) vy;
 
@@ -197,10 +197,12 @@ static __global__ void mul_mat_vec_q(
             tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
         }
 
-        if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
+        if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < (unsigned)nrows_dst)) {
             dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
         }
     }
+
+    GGML_UNUSED(nrows_x);
 }
 
 static std::pair<dim3, dim3> calc_launch_params(const int ncols_y, const int nrows_x, const int warp_size, const mmvq_parameter_table_id table_id) {
diff --git a/ggml/src/ggml-cuda/pad.cu b/ggml/src/ggml-cuda/pad.cu
index aba539e8d..77432b046 100644
--- a/ggml/src/ggml-cuda/pad.cu
+++ b/ggml/src/ggml-cuda/pad.cu
@@ -14,7 +14,7 @@ static __global__ void pad_f32(const float * x, float * dst, const int ne0, cons
         nidx +
         blockIdx.y * ne0 +
         blockIdx.z * ne0 * gridDim.y;
-    if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
+    if (nidx < ne00 && blockIdx.y < (unsigned)ne01 && blockIdx.z < (unsigned)(ne02*ne03)) {
         int offset_src =
             nidx +
             blockIdx.y * ne00 +
diff --git a/ggml/src/ggml-cuda/upscale.cu b/ggml/src/ggml-cuda/upscale.cu
index cf513c3ad..524e97957 100644
--- a/ggml/src/ggml-cuda/upscale.cu
+++ b/ggml/src/ggml-cuda/upscale.cu
@@ -19,7 +19,7 @@ static __global__ void upscale_f32(const float * x, float * dst,
     int i02 = i12 / sf2;
     int i03 = i13 / sf3;
 
-    dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
+    dst[index] = *( (const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00) );
 }
 
 static void upscale_f32_cuda(const float * x, float * dst,