2022-01-24 23:16:29 +09:00
|
|
|
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
|
2021-11-24 18:20:00 +09:00
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-01-24 23:16:29 +09:00
|
|
|
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
|
2022-01-20 18:14:59 +09:00
|
|
|
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
2021-11-24 18:20:00 +09:00
|
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
|
|
#include "mlir/IR/Dialect.h"
|
|
|
|
#include "mlir/IR/Operation.h"
|
|
|
|
|
2021-11-26 22:13:28 +09:00
|
|
|
using namespace mlir;
|
2022-01-20 18:14:59 +09:00
|
|
|
using namespace mlir::bufferization;
|
2022-01-24 23:16:29 +09:00
|
|
|
using namespace mlir::tensor;
|
2021-11-26 22:13:28 +09:00
|
|
|
|
2021-11-24 18:20:00 +09:00
|
|
|
namespace mlir {
|
2022-01-24 23:16:29 +09:00
|
|
|
namespace tensor {
|
|
|
|
namespace {
|
2021-11-24 18:20:00 +09:00
|
|
|
|
|
|
|
struct CastOpInterface
|
|
|
|
: public BufferizableOpInterface::ExternalModel<CastOpInterface,
|
|
|
|
tensor::CastOp> {
|
2021-12-16 11:42:41 +09:00
|
|
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-11-24 18:20:00 +09:00
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2021-12-16 11:42:41 +09:00
|
|
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-11-24 18:20:00 +09:00
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2021-12-16 11:42:41 +09:00
|
|
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-11-24 18:20:00 +09:00
|
|
|
return op->getResult(0);
|
|
|
|
}
|
|
|
|
|
2021-12-04 11:49:07 +09:00
|
|
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-11-24 18:20:00 +09:00
|
|
|
return BufferRelation::Equivalent;
|
|
|
|
}
|
|
|
|
|
2022-01-05 20:36:05 +09:00
|
|
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-11-24 18:20:00 +09:00
|
|
|
auto castOp = cast<tensor::CastOp>(op);
|
|
|
|
|
2022-01-07 06:21:46 +09:00
|
|
|
// The result buffer still has the old (pre-cast) type.
|
|
|
|
FailureOr<Value> resultBuffer =
|
2022-01-08 01:10:15 +09:00
|
|
|
state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/);
|
2022-01-07 06:32:35 +09:00
|
|
|
if (failed(resultBuffer))
|
|
|
|
return failure();
|
2022-01-07 06:21:46 +09:00
|
|
|
auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
|
|
|
|
Attribute memorySpace = sourceMemRefType.getMemorySpace();
|
|
|
|
TensorType resultTensorType =
|
|
|
|
castOp.getResult().getType().cast<TensorType>();
|
|
|
|
MemRefLayoutAttrInterface layout;
|
|
|
|
|
|
|
|
if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
|
|
|
|
if (resultTensorType.isa<RankedTensorType>())
|
|
|
|
layout = rankedMemRefType.getLayout();
|
|
|
|
|
|
|
|
// Compute the new memref type.
|
|
|
|
Type resultMemRefType;
|
2022-01-09 16:09:24 -05:00
|
|
|
if (resultTensorType.isa<RankedTensorType>()) {
|
2022-01-07 06:21:46 +09:00
|
|
|
resultMemRefType =
|
|
|
|
getContiguousMemRefType(resultTensorType, layout, memorySpace);
|
|
|
|
} else {
|
|
|
|
resultMemRefType =
|
|
|
|
getUnrankedMemRefType(resultTensorType.getElementType(), memorySpace);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Replace the op with a memref.cast.
|
2022-01-09 11:51:58 -05:00
|
|
|
assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
|
|
|
|
resultMemRefType) &&
|
|
|
|
"CallOp::bufferize: cast incompatible");
|
2022-01-07 06:21:46 +09:00
|
|
|
replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
|
|
|
|
*resultBuffer);
|
|
|
|
|
2021-11-24 18:20:00 +09:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-01-07 06:21:46 +09:00
|
|
|
/// Bufferization of tensor.dim. Replace with memref.dim.
|
2021-11-24 18:20:00 +09:00
|
|
|
struct DimOpInterface
|
|
|
|
: public BufferizableOpInterface::ExternalModel<DimOpInterface,
|
|
|
|
tensor::DimOp> {
|
2021-12-16 11:42:41 +09:00
|
|
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-11-24 18:20:00 +09:00
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2021-12-16 11:42:41 +09:00
|
|
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-11-24 18:20:00 +09:00
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2021-12-16 11:42:41 +09:00
|
|
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-11-24 18:20:00 +09:00
|
|
|
return OpResult();
|
|
|
|
}
|
|
|
|
|
2022-01-05 20:36:05 +09:00
|
|
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-11-24 18:20:00 +09:00
|
|
|
auto dimOp = cast<tensor::DimOp>(op);
|
2022-01-08 01:10:15 +09:00
|
|
|
Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
|
2022-01-07 05:23:56 +09:00
|
|
|
replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
|
2021-11-24 18:20:00 +09:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-01-07 06:21:46 +09:00
|
|
|
/// Bufferization of tensor.extract_slice. Replace with memref.subview.
|
2021-11-24 18:20:00 +09:00
|
|
|
struct ExtractSliceOpInterface
|
|
|
|
: public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
|
|
|
|
tensor::ExtractSliceOp> {
|
2021-12-16 11:42:41 +09:00
|
|
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-11-24 18:20:00 +09:00
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2021-12-16 11:42:41 +09:00
|
|
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-11-24 18:20:00 +09:00
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2021-12-16 11:42:41 +09:00
|
|
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-11-24 18:20:00 +09:00
|
|
|
return &opOperand == &op->getOpOperand(0) /*source*/
|
|
|
|
? op->getResult(0)
|
|
|
|
: OpResult();
|
|
|
|
}
|
|
|
|
|
2021-12-04 11:49:07 +09:00
|
|
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-11-24 18:20:00 +09:00
|
|
|
return BufferRelation::None;
|
|
|
|
}
|
|
|
|
|
2022-01-05 20:36:05 +09:00
|
|
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-11-24 18:20:00 +09:00
|
|
|
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
|
|
|
|
Location loc = extractSliceOp.getLoc();
|
2022-01-08 01:10:15 +09:00
|
|
|
Value srcMemref =
|
|
|
|
*state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
|
|
|
|
/*forceInPlace=*/true);
|
2021-11-24 18:20:00 +09:00
|
|
|
auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
|
|
|
|
auto dstTensorType =
|
|
|
|
extractSliceOp.result().getType().cast<RankedTensorType>();
|
|
|
|
|
|
|
|
// If not inplaceable, alloc.
|
2022-01-08 00:56:13 +09:00
|
|
|
bool inplace = state.isInPlace(extractSliceOp->getOpOperand(0));
|
2021-11-24 18:20:00 +09:00
|
|
|
Value alloc;
|
2022-01-07 06:32:35 +09:00
|
|
|
if (!inplace) {
|
|
|
|
FailureOr<Value> allocOrFailure =
|
2022-01-19 18:34:48 +09:00
|
|
|
createAlloc(rewriter, loc, extractSliceOp.result(),
|
|
|
|
state.getOptions().createDeallocs, state.getOptions());
|
2022-01-07 06:32:35 +09:00
|
|
|
if (failed(allocOrFailure))
|
|
|
|
return failure();
|
|
|
|
alloc = *allocOrFailure;
|
|
|
|
}
|
2021-11-24 18:20:00 +09:00
|
|
|
|
2022-01-09 16:09:24 -05:00
|
|
|
// Expand offsets, sizes and strides to the full rank to handle the
|
|
|
|
// rank-reducing case.
|
|
|
|
SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
|
|
|
|
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
|
|
|
|
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
|
|
|
|
OffsetSizeAndStrideOpInterface::expandToRank(
|
|
|
|
srcMemref, mixedOffsets, mixedSizes, mixedStrides,
|
|
|
|
[&](Value target, int64_t dim) -> OpFoldResult {
|
|
|
|
auto shapedType = target.getType().cast<ShapedType>();
|
|
|
|
if (shapedType.isDynamicDim(dim))
|
|
|
|
return rewriter.create<memref::DimOp>(loc, target, dim).result();
|
|
|
|
return rewriter.getIndexAttr(shapedType.getDimSize(dim));
|
|
|
|
});
|
2021-11-24 18:20:00 +09:00
|
|
|
// Bufferize to subview.
|
2022-01-09 16:09:24 -05:00
|
|
|
auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
|
|
|
|
dstTensorType.getRank(), srcMemrefType,
|
|
|
|
mixedOffsets, mixedSizes, mixedStrides)
|
|
|
|
.cast<MemRefType>();
|
2022-01-05 20:36:05 +09:00
|
|
|
Value subView = rewriter.create<memref::SubViewOp>(
|
2022-01-09 16:09:24 -05:00
|
|
|
loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
|
|
|
|
mixedStrides);
|
2021-11-24 18:20:00 +09:00
|
|
|
|
2022-01-07 06:21:46 +09:00
|
|
|
// If not inplaceable, copy.
|
2021-11-24 18:20:00 +09:00
|
|
|
if (!inplace) {
|
|
|
|
// Do not copy if the copied data is never read.
|
2021-12-16 11:42:41 +09:00
|
|
|
if (state.isValueRead(extractSliceOp.result()))
|
2022-01-19 18:34:48 +09:00
|
|
|
if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView,
|
|
|
|
alloc, state.getOptions())))
|
|
|
|
return failure();
|
2021-11-24 18:20:00 +09:00
|
|
|
subView = alloc;
|
|
|
|
}
|
|
|
|
|
2022-01-07 05:23:56 +09:00
|
|
|
replaceOpWithBufferizedValues(rewriter, op, subView);
|
2021-11-24 18:20:00 +09:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-01-07 06:21:46 +09:00
|
|
|
/// Bufferization of tensor.extract. Replace with memref.load.
|
2021-11-24 18:20:00 +09:00
|
|
|
struct ExtractOpInterface
|
|
|
|
: public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
|
|
|
|
tensor::ExtractOp> {
|
2021-12-16 11:42:41 +09:00
|
|
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-11-24 18:20:00 +09:00
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2021-12-16 11:42:41 +09:00
|
|
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-11-24 18:20:00 +09:00
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2021-12-16 11:42:41 +09:00
|
|
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-11-24 18:20:00 +09:00
|
|
|
return OpResult();
|
|
|
|
}
|
|
|
|
|
2022-01-05 20:36:05 +09:00
|
|
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-11-24 18:20:00 +09:00
|
|
|
auto extractOp = cast<tensor::ExtractOp>(op);
|
2022-01-08 01:10:15 +09:00
|
|
|
Value srcMemref =
|
|
|
|
*state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
|
2022-01-07 05:23:56 +09:00
|
|
|
replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
|
|
|
|
extractOp.indices());
|
2021-11-24 18:20:00 +09:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-01-07 06:21:46 +09:00
|
|
|
/// Bufferization of tensor.insert. Replace with memref.store.
|
2021-12-02 11:57:26 +09:00
|
|
|
struct InsertOpInterface
|
|
|
|
: public BufferizableOpInterface::ExternalModel<InsertOpInterface,
|
|
|
|
tensor::InsertOp> {
|
2021-12-16 11:42:41 +09:00
|
|
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-12-02 11:57:26 +09:00
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2021-12-16 11:42:41 +09:00
|
|
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-12-02 11:57:26 +09:00
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2021-12-16 11:42:41 +09:00
|
|
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-12-02 11:57:26 +09:00
|
|
|
assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
|
|
|
|
"expected dest OpOperand");
|
|
|
|
return op->getOpResult(0);
|
|
|
|
}
|
|
|
|
|
2021-12-16 11:42:41 +09:00
|
|
|
SmallVector<OpOperand *>
|
|
|
|
getAliasingOpOperand(Operation *op, OpResult opResult,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-12-02 11:57:26 +09:00
|
|
|
return {&op->getOpOperand(1) /*dest*/};
|
|
|
|
}
|
|
|
|
|
2022-01-05 20:36:05 +09:00
|
|
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-12-02 11:57:26 +09:00
|
|
|
auto insertOp = cast<tensor::InsertOp>(op);
|
2022-01-07 06:21:46 +09:00
|
|
|
FailureOr<Value> destMemref =
|
2022-01-08 01:10:15 +09:00
|
|
|
state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/);
|
2022-01-07 06:32:35 +09:00
|
|
|
if (failed(destMemref))
|
|
|
|
return failure();
|
2022-01-07 06:21:46 +09:00
|
|
|
rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
|
|
|
|
*destMemref, insertOp.indices());
|
|
|
|
replaceOpWithBufferizedValues(rewriter, op, *destMemref);
|
2021-12-02 11:57:26 +09:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2021-12-04 11:49:07 +09:00
|
|
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-12-02 11:57:26 +09:00
|
|
|
return BufferRelation::Equivalent;
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2021-11-24 18:20:00 +09:00
|
|
|
/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
|
|
|
|
/// equivalent operand / result and same offset/sizes/strides specification).
|
|
|
|
///
|
|
|
|
/// This is one particular type of relationship between ops on tensors that
|
|
|
|
/// reduce to an equivalence on buffers. This should be generalized and
|
|
|
|
/// exposed as interfaces on the proper types.
|
2022-01-19 18:58:36 +09:00
|
|
|
static bool areEquivalentExtractSliceOps(const BufferizationState &state,
|
|
|
|
ExtractSliceOp st, InsertSliceOp sti) {
|
2021-11-24 18:20:00 +09:00
|
|
|
if (!st || !sti)
|
|
|
|
return false;
|
2022-01-19 18:58:36 +09:00
|
|
|
if (sti != sti &&
|
|
|
|
!state.areEquivalentBufferizedValues(st.source(), sti.dest()))
|
2021-11-24 18:20:00 +09:00
|
|
|
return false;
|
|
|
|
if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
|
|
|
|
return false;
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Return true if `value` is originating from an ExtractSliceOp that matches
|
|
|
|
/// the given InsertSliceOp.
|
2022-01-19 18:58:36 +09:00
|
|
|
static bool hasMatchingExtractSliceOp(const BufferizationState &state,
|
2022-01-07 00:16:16 +09:00
|
|
|
Value value, InsertSliceOp insertOp) {
|
2021-11-24 18:20:00 +09:00
|
|
|
auto condition = [&](Value val) {
|
|
|
|
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
|
2022-01-19 18:58:36 +09:00
|
|
|
if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
|
2021-11-24 18:20:00 +09:00
|
|
|
return true;
|
|
|
|
return false;
|
|
|
|
};
|
|
|
|
|
2021-12-16 11:42:41 +09:00
|
|
|
return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
|
2021-11-24 18:20:00 +09:00
|
|
|
condition);
|
|
|
|
}
|
|
|
|
|
2022-01-07 06:21:46 +09:00
|
|
|
/// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
|
|
|
|
/// certain circumstances, this op can also be a no-op.
|
2021-11-24 18:20:00 +09:00
|
|
|
struct InsertSliceOpInterface
|
|
|
|
: public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
|
|
|
|
tensor::InsertSliceOp> {
|
2021-12-16 11:42:41 +09:00
|
|
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-11-24 18:20:00 +09:00
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2021-12-16 11:42:41 +09:00
|
|
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-11-24 18:20:00 +09:00
|
|
|
return &opOperand == &op->getOpOperand(1) /*dest*/;
|
|
|
|
}
|
|
|
|
|
2021-12-16 11:42:41 +09:00
|
|
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-11-24 18:20:00 +09:00
|
|
|
return &opOperand == &op->getOpOperand(1) /*dest*/
|
|
|
|
? op->getResult(0)
|
|
|
|
: OpResult();
|
|
|
|
}
|
|
|
|
|
2021-12-04 11:49:07 +09:00
|
|
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-11-24 18:20:00 +09:00
|
|
|
return BufferRelation::Equivalent;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool isNotConflicting(Operation *op, OpOperand *uRead,
|
2022-01-07 00:16:16 +09:00
|
|
|
OpOperand *uConflictingWrite,
|
2022-01-19 18:58:36 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-11-24 18:20:00 +09:00
|
|
|
Operation *readingOp = uRead->getOwner();
|
|
|
|
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
|
|
|
|
|
|
|
|
// Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
|
|
|
|
// uRead is an InsertSliceOp...
|
|
|
|
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
|
|
|
|
// As an example, consider the following IR.
|
|
|
|
//
|
|
|
|
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
|
|
|
|
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
|
|
|
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
|
|
|
// {inplace= [true] }
|
|
|
|
|
|
|
|
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
|
|
|
|
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
2022-01-19 18:58:36 +09:00
|
|
|
hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
|
2021-12-16 11:42:41 +09:00
|
|
|
insertSliceOp))
|
2021-11-24 18:20:00 +09:00
|
|
|
// Case 1: The main insight is that InsertSliceOp reads only part of
|
|
|
|
// the destination tensor. The overwritten area is not read. If
|
|
|
|
// uConflictingWrite writes into exactly the memory location that is
|
|
|
|
// being read by uRead, this is not a conflict.
|
|
|
|
//
|
|
|
|
// In the above example:
|
|
|
|
// uRead = OpOperand 1 (%t) of tensor.insert_slice
|
|
|
|
// uConflictingWrite = OpOperand 1 (%0) of linalg.fill
|
|
|
|
//
|
|
|
|
// The read of %t does not conflict with the write of the FillOp
|
|
|
|
// (same aliases!) because the area that the FillOp operates on is
|
|
|
|
// exactly the one that is *not* read via %t.
|
|
|
|
return true;
|
|
|
|
|
|
|
|
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
|
|
|
|
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
2022-01-19 18:58:36 +09:00
|
|
|
hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
|
2021-11-24 18:20:00 +09:00
|
|
|
// Case 2: The read of the source tensor and the write to the dest
|
|
|
|
// tensor via an InsertSliceOp is not a conflict if the read is
|
|
|
|
// reading exactly that part of an equivalent tensor that the
|
|
|
|
// InsertSliceOp is writing.
|
|
|
|
//
|
|
|
|
// In the above example:
|
|
|
|
// uRead = OpOperand 0 (%1) of tensor.insert_slice
|
|
|
|
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
// If uConflictingWrite is an InsertSliceOp...
|
|
|
|
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
|
|
|
|
// As an example, consider the following IR.
|
|
|
|
//
|
|
|
|
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
|
|
|
|
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
|
|
|
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
|
|
|
// {inplace= [true] }
|
|
|
|
// %3 = vector.transfer_read %1, %cst
|
|
|
|
//
|
|
|
|
// In the above example:
|
|
|
|
// uRead = OpOperand 0 (%1) of vector.transfer_read
|
|
|
|
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
|
|
|
// lastWrite = %1
|
|
|
|
//
|
|
|
|
// This is not a conflict because the InsertSliceOp overwrites the
|
|
|
|
// memory segment of %1 with the exact same data. (Effectively, there
|
|
|
|
// is no memory write here.)
|
|
|
|
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
2022-01-19 18:58:36 +09:00
|
|
|
state.areEquivalentBufferizedValues(uRead->get(),
|
|
|
|
insertSliceOp.source()) &&
|
|
|
|
hasMatchingExtractSliceOp(state, insertSliceOp.source(),
|
2021-12-16 11:42:41 +09:00
|
|
|
insertSliceOp))
|
2021-11-24 18:20:00 +09:00
|
|
|
return true;
|
|
|
|
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2022-01-05 20:36:05 +09:00
|
|
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
2022-01-07 00:16:16 +09:00
|
|
|
const BufferizationState &state) const {
|
2021-11-24 18:20:00 +09:00
|
|
|
// insert_slice ops arise from tiling and bufferizing them out-of-place is
|
|
|
|
// generally a deal breaker. When used with loops, this ends up cloning the
|
|
|
|
// whole tensor on every single iteration and is a symptom of a
|
|
|
|
// catastrophically bad scheduling decision.
|
|
|
|
// TODO: be very loud about it or even consider failing the pass.
|
|
|
|
auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
|
|
|
|
Location loc = insertSliceOp.getLoc();
|
|
|
|
|
|
|
|
// When bufferizing out-of-place, `getResultBuffer` allocates.
|
2022-01-07 06:21:46 +09:00
|
|
|
FailureOr<Value> dstMemref =
|
2022-01-08 01:10:15 +09:00
|
|
|
state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/);
|
2022-01-07 06:32:35 +09:00
|
|
|
if (failed(dstMemref))
|
|
|
|
return failure();
|
2021-11-24 18:20:00 +09:00
|
|
|
|
2022-01-09 16:09:24 -05:00
|
|
|
// Expand offsets, sizes and strides to the full rank to handle the
|
|
|
|
// rank-reducing case.
|
|
|
|
SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
|
|
|
|
SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
|
|
|
|
SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
|
|
|
|
OffsetSizeAndStrideOpInterface::expandToRank(
|
|
|
|
*dstMemref, mixedOffsets, mixedSizes, mixedStrides,
|
|
|
|
[&](Value target, int64_t dim) -> OpFoldResult {
|
|
|
|
auto shapedType = target.getType().cast<ShapedType>();
|
|
|
|
if (shapedType.isDynamicDim(dim))
|
|
|
|
return rewriter.create<memref::DimOp>(loc, target, dim).result();
|
|
|
|
return rewriter.getIndexAttr(shapedType.getDimSize(dim));
|
|
|
|
});
|
2022-01-06 17:34:01 +09:00
|
|
|
// Take a subview of the dst.
|
2022-01-07 06:21:46 +09:00
|
|
|
auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
|
2022-01-06 17:34:01 +09:00
|
|
|
auto subviewMemRefType =
|
|
|
|
memref::SubViewOp::inferRankReducedResultType(
|
|
|
|
insertSliceOp.getSourceType().getRank(), dstMemrefType,
|
2022-01-09 16:09:24 -05:00
|
|
|
mixedOffsets, mixedSizes, mixedStrides)
|
2022-01-06 17:34:01 +09:00
|
|
|
.cast<MemRefType>();
|
|
|
|
Value subView = rewriter.create<memref::SubViewOp>(
|
2022-01-09 16:09:24 -05:00
|
|
|
loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
|
|
|
|
mixedStrides);
|
2022-01-06 17:34:01 +09:00
|
|
|
|
|
|
|
// Copy tensor. If this tensor.insert_slice has a matching
|
|
|
|
// tensor.extract_slice, the copy operation will eventually fold away.
|
2022-01-08 01:10:15 +09:00
|
|
|
Value srcMemref =
|
|
|
|
*state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
|
2022-01-19 18:34:48 +09:00
|
|
|
if (failed(createMemCpy(rewriter, loc, srcMemref, subView,
|
|
|
|
state.getOptions())))
|
|
|
|
return failure();
|
2021-11-24 18:20:00 +09:00
|
|
|
|
2022-01-07 06:21:46 +09:00
|
|
|
replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
|
2021-11-24 18:20:00 +09:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-01-25 00:09:36 +09:00
|
|
|
/// Bufferization of tensor.rank. Replace with memref.rank.
|
|
|
|
struct RankOpInterface
|
|
|
|
: public BufferizableOpInterface::ExternalModel<RankOpInterface,
|
|
|
|
tensor::RankOp> {
|
|
|
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
|
|
|
const BufferizationState &state) const {
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
|
|
|
const BufferizationState &state) const {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
|
|
|
const BufferizationState &state) const {
|
|
|
|
return OpResult();
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
|
|
|
const BufferizationState &state) const {
|
|
|
|
auto rankOp = cast<tensor::RankOp>(op);
|
|
|
|
Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
|
|
|
|
replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
|
|
|
|
v);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-01-24 23:16:29 +09:00
|
|
|
} // namespace
|
|
|
|
} // namespace tensor
|
2021-11-24 18:20:00 +09:00
|
|
|
} // namespace mlir
|
|
|
|
|
2022-01-24 23:16:29 +09:00
|
|
|
void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
|
|
|
|
DialectRegistry ®istry) {
|
|
|
|
registry.addOpInterface<CastOp, CastOpInterface>();
|
|
|
|
registry.addOpInterface<DimOp, DimOpInterface>();
|
|
|
|
registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>();
|
|
|
|
registry.addOpInterface<ExtractOp, ExtractOpInterface>();
|
|
|
|
registry.addOpInterface<InsertOp, InsertOpInterface>();
|
|
|
|
registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();
|
2022-01-25 00:09:36 +09:00
|
|
|
registry.addOpInterface<RankOp, RankOpInterface>();
|
2021-11-24 18:20:00 +09:00
|
|
|
}
|