mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-25 18:16:04 +00:00
[MLIR][Shape] Lower shape_of
for unranked tensors
Lower `shape_of` for unranked tensors. Materializes shape in stack-allocated memory. Differential Revision: https://reviews.llvm.org/D82196
This commit is contained in:
parent
24debf5a76
commit
e34b88309e
@ -1408,7 +1408,9 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
|
||||
|
||||
let builders = [
|
||||
OpBuilder<"OpBuilder &builder, OperationState &result, "
|
||||
"Value memrefOrTensor, int64_t index">
|
||||
"Value memrefOrTensor, int64_t index">,
|
||||
OpBuilder<"OpBuilder &builder, OperationState &result, "
|
||||
"Value memrefOrTensor, Value index">
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
|
@ -69,6 +69,58 @@ ReduceOpConverter::matchAndRewrite(ReduceOp reduceOp,
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Converts `shape_of` to for loop for unranked tensors.
|
||||
class ShapeOfOpConverter : public OpConversionPattern<ShapeOfOp> {
|
||||
public:
|
||||
using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
LogicalResult
|
||||
ShapeOfOpConverter::matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
ShapeOfOp::Adaptor transformed(operands);
|
||||
auto tensorVal = transformed.arg();
|
||||
auto tensorTy = tensorVal.getType();
|
||||
|
||||
// For ranked tensors `shape_of` lowers to `std` and the pattern can be
|
||||
// found in the corresponding pass.
|
||||
if (tensorTy.isa<RankedTensorType>())
|
||||
return failure();
|
||||
|
||||
// Allocate stack memory.
|
||||
auto loc = op.getLoc();
|
||||
auto rankVal = rewriter.create<RankOp>(loc, tensorVal);
|
||||
auto i64Ty = rewriter.getI64Type();
|
||||
auto memTy = MemRefType::get({ShapedType::kDynamicSize}, i64Ty);
|
||||
auto memVal = rewriter.create<AllocaOp>(loc, memTy, ValueRange({rankVal}));
|
||||
|
||||
// Copy shape extents to stack-allocated memory.
|
||||
auto zeroVal = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||
auto oneVal = rewriter.create<ConstantIndexOp>(loc, 1);
|
||||
rewriter.create<scf::ForOp>(
|
||||
loc, zeroVal, rankVal, oneVal, ValueRange(),
|
||||
[&](OpBuilder &b, Location loc, Value iVal, ValueRange args) {
|
||||
auto dimVal = b.create<DimOp>(loc, tensorVal, iVal);
|
||||
auto dimIntVal = b.create<IndexCastOp>(loc, dimVal, i64Ty);
|
||||
b.create<StoreOp>(loc, dimIntVal, memVal, ValueRange({iVal}));
|
||||
b.create<scf::YieldOp>(loc);
|
||||
});
|
||||
|
||||
// Load extents to tensor value.
|
||||
auto shapeIntVal = rewriter.create<TensorLoadOp>(loc, memVal);
|
||||
auto indexTy = rewriter.getIndexType();
|
||||
auto shapeTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
|
||||
rewriter.replaceOpWithNewOp<IndexCastOp>(op.getOperation(), shapeIntVal,
|
||||
shapeTy);
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct ConvertShapeToSCFPass
|
||||
: public ConvertShapeToSCFBase<ConvertShapeToSCFPass> {
|
||||
@ -79,19 +131,23 @@ struct ConvertShapeToSCFPass
|
||||
void ConvertShapeToSCFPass::runOnFunction() {
|
||||
MLIRContext &ctx = getContext();
|
||||
|
||||
// Populate conversion patterns.
|
||||
OwningRewritePatternList patterns;
|
||||
populateShapeToSCFConversionPatterns(patterns, &ctx);
|
||||
|
||||
// Setup target legality.
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<ShapeDialect, scf::SCFDialect, StandardOpsDialect>();
|
||||
target.addIllegalOp<ReduceOp>();
|
||||
if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
|
||||
target.addIllegalOp<ReduceOp, ShapeOfOp>();
|
||||
|
||||
// Apply conversion.
|
||||
if (failed(applyPartialConversion(getFunction(), target, patterns)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
void mlir::populateShapeToSCFConversionPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||
patterns.insert<ReduceOpConverter>(ctx);
|
||||
patterns.insert<ReduceOpConverter, ShapeOfOpConverter>(ctx);
|
||||
}
|
||||
|
||||
std::unique_ptr<FunctionPass> mlir::createConvertShapeToSCFPass() {
|
||||
|
@ -1273,8 +1273,13 @@ void DimOp::build(OpBuilder &builder, OperationState &result,
|
||||
Value memrefOrTensor, int64_t index) {
|
||||
auto loc = result.location;
|
||||
Value indexValue = builder.create<ConstantIndexOp>(loc, index);
|
||||
build(builder, result, memrefOrTensor, indexValue);
|
||||
}
|
||||
|
||||
void DimOp::build(OpBuilder &builder, OperationState &result,
|
||||
Value memrefOrTensor, Value index) {
|
||||
auto indexTy = builder.getIndexType();
|
||||
build(builder, result, indexTy, memrefOrTensor, indexValue);
|
||||
build(builder, result, indexTy, memrefOrTensor, index);
|
||||
}
|
||||
|
||||
Optional<int64_t> DimOp::getConstantIndex() {
|
||||
|
@ -26,3 +26,25 @@ func @shape_reduce(%shape : !shape.shape) -> !shape.size {
|
||||
// CHECK-NEXT: scf.yield [[NEW_ACC]] : !shape.size
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return [[RESULT]] : !shape.size
|
||||
|
||||
// -----
|
||||
|
||||
// Lower `shape_of` for unranked tensors.
|
||||
// CHECK-LABEL: @shape_of_unranked
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
|
||||
func @shape_of_unranked(%arg : tensor<*xf32>) {
|
||||
// CHECK-DAG: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32>
|
||||
// CHECK-DAG: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref<?xi64>
|
||||
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] {
|
||||
// CHECK-DAG: %[[DIM:.]] = dim %[[ARG]], %[[I]] : tensor<*xf32>
|
||||
// CHECK-DAG: %[[DIM_INT:.*]] = index_cast %[[DIM]] : index to i64
|
||||
// CHECK-DAG: store %[[DIM_INT]], %[[SHAPE_MEM]][%[[I]]] : memref<?xi64>
|
||||
// CHECK: }
|
||||
// CHECK-DAG: %[[SHAPE_INT:.*]] = tensor_load %[[SHAPE_MEM]] : memref<?xi64>
|
||||
// CHECK-DAG: %[[SHAPE:.*]] = index_cast %[[SHAPE_INT]] : tensor<?xi64> to tensor<?xindex>
|
||||
%shape = shape.shape_of %arg : tensor<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user