mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-14 17:06:38 +00:00
[tosa]: canonicalize dynamic size of tosa.slice to static output shape (#135429)
Addresses https://github.com/llvm/llvm-project/issues/135389
This commit is contained in:
parent
b0fede358f
commit
60b1d44d70
@ -731,9 +731,62 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Update size operand of tosa.slice if size has dynamic dims but corresponding
|
||||
// output dim is static
|
||||
struct SliceDynamicSizeCanonicalization
|
||||
: public OpRewritePattern<tosa::SliceOp> {
|
||||
using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
ShapedType resultType = cast<ShapedType>(sliceOp.getType());
|
||||
|
||||
ElementsAttr sizeElems;
|
||||
if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
sliceOp, "size of slice must be a static ranked shape");
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t> sliceSizes =
|
||||
llvm::to_vector(sizeElems.getValues<int64_t>());
|
||||
|
||||
bool replaceSliceSize{false};
|
||||
// if size op has -1 indicating dynamic shape but corresponding dim on the
|
||||
// output is statically known, update size to match with known output dim
|
||||
// shape
|
||||
for (const auto &[index, size] : llvm::enumerate(sliceSizes)) {
|
||||
if (size == -1 && !resultType.isDynamicDim(index)) {
|
||||
sliceSizes[index] = resultType.getDimSize(index);
|
||||
replaceSliceSize = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!replaceSliceSize) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
sliceOp, "no dimension of size of slice is dynamic that resolves "
|
||||
"to static output shape");
|
||||
}
|
||||
|
||||
auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
|
||||
auto newSliceOp = rewriter.create<tosa::SliceOp>(
|
||||
sliceOp.getLoc(), sliceOp.getType(), sliceOp.getInput1(),
|
||||
sliceOp.getStart(), size_op);
|
||||
|
||||
rewriter.replaceOp(sliceOp, newSliceOp.getResult());
|
||||
|
||||
// Remove const_shape size op when it no longer has use point.
|
||||
Operation *sizeConstShape = sliceOp.getSize().getDefiningOp();
|
||||
if (sizeConstShape->getResult(0).hasOneUse())
|
||||
rewriter.eraseOp(sizeConstShape);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<ConcatSliceOptimization>(context);
|
||||
results.add<ConcatSliceOptimization, SliceDynamicSizeCanonicalization>(
|
||||
context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1212,3 +1212,18 @@ func.func @do_not_fold_intdiv_division_by_0() -> tensor<1x24x2xi32> {
|
||||
%16 = tosa.intdiv %4, %1 : (tensor<1x24x2xi32>, tensor<1x24x2xi32>) -> tensor<1x24x2xi32>
|
||||
return %16 : tensor<1x24x2xi32>
|
||||
}
|
||||
|
||||
|
||||
// ----
|
||||
// CHECK-LABEL: func.func @slice_dynamic_size_static_output_canonicalize(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x60x59x?xf32>) -> tensor<2x60x58x?xf32> {
|
||||
// CHECK: %[[START:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
|
||||
// CHECK: %[[SIZE:.*]] = tosa.const_shape {values = dense<[2, 60, 58, -1]> : tensor<4xindex>} : () -> !tosa.shape<4>
|
||||
// CHECK: %[[SLICE:.*]] = tosa.slice %[[ARG0]], %[[START]], %[[SIZE]] : (tensor<2x60x59x?xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<2x60x58x?xf32>
|
||||
// CHECK: return %[[SLICE]]
|
||||
func.func @slice_dynamic_size_static_output_canonicalize(%arg0: tensor<2x60x59x?xf32>) -> tensor<2x60x58x?xf32> {
|
||||
%0 = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
|
||||
%1 = tosa.const_shape {values = dense<[-1, 60, 58, -1]> : tensor<4xindex>} : () -> !tosa.shape<4>
|
||||
%2 = tosa.slice %arg0, %0, %1 : (tensor<2x60x59x?xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<2x60x58x?xf32>
|
||||
return %2 : tensor<2x60x58x?xf32>
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user