[MLIR][Shape] Lower shape_of for unranked tensors

Lower `shape_of` for unranked tensors.
Materializes shape in stack-allocated memory.

Differential Revision: https://reviews.llvm.org/D82196
This commit is contained in:
Frederik Gossen 2020-06-25 08:50:02 +00:00
parent 24debf5a76
commit e34b88309e
4 changed files with 90 additions and 5 deletions

View File

@ -1408,7 +1408,9 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &result, "
"Value memrefOrTensor, int64_t index">
"Value memrefOrTensor, int64_t index">,
OpBuilder<"OpBuilder &builder, OperationState &result, "
"Value memrefOrTensor, Value index">
];
let extraClassDeclaration = [{

View File

@ -69,6 +69,58 @@ ReduceOpConverter::matchAndRewrite(ReduceOp reduceOp,
return success();
}
namespace {
/// Converts `shape_of` to for loop for unranked tensors.
class ShapeOfOpConverter : public OpConversionPattern<ShapeOfOp> {
public:
using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
LogicalResult
ShapeOfOpConverter::matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
ShapeOfOp::Adaptor transformed(operands);
auto tensorVal = transformed.arg();
auto tensorTy = tensorVal.getType();
// For ranked tensors `shape_of` lowers to `std` and the pattern can be
// found in the corresponding pass.
if (tensorTy.isa<RankedTensorType>())
return failure();
// Allocate stack memory.
auto loc = op.getLoc();
auto rankVal = rewriter.create<RankOp>(loc, tensorVal);
auto i64Ty = rewriter.getI64Type();
auto memTy = MemRefType::get({ShapedType::kDynamicSize}, i64Ty);
auto memVal = rewriter.create<AllocaOp>(loc, memTy, ValueRange({rankVal}));
// Copy shape extents to stack-allocated memory.
auto zeroVal = rewriter.create<ConstantIndexOp>(loc, 0);
auto oneVal = rewriter.create<ConstantIndexOp>(loc, 1);
rewriter.create<scf::ForOp>(
loc, zeroVal, rankVal, oneVal, ValueRange(),
[&](OpBuilder &b, Location loc, Value iVal, ValueRange args) {
auto dimVal = b.create<DimOp>(loc, tensorVal, iVal);
auto dimIntVal = b.create<IndexCastOp>(loc, dimVal, i64Ty);
b.create<StoreOp>(loc, dimIntVal, memVal, ValueRange({iVal}));
b.create<scf::YieldOp>(loc);
});
// Load extents to tensor value.
auto shapeIntVal = rewriter.create<TensorLoadOp>(loc, memVal);
auto indexTy = rewriter.getIndexType();
auto shapeTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
rewriter.replaceOpWithNewOp<IndexCastOp>(op.getOperation(), shapeIntVal,
shapeTy);
return success();
}
namespace {
struct ConvertShapeToSCFPass
: public ConvertShapeToSCFBase<ConvertShapeToSCFPass> {
@ -79,19 +131,23 @@ struct ConvertShapeToSCFPass
void ConvertShapeToSCFPass::runOnFunction() {
MLIRContext &ctx = getContext();
// Populate conversion patterns.
OwningRewritePatternList patterns;
populateShapeToSCFConversionPatterns(patterns, &ctx);
// Setup target legality.
ConversionTarget target(getContext());
target.addLegalDialect<ShapeDialect, scf::SCFDialect, StandardOpsDialect>();
target.addIllegalOp<ReduceOp>();
if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
target.addIllegalOp<ReduceOp, ShapeOfOp>();
// Apply conversion.
if (failed(applyPartialConversion(getFunction(), target, patterns)))
signalPassFailure();
}
void mlir::populateShapeToSCFConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<ReduceOpConverter>(ctx);
patterns.insert<ReduceOpConverter, ShapeOfOpConverter>(ctx);
}
std::unique_ptr<FunctionPass> mlir::createConvertShapeToSCFPass() {

View File

@ -1273,8 +1273,13 @@ void DimOp::build(OpBuilder &builder, OperationState &result,
Value memrefOrTensor, int64_t index) {
auto loc = result.location;
Value indexValue = builder.create<ConstantIndexOp>(loc, index);
build(builder, result, memrefOrTensor, indexValue);
}
void DimOp::build(OpBuilder &builder, OperationState &result,
Value memrefOrTensor, Value index) {
auto indexTy = builder.getIndexType();
build(builder, result, indexTy, memrefOrTensor, indexValue);
build(builder, result, indexTy, memrefOrTensor, index);
}
Optional<int64_t> DimOp::getConstantIndex() {

View File

@ -26,3 +26,25 @@ func @shape_reduce(%shape : !shape.shape) -> !shape.size {
// CHECK-NEXT: scf.yield [[NEW_ACC]] : !shape.size
// CHECK-NEXT: }
// CHECK-NEXT: return [[RESULT]] : !shape.size
// -----
// Lower `shape_of` for unranked tensors.
// CHECK-LABEL: @shape_of_unranked
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
func @shape_of_unranked(%arg : tensor<*xf32>) {
// CHECK-DAG: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32>
// CHECK-DAG: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref<?xi64>
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] {
// CHECK-DAG: %[[DIM:.]] = dim %[[ARG]], %[[I]] : tensor<*xf32>
// CHECK-DAG: %[[DIM_INT:.*]] = index_cast %[[DIM]] : index to i64
// CHECK-DAG: store %[[DIM_INT]], %[[SHAPE_MEM]][%[[I]]] : memref<?xi64>
// CHECK: }
// CHECK-DAG: %[[SHAPE_INT:.*]] = tensor_load %[[SHAPE_MEM]] : memref<?xi64>
// CHECK-DAG: %[[SHAPE:.*]] = index_cast %[[SHAPE_INT]] : tensor<?xi64> to tensor<?xindex>
%shape = shape.shape_of %arg : tensor<*xf32>
return
}