mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-27 02:36:05 +00:00
[mlir][linalg][python] Add max operation in OpDSL
Add the max operation to the OpDSL and introduce a max pooling operation to test the implementation. As MLIR has no builtin max operation, the max function is lowered to a compare and select pair. Differential Revision: https://reviews.llvm.org/D105203
This commit is contained in:
parent
0c53f602d5
commit
3b95400f78
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: matmul
|
name: matmul
|
||||||
@ -594,6 +593,77 @@ structured_op: !LinalgStructuredOpConfig
|
|||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_arg: I
|
scalar_arg: I
|
||||||
--- !LinalgOpConfig
|
--- !LinalgOpConfig
|
||||||
|
metadata: !LinalgOpMetadata
|
||||||
|
name: pooling_nhwc_max_poly
|
||||||
|
cpp_class_name: PoolingNhwcMaxPolyOp
|
||||||
|
doc: |-
|
||||||
|
Performs max pooling.
|
||||||
|
|
||||||
|
Numeric casting is performed on the input operand, promoting it to the same
|
||||||
|
data type as the accumulator/output.
|
||||||
|
structured_op: !LinalgStructuredOpConfig
|
||||||
|
args:
|
||||||
|
- !LinalgOperandDefConfig
|
||||||
|
name: I
|
||||||
|
usage: InputOperand
|
||||||
|
type_var: T1
|
||||||
|
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
|
||||||
|
(s0, s1, s2, s3)>
|
||||||
|
- !LinalgOperandDefConfig
|
||||||
|
name: K
|
||||||
|
usage: InputOperand
|
||||||
|
type_var: T2
|
||||||
|
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
|
||||||
|
(s4, s5)>
|
||||||
|
- !LinalgOperandDefConfig
|
||||||
|
name: O
|
||||||
|
usage: OutputOperand
|
||||||
|
type_var: U
|
||||||
|
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
|
||||||
|
(s0, s6, s7, s3)>
|
||||||
|
- !LinalgOperandDefConfig
|
||||||
|
name: strides
|
||||||
|
usage: IndexAttribute
|
||||||
|
type_var: I64
|
||||||
|
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
|
||||||
|
-> (s8, s9)>
|
||||||
|
- !LinalgOperandDefConfig
|
||||||
|
name: dilations
|
||||||
|
usage: IndexAttribute
|
||||||
|
type_var: I64
|
||||||
|
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
|
||||||
|
-> (s10, s11)>
|
||||||
|
indexing_maps: !LinalgIndexingMapsConfig
|
||||||
|
static_indexing_maps:
|
||||||
|
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
|
||||||
|
s10, s11] -> (d0, d1 * s8 + d3 * s10, d2 * s9 + d4 * s11, d5)>
|
||||||
|
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
|
||||||
|
s10, s11] -> (d3, d4)>
|
||||||
|
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
|
||||||
|
s10, s11] -> (d0, d1, d2, d5)>
|
||||||
|
iterator_types:
|
||||||
|
- parallel
|
||||||
|
- parallel
|
||||||
|
- parallel
|
||||||
|
- reduction
|
||||||
|
- reduction
|
||||||
|
- parallel
|
||||||
|
assignments:
|
||||||
|
- !ScalarAssign
|
||||||
|
arg: O
|
||||||
|
value: !ScalarExpression
|
||||||
|
scalar_apply:
|
||||||
|
fn_name: max
|
||||||
|
operands:
|
||||||
|
- !ScalarExpression
|
||||||
|
scalar_arg: O
|
||||||
|
- !ScalarExpression
|
||||||
|
symbolic_cast:
|
||||||
|
type_var: U
|
||||||
|
operands:
|
||||||
|
- !ScalarExpression
|
||||||
|
scalar_arg: I
|
||||||
|
--- !LinalgOpConfig
|
||||||
metadata: !LinalgOpMetadata
|
metadata: !LinalgOpMetadata
|
||||||
name: fill_rng_2d
|
name: fill_rng_2d
|
||||||
cpp_class_name: FillRng2DOp
|
cpp_class_name: FillRng2DOp
|
||||||
|
@ -274,6 +274,21 @@ public:
|
|||||||
llvm_unreachable("unsupported non numeric type");
|
llvm_unreachable("unsupported non numeric type");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value applyfn__max(Value lhs, Value rhs) {
|
||||||
|
OpBuilder builder = getBuilder();
|
||||||
|
if (isFloatingPoint(lhs)) {
|
||||||
|
Value condition =
|
||||||
|
builder.create<CmpFOp>(lhs.getLoc(), CmpFPredicate::OGT, lhs, rhs);
|
||||||
|
return builder.create<SelectOp>(lhs.getLoc(), condition, lhs, rhs);
|
||||||
|
}
|
||||||
|
if (isInteger(lhs)) {
|
||||||
|
Value condition =
|
||||||
|
builder.create<CmpIOp>(lhs.getLoc(), CmpIPredicate::sgt, lhs, rhs);
|
||||||
|
return builder.create<SelectOp>(lhs.getLoc(), condition, lhs, rhs);
|
||||||
|
}
|
||||||
|
llvm_unreachable("unsupported non numeric type");
|
||||||
|
}
|
||||||
|
|
||||||
void yieldOutputs(ValueRange values) {
|
void yieldOutputs(ValueRange values) {
|
||||||
assert(!values.empty() && "linalg ops must yield outputs");
|
assert(!values.empty() && "linalg ops must yield outputs");
|
||||||
if (values.empty())
|
if (values.empty())
|
||||||
|
@ -307,6 +307,18 @@ class _BodyBuilder:
|
|||||||
return std.MulIOp(lhs.type, lhs, rhs).result
|
return std.MulIOp(lhs.type, lhs, rhs).result
|
||||||
raise NotImplementedError("Unsupported 'mul' operand: {lhs}")
|
raise NotImplementedError("Unsupported 'mul' operand: {lhs}")
|
||||||
|
|
||||||
|
def _eval_max(self, lhs: Value, rhs: Value) -> Value:
|
||||||
|
i1 = IntegerType.get_signless(1)
|
||||||
|
if _is_floating_point_type(lhs.type):
|
||||||
|
ogt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2)
|
||||||
|
cond = std.CmpFOp(i1, ogt_attr, lhs, rhs).result
|
||||||
|
return std.SelectOp(lhs.type, cond, lhs, rhs).result
|
||||||
|
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
|
||||||
|
sgt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4)
|
||||||
|
cond = std.CmpIOp(i1, sgt_attr, lhs, rhs).result
|
||||||
|
return std.SelectOp(lhs.type, cond, lhs, rhs).result
|
||||||
|
raise NotImplementedError("Unsupported 'max' operand: {lhs}")
|
||||||
|
|
||||||
|
|
||||||
def _infer_structured_outs(op_config: LinalgStructuredOpConfig,
|
def _infer_structured_outs(op_config: LinalgStructuredOpConfig,
|
||||||
in_arg_defs: Sequence[OperandDefConfig],
|
in_arg_defs: Sequence[OperandDefConfig],
|
||||||
|
@ -148,6 +148,24 @@ def pooling_nhwc_sum_poly(
|
|||||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
|
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
|
||||||
|
|
||||||
|
|
||||||
|
@linalg_structured_op
|
||||||
|
def pooling_nhwc_max_poly(
|
||||||
|
I=TensorDef(T1, S.N, S.H, S.W, S.C),
|
||||||
|
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
|
||||||
|
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
|
||||||
|
strides=AttributeDef(S.SH, S.SW),
|
||||||
|
dilations=AttributeDef(S.DH, S.DW)):
|
||||||
|
"""Performs max pooling.
|
||||||
|
|
||||||
|
Numeric casting is performed on the input operand, promoting it to the same
|
||||||
|
data type as the accumulator/output.
|
||||||
|
"""
|
||||||
|
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
||||||
|
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max(D.kh, D.kw)(
|
||||||
|
cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
||||||
|
D.c]))
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
def fill_rng_2d(
|
def fill_rng_2d(
|
||||||
min=ScalarDef(F64),
|
min=ScalarDef(F64),
|
||||||
|
@ -60,6 +60,36 @@ func @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_i32(%input : tenso
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func @generalize_pooling_nhwc_max_poly_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
|
||||||
|
%0 = linalg.pooling_nhwc_max_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
|
||||||
|
ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
|
||||||
|
return %0: tensor<1x2x4x1xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @generalize_pooling_nhwc_max_poly_f32
|
||||||
|
// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
|
||||||
|
// CHECK-NEXT: %[[COND:.+]] = cmpf ogt, %[[OUT_ARG]], %[[IN_ARG]] : f32
|
||||||
|
// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : f32
|
||||||
|
// CHECK-NEXT: linalg.yield %[[MAX]] : f32
|
||||||
|
// CHECK-NEXT: -> tensor<1x2x4x1xf32>
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @generalize_pooling_nhwc_max_poly_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> {
|
||||||
|
%0 = linalg.pooling_nhwc_max_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
|
||||||
|
ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32>
|
||||||
|
return %0: tensor<1x2x4x1xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @generalize_pooling_nhwc_max_poly_i32
|
||||||
|
// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
|
||||||
|
// CHECK-NEXT: %[[COND:.+]] = cmpi sgt, %[[OUT_ARG]], %[[IN_ARG]] : i32
|
||||||
|
// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : i32
|
||||||
|
// CHECK-NEXT: linalg.yield %[[MAX]] : i32
|
||||||
|
// CHECK-NEXT: -> tensor<1x2x4x1xi32>
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @generalize_pooling_nhwc_sum_poly_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
|
func @generalize_pooling_nhwc_sum_poly_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
|
||||||
%0 = linalg.pooling_nhwc_sum_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
|
%0 = linalg.pooling_nhwc_sum_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
|
||||||
ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
|
ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
|
||||||
|
@ -50,8 +50,9 @@ def pooling_poly(
|
|||||||
strides=AttributeDef(S.SH, S.SW),
|
strides=AttributeDef(S.SH, S.SW),
|
||||||
dilations=AttributeDef(S.DH, S.DW)):
|
dilations=AttributeDef(S.DH, S.DW)):
|
||||||
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
||||||
O[D.n, D.oh, D.ow, D.c] += cast(
|
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max(D.kh, D.kw)(
|
||||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
|
cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
||||||
|
D.c]))
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
@ -221,8 +222,9 @@ with Context() as ctx, Location.unknown():
|
|||||||
# CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
|
# CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
|
||||||
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: i32)
|
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: i32)
|
||||||
# CHECK-NEXT: %[[IN_CAST:.+]] = fptosi %[[IN:.+]] : f32 to i32
|
# CHECK-NEXT: %[[IN_CAST:.+]] = fptosi %[[IN:.+]] : f32 to i32
|
||||||
# CHECK-NEXT: %[[SUM:.+]] = addi %[[OUT]], %[[IN_CAST]] : i32
|
# CHECK-NEXT: %[[COND:.+]] = cmpi sgt, %[[OUT]], %[[IN_CAST:.+]] : i32
|
||||||
# CHECK-NEXT: linalg.yield %[[SUM]] : i32
|
# CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT]], %[[IN_CAST:.+]] : i32
|
||||||
|
# CHECK-NEXT: linalg.yield %[[MAX]] : i32
|
||||||
# CHECK-NEXT: -> tensor<2x4xi32>
|
# CHECK-NEXT: -> tensor<2x4xi32>
|
||||||
@builtin.FuncOp.from_py_func(
|
@builtin.FuncOp.from_py_func(
|
||||||
RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
|
RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
|
||||||
@ -231,6 +233,22 @@ with Context() as ctx, Location.unknown():
|
|||||||
return pooling_poly(
|
return pooling_poly(
|
||||||
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
|
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
|
||||||
|
|
||||||
|
# CHECK-LABEL: @test_f32f32_pooling
|
||||||
|
# CHECK: linalg.generic
|
||||||
|
# CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$POOL_MAP_K]], #[[$CONV_MAP_O]]]
|
||||||
|
# CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
|
||||||
|
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: f32)
|
||||||
|
# CHECK-NEXT: %[[COND:.+]] = cmpf ogt, %[[OUT]], %[[IN:.+]] : f32
|
||||||
|
# CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT]], %[[IN:.+]] : f32
|
||||||
|
# CHECK-NEXT: linalg.yield %[[MAX]] : f32
|
||||||
|
# CHECK-NEXT: -> tensor<2x4xf32>
|
||||||
|
@builtin.FuncOp.from_py_func(
|
||||||
|
RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
|
||||||
|
RankedTensorType.get((2, 4), f32))
|
||||||
|
def test_f32f32_pooling(input, shape, init_result):
|
||||||
|
return pooling_poly(
|
||||||
|
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
|
||||||
|
|
||||||
# CHECK-LABEL: @test_i32_fill_rng
|
# CHECK-LABEL: @test_i32_fill_rng
|
||||||
# CHECK: ^{{.*}}(%[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32, %{{.*}}
|
# CHECK: ^{{.*}}(%[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32, %{{.*}}
|
||||||
# CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index
|
# CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index
|
||||||
|
@ -85,6 +85,7 @@ func @main() -> i32 attributes {llvm.emit_c_interface} {
|
|||||||
pooling_boiler = """
|
pooling_boiler = """
|
||||||
func @main() -> i32 attributes {llvm.emit_c_interface} {
|
func @main() -> i32 attributes {llvm.emit_c_interface} {
|
||||||
%v0 = constant 0 : i32
|
%v0 = constant 0 : i32
|
||||||
|
%v42 = constant 42.0 : f64
|
||||||
%v1 = constant 1.0 : f64
|
%v1 = constant 1.0 : f64
|
||||||
|
|
||||||
%input = memref.alloc() : memref<1x4x16x1xf64>
|
%input = memref.alloc() : memref<1x4x16x1xf64>
|
||||||
@ -94,10 +95,12 @@ func @main() -> i32 attributes {llvm.emit_c_interface} {
|
|||||||
linalg.fill(%v1, %shape) : f64, memref<2x2xf64>
|
linalg.fill(%v1, %shape) : f64, memref<2x2xf64>
|
||||||
linalg.fill(%v0, %output) : i32, memref<1x2x4x1xi32>
|
linalg.fill(%v0, %output) : i32, memref<1x2x4x1xi32>
|
||||||
|
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
memref.store %v42, %input[%c0, %c0, %c0, %c0] : memref<1x4x16x1xf64>
|
||||||
|
|
||||||
call @pooling_on_buffers(%input, %shape, %output) :
|
call @pooling_on_buffers(%input, %shape, %output) :
|
||||||
(memref<1x4x16x1xf64>, memref<2x2xf64>, memref<1x2x4x1xi32>) -> ()
|
(memref<1x4x16x1xf64>, memref<2x2xf64>, memref<1x2x4x1xi32>) -> ()
|
||||||
|
|
||||||
%c0 = constant 0 : index
|
|
||||||
%0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32>
|
%0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32>
|
||||||
|
|
||||||
// TODO: FFI-based solution to allow testing and printing with python code.
|
// TODO: FFI-based solution to allow testing and printing with python code.
|
||||||
@ -105,6 +108,7 @@ func @main() -> i32 attributes {llvm.emit_c_interface} {
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def transform(module, boilerplate):
|
def transform(module, boilerplate):
|
||||||
import mlir.conversions
|
import mlir.conversions
|
||||||
import mlir.dialects.linalg.passes
|
import mlir.dialects.linalg.passes
|
||||||
@ -308,12 +312,8 @@ def test_pooling_builtin():
|
|||||||
MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
|
MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
|
||||||
MemRefType.get((1, 2, 4, 1), i32))
|
MemRefType.get((1, 2, 4, 1), i32))
|
||||||
def pooling_on_buffers(input, shape, output):
|
def pooling_on_buffers(input, shape, output):
|
||||||
linalg.pooling_nhwc_sum_poly(
|
linalg.pooling_nhwc_max_poly(
|
||||||
input,
|
input, shape, outs=[output], strides=[2, 4], dilations=[1, 2])
|
||||||
shape,
|
|
||||||
outs=[output],
|
|
||||||
strides=[2, 4],
|
|
||||||
dilations=[1, 2])
|
|
||||||
|
|
||||||
execution_engine = ExecutionEngine(transform(module, pooling_boiler))
|
execution_engine = ExecutionEngine(transform(module, pooling_boiler))
|
||||||
|
|
||||||
@ -325,7 +325,7 @@ def test_pooling_builtin():
|
|||||||
execution_engine.invoke("main", res)
|
execution_engine.invoke("main", res)
|
||||||
|
|
||||||
log("RESULT: ", res[0])
|
log("RESULT: ", res[0])
|
||||||
# CHECK: RESULT: 4
|
# CHECK: RESULT: 42
|
||||||
|
|
||||||
|
|
||||||
test_pooling_builtin()
|
test_pooling_builtin()
|
||||||
@ -342,7 +342,7 @@ def test_pooling_generic():
|
|||||||
MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
|
MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
|
||||||
MemRefType.get((1, 2, 4, 1), i32))
|
MemRefType.get((1, 2, 4, 1), i32))
|
||||||
def pooling_on_buffers(input, shape, output):
|
def pooling_on_buffers(input, shape, output):
|
||||||
linalg.pooling_nhwc_sum_poly(
|
linalg.pooling_nhwc_max_poly(
|
||||||
input,
|
input,
|
||||||
shape,
|
shape,
|
||||||
outs=[output],
|
outs=[output],
|
||||||
@ -360,7 +360,7 @@ def test_pooling_generic():
|
|||||||
execution_engine.invoke("main", res)
|
execution_engine.invoke("main", res)
|
||||||
|
|
||||||
log("RESULT: ", res[0])
|
log("RESULT: ", res[0])
|
||||||
# CHECK: RESULT: 4
|
# CHECK: RESULT: 42
|
||||||
|
|
||||||
|
|
||||||
test_pooling_generic()
|
test_pooling_generic()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user