mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-27 06:06:06 +00:00
[mlir][linalg] Add pooling_nchw_max, conv_2d_nchw as yaml ops.
- Add pooling_nchw_max. - Move conv_2d_nchw to yaml ops and add strides and dilation attributes. Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D106658
This commit is contained in:
parent
ae69f46867
commit
deebf18512
@ -905,6 +905,88 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarExpression
|
||||
scalar_arg: K
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: conv_2d_nchw
|
||||
cpp_class_name: Conv2DNchwOp
|
||||
doc: |-
|
||||
Performs 2-D convolution.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output.
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !LinalgOperandDefConfig
|
||||
name: I
|
||||
usage: InputOperand
|
||||
type_var: T1
|
||||
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
|
||||
-> (s0, s1, s2, s3)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: K
|
||||
usage: InputOperand
|
||||
type_var: T2
|
||||
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
|
||||
-> (s4, s1, s5, s6)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: O
|
||||
usage: OutputOperand
|
||||
type_var: U
|
||||
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
|
||||
-> (s0, s4, s7, s8, s1)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: strides
|
||||
usage: IndexAttribute
|
||||
type_var: I64
|
||||
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
|
||||
s12] -> (s9, s10)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: dilations
|
||||
usage: IndexAttribute
|
||||
type_var: I64
|
||||
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
|
||||
s12] -> (s11, s12)>
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
|
||||
s9, s10, s11, s12] -> (d0, d4, d2 * s9 + d5 * s11, d3 * s10 + d6 * s12)>
|
||||
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
|
||||
s9, s10, s11, s12] -> (d1, d4, d5, d6)>
|
||||
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
|
||||
s9, s10, s11, s12] -> (d0, d1, d2, d3)>
|
||||
iterator_types:
|
||||
- parallel
|
||||
- parallel
|
||||
- parallel
|
||||
- parallel
|
||||
- reduction
|
||||
- reduction
|
||||
- reduction
|
||||
assignments:
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: O
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
symbolic_cast:
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: I
|
||||
- !ScalarExpression
|
||||
symbolic_cast:
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: K
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: pooling_nhwc_sum
|
||||
cpp_class_name: PoolingNhwcSumOp
|
||||
@ -1047,6 +1129,77 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarExpression
|
||||
scalar_arg: I
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: pooling_nchw_max
|
||||
cpp_class_name: PoolingNchwMaxOp
|
||||
doc: |-
|
||||
Performs max pooling.
|
||||
|
||||
Numeric casting is performed on the input operand, promoting it to the same
|
||||
data type as the accumulator/output.
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !LinalgOperandDefConfig
|
||||
name: I
|
||||
usage: InputOperand
|
||||
type_var: T1
|
||||
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
|
||||
(s0, s1, s2, s3)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: K
|
||||
usage: InputOperand
|
||||
type_var: T2
|
||||
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
|
||||
(s4, s5)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: O
|
||||
usage: OutputOperand
|
||||
type_var: U
|
||||
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
|
||||
(s0, s1, s6, s7)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: strides
|
||||
usage: IndexAttribute
|
||||
type_var: I64
|
||||
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
|
||||
-> (s8, s9)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: dilations
|
||||
usage: IndexAttribute
|
||||
type_var: I64
|
||||
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
|
||||
-> (s10, s11)>
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
|
||||
s10, s11] -> (d0, d1, d2 * s8 + d4 * s10, d3 * s9 + d5 * s11)>
|
||||
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
|
||||
s10, s11] -> (d4, d5)>
|
||||
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
|
||||
s10, s11] -> (d0, d1, d2, d3)>
|
||||
iterator_types:
|
||||
- parallel
|
||||
- parallel
|
||||
- parallel
|
||||
- parallel
|
||||
- reduction
|
||||
- reduction
|
||||
assignments:
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
fn_name: max
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: O
|
||||
- !ScalarExpression
|
||||
symbolic_cast:
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: I
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: pooling_nhwc_min
|
||||
cpp_class_name: PoolingNhwcMinOp
|
||||
|
@ -125,12 +125,6 @@ def conv_2d_nhwc(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F
|
||||
O(n, h, w, f), MulFOp(I(n, h + kh, w + kw, c), K(f, kh, kw, c)));
|
||||
}
|
||||
|
||||
ods_def<ConvNCHWOp>:
|
||||
def conv_2d_nchw(I: f32(N, C, H, W), K: f32(F, C, KH, KW)) -> (O: f32(N, F, H, W)) {
|
||||
O(n, f, h, w) = AddFOp<kh, kw>(
|
||||
O(n, f, h, w), MulFOp(I(n, c, h + kh, w + kw), K(f, c, kh, kw)));
|
||||
}
|
||||
|
||||
ods_def<ConvDHWOp>:
|
||||
def conv_3d(I: f32(D, H, W), K: f32(KD, KH, KW)) -> (O: f32(D, H, W)) {
|
||||
O(d, h, w) = AddFOp<kd, kh, kw>(
|
||||
|
@ -1186,8 +1186,8 @@ void mlir::linalg::populateConvVectorizationPatterns(
|
||||
populateVectorizationPatterns<ConvInputNHWCFilterHWCFOp, 4>(
|
||||
tiling, promotion, vectorization, tileSizes);
|
||||
|
||||
populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization,
|
||||
tileSizes);
|
||||
populateVectorizationPatterns<Conv2DNchwOp, 4>(tiling, promotion,
|
||||
vectorization, tileSizes);
|
||||
populateVectorizationPatterns<ConvInputNCHWFilterHWCFOp, 4>(
|
||||
tiling, promotion, vectorization, tileSizes);
|
||||
|
||||
|
@ -205,6 +205,23 @@ def depthwise_conv_2d_input_nhwc_filter_hwc_poly(
|
||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
||||
D.c]) * cast(U, K[D.kh, D.kw, D.c])
|
||||
|
||||
@linalg_structured_op
|
||||
def conv_2d_nchw(
|
||||
I=TensorDef(T1, S.N, S.C, S.IH, S.IW),
|
||||
K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
|
||||
O=TensorDef(U, S.N, S.F, S.OH, S.OW, S.C, output=True),
|
||||
strides=AttributeDef(S.SH, S.SW),
|
||||
dilations=AttributeDef(S.DH, S.DW)):
|
||||
"""Performs 2-D convolution.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output.
|
||||
"""
|
||||
domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
|
||||
O[D.n, D.f, D.oh, D.ow] += cast(
|
||||
U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
||||
]) * cast(U, K[D.f, D.c, D.kh, D.kw])
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def pooling_nhwc_sum(
|
||||
@ -240,6 +257,22 @@ def pooling_nhwc_max(
|
||||
cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
||||
D.c]))
|
||||
|
||||
@linalg_structured_op
|
||||
def pooling_nchw_max(
|
||||
I=TensorDef(T1, S.N, S.C, S.H, S.W),
|
||||
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
|
||||
O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True),
|
||||
strides=AttributeDef(S.SH, S.SW),
|
||||
dilations=AttributeDef(S.DH, S.DW)):
|
||||
"""Performs max pooling.
|
||||
|
||||
Numeric casting is performed on the input operand, promoting it to the same
|
||||
data type as the accumulator/output.
|
||||
"""
|
||||
domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
|
||||
O[D.n, D.c, D.oh, D.ow] = ReduceFn.max(D.kh, D.kw)(
|
||||
cast(U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
||||
]))
|
||||
|
||||
@linalg_structured_op
|
||||
def pooling_nhwc_min(
|
||||
|
@ -30,6 +30,24 @@ func @depthwise_conv_2d_input_nhwc_filter_hwcf_tensor(%input: tensor<2x4x5x2xf32
|
||||
return %0 : tensor<2x3x4x2x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @conv_2d_nchw_tensor
|
||||
func @conv_2d_nchw_tensor(%input: tensor<2x2x4x5xf32>, %filter: tensor<4x2x3x3xf32>) -> tensor<2x4x2x3xf32> {
|
||||
%cst = constant 0.000000e+00 : f32
|
||||
%init = linalg.init_tensor [2, 4, 2, 3] : tensor<2x4x2x3xf32>
|
||||
%fill = linalg.fill(%cst, %init) : f32, tensor<2x4x2x3xf32> -> tensor<2x4x2x3xf32>
|
||||
// CHECK: %{{.+}} = linalg.conv_2d_nchw
|
||||
// CHECK-SAME: {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<2x2x4x5xf32>, tensor<4x2x3x3xf32>)
|
||||
// CHECK-SAME: outs(%{{.+}} : tensor<2x4x2x3xf32>) -> tensor<2x4x2x3xf32>
|
||||
// CHECK: return %{{.+}} : tensor<2x4x2x3xf32>
|
||||
// CHECK: }
|
||||
%0 = linalg.conv_2d_nchw
|
||||
{dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
|
||||
ins(%input, %filter: tensor<2x2x4x5xf32>, tensor<4x2x3x3xf32>)
|
||||
outs(%fill : tensor<2x4x2x3xf32>) -> tensor<2x4x2x3xf32>
|
||||
return %0 : tensor<2x4x2x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_hwcf_memref
|
||||
func @depthwise_conv_2d_input_nhwc_filter_hwcf_memref(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x3x4x2x3xf32>) {
|
||||
// CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
|
||||
@ -381,6 +399,25 @@ func @pooling_nhwc_max_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x1xf32
|
||||
return %res : tensor<1x2x2x1xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @pooling_nchw_max_tensor
|
||||
// CHECK: %{{.+}} = linalg.pooling_nchw_max
|
||||
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
|
||||
// CHECK-SAME: strides = dense<1> : tensor<2xi64>
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x1x4x4xf32>, tensor<3x3xf32>)
|
||||
// CHECK-SAME: outs(%{{.+}} : tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32>
|
||||
|
||||
func @pooling_nchw_max_tensor(%input: tensor<1x1x4x4xf32>) -> tensor<1x1x2x2xf32> {
|
||||
%fake = linalg.init_tensor [3, 3] : tensor<3x3xf32>
|
||||
%init = linalg.init_tensor [1, 1, 2, 2] : tensor<1x1x2x2xf32>
|
||||
%cst = constant 0.000000e+00 : f32
|
||||
%fill = linalg.fill(%cst, %init) : f32, tensor<1x1x2x2xf32> -> tensor<1x1x2x2xf32>
|
||||
%res = linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||
ins(%input, %fake: tensor<1x1x4x4xf32>, tensor<3x3xf32>)
|
||||
outs(%fill: tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32>
|
||||
return %res : tensor<1x1x2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @pooling_nhwc_max
|
||||
|
@ -30,8 +30,10 @@ func @alloc_4d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %s4 : index, %f
|
||||
}
|
||||
|
||||
func @conv_2d_nchw(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) {
|
||||
linalg.conv_2d_nchw ins (%arg0, %arg1: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
|
||||
outs (%arg2: memref<?x?x?x?xf32>)
|
||||
linalg.conv_2d_nchw
|
||||
{dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
|
||||
ins (%arg0, %arg1: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
|
||||
outs (%arg2: memref<?x?x?x?xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user