diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index c4ef7d0bb9ff..84f89bfd7f2d 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -731,9 +731,62 @@ struct ConcatSliceOptimization : public OpRewritePattern { } }; +// Update size operand of tosa.slice if size has dynamic dims but corresponding +// output dim is static +struct SliceDynamicSizeCanonicalization + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, + PatternRewriter &rewriter) const override { + ShapedType resultType = cast(sliceOp.getType()); + + ElementsAttr sizeElems; + if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) { + return rewriter.notifyMatchFailure( + sliceOp, "size of slice must be a static ranked shape"); + } + + llvm::SmallVector sliceSizes = + llvm::to_vector(sizeElems.getValues()); + + bool replaceSliceSize{false}; + // if size op has -1 indicating dynamic shape but corresponding dim on the + // output is statically known, update size to match with known output dim + // shape + for (const auto &[index, size] : llvm::enumerate(sliceSizes)) { + if (size == -1 && !resultType.isDynamicDim(index)) { + sliceSizes[index] = resultType.getDimSize(index); + replaceSliceSize = true; + } + } + + if (!replaceSliceSize) { + return rewriter.notifyMatchFailure( + sliceOp, "no dimension of size of slice is dynamic that resolves " + "to static output shape"); + } + + auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes); + auto newSliceOp = rewriter.create( + sliceOp.getLoc(), sliceOp.getType(), sliceOp.getInput1(), + sliceOp.getStart(), size_op); + + rewriter.replaceOp(sliceOp, newSliceOp.getResult()); + + // Remove const_shape size op when it no longer has use point. + Operation *sizeConstShape = sliceOp.getSize().getDefiningOp(); + if (sizeConstShape->getResult(0).hasOneUse()) + rewriter.eraseOp(sizeConstShape); + + return success(); + } +}; + void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index b366b4f1e4fd..a754a46be603 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -1212,3 +1212,18 @@ func.func @do_not_fold_intdiv_division_by_0() -> tensor<1x24x2xi32> { %16 = tosa.intdiv %4, %1 : (tensor<1x24x2xi32>, tensor<1x24x2xi32>) -> tensor<1x24x2xi32> return %16 : tensor<1x24x2xi32> } + + +// ---- +// CHECK-LABEL: func.func @slice_dynamic_size_static_output_canonicalize( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x60x59x?xf32>) -> tensor<2x60x58x?xf32> { +// CHECK: %[[START:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[SIZE:.*]] = tosa.const_shape {values = dense<[2, 60, 58, -1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[SLICE:.*]] = tosa.slice %[[ARG0]], %[[START]], %[[SIZE]] : (tensor<2x60x59x?xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<2x60x58x?xf32> +// CHECK: return %[[SLICE]] +func.func @slice_dynamic_size_static_output_canonicalize(%arg0: tensor<2x60x59x?xf32>) -> tensor<2x60x58x?xf32> { + %0 = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> + %1 = tosa.const_shape {values = dense<[-1, 60, 58, -1]> : tensor<4xindex>} : () -> !tosa.shape<4> + %2 = tosa.slice %arg0, %0, %1 : (tensor<2x60x59x?xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<2x60x58x?xf32> + return %2 : tensor<2x60x58x?xf32> + }