mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-04-16 11:36:08 +00:00
Support pure float16 add/sub/mul/div operations in the CUDA (and CPU) backend (ggml/1121)
* Support float16-to-float16 add/sub/mul/div operations in the CUDA backend * Add fp16 support for add/sub/mul/div on the CPU backend * Add test cases for fp16 add/sub/mul/div
This commit is contained in:
parent
aede2074f6
commit
f54a4ba11e
@ -1415,15 +1415,35 @@ inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x)
|
||||
inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
||||
inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
||||
inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
|
||||
inline static void ggml_vec_add_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) + GGML_FP16_TO_FP32(y[i]));
|
||||
}
|
||||
}
|
||||
inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
|
||||
inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
|
||||
inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; }
|
||||
inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; }
|
||||
inline static void ggml_vec_sub_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) - GGML_FP16_TO_FP32(y[i]));
|
||||
}
|
||||
}
|
||||
inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
||||
inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
|
||||
inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; }
|
||||
inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
|
||||
inline static void ggml_vec_mul_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) * GGML_FP16_TO_FP32(y[i]));
|
||||
}
|
||||
}
|
||||
inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
|
||||
inline static void ggml_vec_div_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) / GGML_FP16_TO_FP32(y[i]));
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
@ -4379,7 +4399,7 @@ static void ggml_compute_forward_add_f16_f16(
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
||||
GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
@ -4404,17 +4424,22 @@ static void ggml_compute_forward_add_f16_f16(
|
||||
|
||||
if (nb10 == sizeof(ggml_fp16_t)) {
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
// src0, src1 and dst are same shape => same indices
|
||||
const int i3 = ir/(ne2*ne1);
|
||||
const int i2 = (ir - i3*ne2*ne1)/ne1;
|
||||
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
||||
// src1 is broadcastable across src0 and dst in i1, i2, i3
|
||||
const int64_t i03 = ir/(ne02*ne01);
|
||||
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
||||
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
|
||||
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
|
||||
ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
|
||||
const int64_t i13 = i03 % ne13;
|
||||
const int64_t i12 = i02 % ne12;
|
||||
const int64_t i11 = i01 % ne11;
|
||||
const int64_t nr0 = ne00 / ne10;
|
||||
|
||||
for (int i = 0; i < ne0; i++) {
|
||||
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(src1_ptr[i]));
|
||||
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
|
||||
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
||||
ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
|
||||
|
||||
for (int64_t r = 0; r < nr0; ++r) {
|
||||
ggml_vec_add_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -5202,6 +5227,62 @@ static void ggml_compute_forward_sub_f32(
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_sub_f16(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
assert(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int nr = ggml_nrows(src0);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F16);
|
||||
|
||||
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
|
||||
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
if (nb10 == sizeof(ggml_fp16_t)) {
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
// src1 is broadcastable across src0 and dst in i1, i2, i3
|
||||
const int64_t i03 = ir/(ne02*ne01);
|
||||
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
||||
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
const int64_t i13 = i03 % ne13;
|
||||
const int64_t i12 = i02 % ne12;
|
||||
const int64_t i11 = i01 % ne11;
|
||||
const int64_t nr0 = ne00 / ne10;
|
||||
|
||||
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
|
||||
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
||||
ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
|
||||
|
||||
for (int64_t r = 0; r < nr0; ++r) {
|
||||
ggml_vec_sub_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// src1 is not contiguous
|
||||
GGML_ABORT("unimplemented error");
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_sub(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
@ -5213,6 +5294,10 @@ static void ggml_compute_forward_sub(
|
||||
{
|
||||
ggml_compute_forward_sub_f32(params, dst);
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
ggml_compute_forward_sub_f16(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
@ -5293,6 +5378,55 @@ static void ggml_compute_forward_mul_f32(
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_mul_f16(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int64_t nr = ggml_nrows(src0);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F16);
|
||||
|
||||
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
|
||||
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
||||
|
||||
if (nb10 == sizeof(ggml_fp16_t)) {
|
||||
for (int64_t ir = ith; ir < nr; ir += nth) {
|
||||
// src0 and dst are same shape => same indices
|
||||
const int64_t i03 = ir/(ne02*ne01);
|
||||
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
||||
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
const int64_t i13 = i03 % ne13;
|
||||
const int64_t i12 = i02 % ne12;
|
||||
const int64_t i11 = i01 % ne11;
|
||||
const int64_t nr0 = ne00 / ne10;
|
||||
|
||||
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
|
||||
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
||||
ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
|
||||
|
||||
for (int64_t r = 0 ; r < nr0; ++r) {
|
||||
ggml_vec_mul_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// src1 is not contiguous
|
||||
GGML_ABORT("unimplemented error");
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_mul(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
@ -5300,13 +5434,17 @@ static void ggml_compute_forward_mul(
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32 && "only f32 src1 supported for now");
|
||||
GGML_ASSERT((src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16) && "only f32/f16 src1 supported for now");
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_mul_f32(params, dst);
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
ggml_compute_forward_mul_f16(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
@ -5387,6 +5525,55 @@ static void ggml_compute_forward_div_f32(
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_div_f16(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int64_t nr = ggml_nrows(src0);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F16);
|
||||
|
||||
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
|
||||
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
||||
|
||||
if (nb10 == sizeof(ggml_fp16_t)) {
|
||||
for (int64_t ir = ith; ir < nr; ir += nth) {
|
||||
// src0 and dst are same shape => same indices
|
||||
const int64_t i03 = ir/(ne02*ne01);
|
||||
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
||||
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
const int64_t i13 = i03 % ne13;
|
||||
const int64_t i12 = i02 % ne12;
|
||||
const int64_t i11 = i01 % ne11;
|
||||
const int64_t nr0 = ne00 / ne10;
|
||||
|
||||
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
|
||||
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
||||
ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
|
||||
|
||||
for (int64_t r = 0; r < nr0; ++r) {
|
||||
ggml_vec_div_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// src1 is not contiguous
|
||||
GGML_ABORT("unimplemented error");
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_div(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
@ -5398,6 +5585,10 @@ static void ggml_compute_forward_div(
|
||||
{
|
||||
ggml_compute_forward_div_f32(params, dst);
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
ggml_compute_forward_div_f16(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
|
@ -294,11 +294,13 @@ static void ggml_cuda_op_bin_bcast(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
|
||||
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
||||
op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (half *) dst_dd, stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
|
||||
op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
||||
op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
|
||||
|
@ -3961,37 +3961,38 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr));
|
||||
}
|
||||
};
|
||||
|
||||
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 8, 1}, {1, 1, 1, 1});
|
||||
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1, 1}, {32, 1, 1, 1});
|
||||
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 320, 320}, {1, 1, 1, 1});
|
||||
add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 1, 1}, {1, 1, 1, 1});
|
||||
add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 1}, {1, 1, 1, 1});
|
||||
add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 1});
|
||||
add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1});
|
||||
add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 2, 1, 1});
|
||||
add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 1});
|
||||
add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 2});
|
||||
add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 2});
|
||||
add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 2, 2, 2});
|
||||
add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {2, 2, 2, 2});
|
||||
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
|
||||
add_test_bin_bcast(type, {1, 1, 8, 1}, {1, 1, 1, 1});
|
||||
add_test_bin_bcast(type, {1, 1, 1, 1}, {32, 1, 1, 1});
|
||||
add_test_bin_bcast(type, {1, 1, 320, 320}, {1, 1, 1, 1});
|
||||
add_test_bin_bcast(type, {10, 5, 1, 1}, {1, 1, 1, 1});
|
||||
add_test_bin_bcast(type, {10, 5, 4, 1}, {1, 1, 1, 1});
|
||||
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 1});
|
||||
add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 1, 1, 1});
|
||||
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 1, 1});
|
||||
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 1});
|
||||
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 2});
|
||||
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 2});
|
||||
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2});
|
||||
add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2});
|
||||
|
||||
// stable diffusion
|
||||
add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 1, 1, 1});
|
||||
add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 16, 16, 1});
|
||||
add_test_bin_bcast(GGML_TYPE_F32, {1280, 16, 16, 1}, {1, 1, 1, 1});
|
||||
add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 256, 1, 1});
|
||||
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1280, 1}, {16, 16, 1, 1});
|
||||
add_test_bin_bcast(GGML_TYPE_F32, {16, 16, 1280, 1}, {1, 1, 1, 1});
|
||||
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1920, 1}, {16, 16, 1, 1});
|
||||
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 2560, 1}, {16, 16, 1, 1});
|
||||
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1280, 1}, {32, 32, 1, 1});
|
||||
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1920, 1}, {32, 32, 1, 1});
|
||||
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 640, 1}, {32, 32, 1, 1});
|
||||
add_test_bin_bcast(GGML_TYPE_F32, {5120, 1, 1, 1}, {1, 256, 1, 1});
|
||||
add_test_bin_bcast(GGML_TYPE_F32, {640, 1, 1, 1}, {1, 1, 1, 1});
|
||||
//add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {1, 1, 1, 1});
|
||||
//add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {2, 1, 1, 1});
|
||||
add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 1, 1, 1});
|
||||
add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 16, 16, 1});
|
||||
add_test_bin_bcast(type, {1280, 16, 16, 1}, {1, 1, 1, 1});
|
||||
add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 256, 1, 1});
|
||||
add_test_bin_bcast(type, {1, 1, 1280, 1}, {16, 16, 1, 1});
|
||||
add_test_bin_bcast(type, {16, 16, 1280, 1}, {1, 1, 1, 1});
|
||||
add_test_bin_bcast(type, {1, 1, 1920, 1}, {16, 16, 1, 1});
|
||||
add_test_bin_bcast(type, {1, 1, 2560, 1}, {16, 16, 1, 1});
|
||||
add_test_bin_bcast(type, {1, 1, 1280, 1}, {32, 32, 1, 1});
|
||||
add_test_bin_bcast(type, {1, 1, 1920, 1}, {32, 32, 1, 1});
|
||||
add_test_bin_bcast(type, {1, 1, 640, 1}, {32, 32, 1, 1});
|
||||
add_test_bin_bcast(type, {5120, 1, 1, 1}, {1, 256, 1, 1});
|
||||
add_test_bin_bcast(type, {640, 1, 1, 1}, {1, 1, 1, 1});
|
||||
//add_test_bin_bcast(type, {3, 3, 2560, 1280}, {1, 1, 1, 1});
|
||||
//add_test_bin_bcast(type, {3, 3, 2560, 1280}, {2, 1, 1, 1});
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_add1());
|
||||
test_cases.emplace_back(new test_scale());
|
||||
|
Loading…
x
Reference in New Issue
Block a user