SYCL: Add fp16 type support to unary op kernels (#12788)

* SYCL: Add fp16 support to some elementwise OP kernels

* remove comment

ggml-ci

* Use static_cast directly

* remove not needed cast from tanh

* Use static cast and remove unneeded castings

* Adjust device_support_op for unary OPs

* Use cast_data and typed_data struct to deduplicate casting code
This commit is contained in:
Akarshan Biswas 2025-04-11 13:33:50 +05:30 committed by GitHub
parent ec6c09d0fa
commit fccf9cae83
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 778 additions and 297 deletions

File diff suppressed because it is too large Load Diff

View File

@ -2,6 +2,13 @@
#define GGML_SYCL_ELEMENTWISE_HPP
#include "common.hpp"
#include "ggml.h"
#include <limits.h>
template <typename T>
T neg_infinity() {
return -std::numeric_limits<T>::infinity();
}
static __dpct_inline__ float op_repeat(const float a, const float b) {
return b;
@ -24,6 +31,19 @@ static __dpct_inline__ float op_div(const float a, const float b) {
return a / b;
}
template<typename T>
struct typed_data {
const T * src;
T * dst;
};
template<typename T>
typed_data<T> cast_data(ggml_tensor * dst) {
return {
/* .src = */ static_cast<const T *>(dst->src[0]->data),
/* .dst = */ static_cast<T *>(dst->data)
};
}
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
@ -65,6 +85,10 @@ void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
// ---------
void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

View File

@ -1617,17 +1617,6 @@ static void scale_f32(const float * x, float * dst, const float scale, const int
dst[i] = scale * x[i];
}
static void clamp_f32(const float * x, float * dst, const float min, const float max, const int k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
if (i >= k) {
return;
}
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
}
template <typename Ti, typename To>
static void pool2d_nchw_kernel(
@ -1768,18 +1757,6 @@ static void scale_f32_sycl(const float *x, float *dst, const float scale,
});
}
static void clamp_f32_sycl(const float *x, float *dst, const float min,
const float max, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
clamp_f32(x, dst, min, max, k, item_ct1);
});
}
static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
const int nrows, queue_ptr stream) {
@ -2258,26 +2235,6 @@ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor *dst
SYCL_CHECK(0);
}
inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
float min;
float max;
memcpy(&min, dst->op_params, sizeof(float));
memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
clamp_f32_sycl(src0_dd, dst_dd, min, max, ggml_nelements(dst->src[0]), ctx.stream());
/*
DPCT1010:88: SYCL uses exceptions to report errors and does not use the
error codes. The call was replaced with 0. You need to rewrite this code.
*/
SYCL_CHECK(0);
}
static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
static bool peer_access_enabled = false;
@ -3218,10 +3175,6 @@ static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
ggml_sycl_op_scale(ctx, dst);
}
static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_clamp(ctx, dst);
}
static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_diag_mask_inf(ctx, dst);
}
@ -3900,7 +3853,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_EXP:
return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32);
#if defined (GGML_SYCL_F16)
return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);
#else
return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
#endif
default:
return false;
}
@ -4022,13 +3979,18 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
return (op->src[0]->type == GGML_TYPE_F32);
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_CLAMP:
case GGML_OP_LOG:
return (op->src[0]->type == GGML_TYPE_F32);
#if defined (GGML_SYCL_F16)
return ((op->type == GGML_TYPE_F32 || op->type == GGML_SYCL_F16) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_SYCL_F16) && (op->type == op->src[0]->type));
#else
return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
#endif
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM: