2020-10-12 14:03:09 -07:00
|
|
|
//===- Bufferize.cpp - Bufferization utilities ----------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-11-26 13:26:08 +01:00
|
|
|
#include "PassDetail.h"
|
2021-11-29 13:45:01 +01:00
|
|
|
|
2022-01-20 18:14:59 +09:00
|
|
|
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
2021-11-25 11:42:16 +01:00
|
|
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
2021-11-29 13:45:01 +01:00
|
|
|
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
|
|
|
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
|
2020-10-12 14:03:09 -07:00
|
|
|
#include "mlir/IR/Operation.h"
|
2022-01-20 18:14:59 +09:00
|
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
2020-10-12 14:03:09 -07:00
|
|
|
|
|
|
|
using namespace mlir;
|
2021-11-29 13:45:01 +01:00
|
|
|
using namespace mlir::bufferization;
|
2020-10-12 14:03:09 -07:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2020-10-12 14:32:38 -07:00
|
|
|
// BufferizeTypeConverter
|
2020-10-12 14:03:09 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2021-11-29 13:45:01 +01:00
|
|
|
static Value materializeToTensor(OpBuilder &builder, TensorType type,
|
|
|
|
ValueRange inputs, Location loc) {
|
2020-11-26 13:26:08 +01:00
|
|
|
assert(inputs.size() == 1);
|
|
|
|
assert(inputs[0].getType().isa<BaseMemRefType>());
|
2021-11-25 11:42:16 +01:00
|
|
|
return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
|
2020-11-26 13:26:08 +01:00
|
|
|
}
|
|
|
|
|
2020-10-12 14:32:38 -07:00
|
|
|
/// Registers conversions into BufferizeTypeConverter
|
|
|
|
BufferizeTypeConverter::BufferizeTypeConverter() {
|
2020-10-12 14:03:09 -07:00
|
|
|
// Keep all types unchanged.
|
|
|
|
addConversion([](Type type) { return type; });
|
|
|
|
// Convert RankedTensorType to MemRefType.
|
2020-10-12 14:47:31 -07:00
|
|
|
addConversion([](RankedTensorType type) -> Type {
|
|
|
|
return MemRefType::get(type.getShape(), type.getElementType());
|
2020-10-12 14:03:09 -07:00
|
|
|
});
|
|
|
|
// Convert UnrankedTensorType to UnrankedMemRefType.
|
2020-10-12 14:47:31 -07:00
|
|
|
addConversion([](UnrankedTensorType type) -> Type {
|
|
|
|
return UnrankedMemRefType::get(type.getElementType(), 0);
|
2020-10-12 14:03:09 -07:00
|
|
|
});
|
2021-11-29 13:45:01 +01:00
|
|
|
addArgumentMaterialization(materializeToTensor);
|
|
|
|
addSourceMaterialization(materializeToTensor);
|
2020-11-02 15:12:55 -08:00
|
|
|
addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
|
2020-10-14 11:26:22 -07:00
|
|
|
ValueRange inputs, Location loc) -> Value {
|
|
|
|
assert(inputs.size() == 1);
|
|
|
|
assert(inputs[0].getType().isa<TensorType>());
|
2021-11-25 11:42:16 +01:00
|
|
|
return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
|
2020-10-14 11:26:22 -07:00
|
|
|
});
|
2020-10-12 14:03:09 -07:00
|
|
|
}
|
|
|
|
|
2021-11-29 13:45:01 +01:00
|
|
|
void mlir::bufferization::populateBufferizeMaterializationLegality(
|
|
|
|
ConversionTarget &target) {
|
2021-11-25 11:42:16 +01:00
|
|
|
target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>();
|
2020-10-24 00:22:48 +00:00
|
|
|
}
|
2020-10-15 20:17:25 -07:00
|
|
|
|
2020-10-26 12:52:28 -07:00
|
|
|
namespace {
|
|
|
|
// In a finalizing bufferize conversion, we know that all tensors have been
|
|
|
|
// converted to memrefs, thus, this op becomes an identity.
|
2021-11-29 13:45:01 +01:00
|
|
|
class BufferizeToTensorOp
|
2021-11-25 11:42:16 +01:00
|
|
|
: public OpConversionPattern<bufferization::ToTensorOp> {
|
2020-10-26 12:52:28 -07:00
|
|
|
public:
|
|
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
2021-11-25 11:42:16 +01:00
|
|
|
matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor,
|
2020-10-26 12:52:28 -07:00
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
rewriter.replaceOp(op, adaptor.memref());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
// In a finalizing bufferize conversion, we know that all tensors have been
|
|
|
|
// converted to memrefs, thus, this op becomes an identity.
|
2021-11-29 13:45:01 +01:00
|
|
|
class BufferizeToMemrefOp
|
|
|
|
: public OpConversionPattern<bufferization::ToMemrefOp> {
|
2020-10-26 12:52:28 -07:00
|
|
|
public:
|
|
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
2021-11-25 11:42:16 +01:00
|
|
|
matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor,
|
2020-10-26 12:52:28 -07:00
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
rewriter.replaceOp(op, adaptor.tensor());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2021-11-29 13:45:01 +01:00
|
|
|
void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns(
|
2021-03-22 16:58:34 -07:00
|
|
|
BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
|
2021-11-29 13:45:01 +01:00
|
|
|
patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter,
|
|
|
|
patterns.getContext());
|
2020-10-26 12:52:28 -07:00
|
|
|
}
|
2020-11-26 13:26:08 +01:00
|
|
|
|
|
|
|
namespace {
|
|
|
|
struct FinalizingBufferizePass
|
|
|
|
: public FinalizingBufferizeBase<FinalizingBufferizePass> {
|
|
|
|
using FinalizingBufferizeBase<
|
|
|
|
FinalizingBufferizePass>::FinalizingBufferizeBase;
|
|
|
|
|
2022-01-04 15:41:17 -08:00
|
|
|
void runOnOperation() override {
|
|
|
|
auto func = getOperation();
|
2020-11-26 13:26:08 +01:00
|
|
|
auto *context = &getContext();
|
|
|
|
|
|
|
|
BufferizeTypeConverter typeConverter;
|
2021-03-22 16:58:34 -07:00
|
|
|
RewritePatternSet patterns(context);
|
2020-11-26 13:26:08 +01:00
|
|
|
ConversionTarget target(*context);
|
|
|
|
|
2021-03-20 16:29:41 -07:00
|
|
|
populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns);
|
2020-11-26 13:26:08 +01:00
|
|
|
|
|
|
|
// If all result types are legal, and all block arguments are legal (ensured
|
|
|
|
// by func conversion above), then all types in the program are legal.
|
2020-11-30 15:20:30 -08:00
|
|
|
//
|
|
|
|
// We also check that the operand types are legal to avoid creating invalid
|
|
|
|
// IR. For example, this prevents
|
|
|
|
// populateEliminateBufferizeMaterializationsPatterns from updating the
|
|
|
|
// types of the operands to a return op without updating the enclosing
|
|
|
|
// function.
|
|
|
|
target.markUnknownOpDynamicallyLegal(
|
|
|
|
[&](Operation *op) { return typeConverter.isLegal(op); });
|
2020-11-26 13:26:08 +01:00
|
|
|
|
|
|
|
if (failed(applyFullConversion(func, target, std::move(patterns))))
|
|
|
|
signalPassFailure();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-01-04 15:41:17 -08:00
|
|
|
std::unique_ptr<OperationPass<FuncOp>>
|
2021-11-29 13:45:01 +01:00
|
|
|
mlir::bufferization::createFinalizingBufferizePass() {
|
2020-11-26 13:26:08 +01:00
|
|
|
return std::make_unique<FinalizingBufferizePass>();
|
|
|
|
}
|
2022-01-20 18:14:59 +09:00
|
|
|
|
2022-01-24 23:16:29 +09:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// BufferizableOpInterface-based Bufferization
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-01-20 18:14:59 +09:00
|
|
|
static bool isaTensor(Type t) { return t.isa<TensorType>(); }
|
|
|
|
|
|
|
|
/// Return true if the given op has a tensor result or a tensor operand.
|
|
|
|
static bool hasTensorSemantics(Operation *op) {
|
|
|
|
bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
|
|
|
|
bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
|
|
|
|
return hasTensorResult || hasTensorOperand;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Rewrite pattern that bufferizes bufferizable ops.
|
|
|
|
struct BufferizationPattern
|
|
|
|
: public OpInterfaceRewritePattern<BufferizableOpInterface> {
|
|
|
|
BufferizationPattern(MLIRContext *context, const BufferizationState &state,
|
|
|
|
PatternBenefit benefit = 1)
|
|
|
|
: OpInterfaceRewritePattern<BufferizableOpInterface>(context, benefit),
|
|
|
|
state(state) {}
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(BufferizableOpInterface bufferizableOp,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
// No tensors => no buffers.
|
|
|
|
if (!hasTensorSemantics(bufferizableOp.getOperation()))
|
|
|
|
return failure();
|
|
|
|
if (!state.getOptions().isOpAllowed(bufferizableOp.getOperation()))
|
|
|
|
return failure();
|
|
|
|
return bufferizableOp.bufferize(rewriter, state);
|
|
|
|
}
|
|
|
|
|
|
|
|
private:
|
|
|
|
const BufferizationState &state;
|
|
|
|
};
|
|
|
|
|
|
|
|
/// Check the result of bufferization. Return an error if an op was not
|
|
|
|
/// bufferized, unless partial bufferization is allowed.
|
|
|
|
static LogicalResult
|
|
|
|
checkBufferizationResult(Operation *op, const BufferizationOptions &options) {
|
|
|
|
if (!options.allowUnknownOps) {
|
|
|
|
// Check if all ops were bufferized.
|
|
|
|
LogicalResult status = success();
|
|
|
|
op->walk([&](Operation *op) {
|
|
|
|
if (!hasTensorSemantics(op))
|
|
|
|
return WalkResult::advance();
|
|
|
|
|
|
|
|
// Bufferization dialect ops will canonicalize away if all other ops are
|
|
|
|
// bufferized.
|
|
|
|
if (isa<bufferization::ToMemrefOp, bufferization::ToTensorOp>(op))
|
|
|
|
return WalkResult::advance();
|
|
|
|
|
|
|
|
// Ops that are not in the allow list can be ignored.
|
|
|
|
if (!options.isOpAllowed(op))
|
|
|
|
return WalkResult::advance();
|
|
|
|
|
|
|
|
// Ops without any uses and no side effects will fold away.
|
|
|
|
if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op))
|
|
|
|
return WalkResult::advance();
|
|
|
|
|
|
|
|
status = op->emitError("op was not bufferized");
|
|
|
|
return WalkResult::interrupt();
|
|
|
|
});
|
|
|
|
|
|
|
|
if (failed(status))
|
|
|
|
return status;
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult bufferization::bufferizeOp(Operation *op,
|
|
|
|
const BufferizationState &state) {
|
|
|
|
// Bufferize the op and its nested ops.
|
2022-01-26 14:42:38 -08:00
|
|
|
RewritePatternSet patterns(op->getContext());
|
2022-01-27 19:18:59 +09:00
|
|
|
populateBufferizationPattern(state, patterns);
|
2022-01-20 18:14:59 +09:00
|
|
|
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
return checkBufferizationResult(op, state.getOptions());
|
|
|
|
}
|
2022-01-27 19:18:59 +09:00
|
|
|
|
|
|
|
namespace {
|
|
|
|
/// This a "no analysis, always copy" BufferizationState. In the absence of an
|
|
|
|
/// analysis, a buffer must be copied each time it is written to. Therefore, all
|
|
|
|
/// OpOperands that bufferize to a memory write must bufferize out-of-place.
|
|
|
|
class AlwaysCopyBufferizationState : public BufferizationState {
|
|
|
|
public:
|
|
|
|
AlwaysCopyBufferizationState(const BufferizationOptions &options)
|
|
|
|
: BufferizationState(options) {}
|
|
|
|
|
|
|
|
AlwaysCopyBufferizationState(const AlwaysCopyBufferizationState &) = delete;
|
|
|
|
|
|
|
|
virtual ~AlwaysCopyBufferizationState() = default;
|
|
|
|
|
|
|
|
/// Return `true` if the given OpResult has been decided to bufferize inplace.
|
|
|
|
bool isInPlace(OpOperand &opOperand) const override {
|
|
|
|
// OpOperands that bufferize to a memory write are out-of-place, i.e., an
|
|
|
|
// alloc and copy is inserted.
|
|
|
|
return !bufferizesToMemoryWrite(opOperand);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
|
|
|
|
bool areEquivalentBufferizedValues(Value v1, Value v2) const override {
|
|
|
|
// There is no analysis, so we do not know if the values are equivalent. The
|
|
|
|
// conservative answer is "false".
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
LogicalResult bufferization::bufferizeOp(Operation *op,
|
|
|
|
const BufferizationOptions &options) {
|
|
|
|
AlwaysCopyBufferizationState state(options);
|
|
|
|
return bufferizeOp(op, state);
|
|
|
|
}
|
|
|
|
|
|
|
|
void bufferization::populateBufferizationPattern(
|
|
|
|
const BufferizationState &state, RewritePatternSet &patterns) {
|
|
|
|
patterns.add<BufferizationPattern>(patterns.getContext(), state);
|
|
|
|
}
|
|
|
|
|
|
|
|
std::unique_ptr<BufferizationOptions>
|
|
|
|
bufferization::getPartialBufferizationOptions() {
|
|
|
|
auto options = std::make_unique<BufferizationOptions>();
|
|
|
|
options->allowReturnMemref = true;
|
|
|
|
options->allowUnknownOps = true;
|
|
|
|
options->createDeallocs = false;
|
|
|
|
options->fullyDynamicLayoutMaps = false;
|
|
|
|
return options;
|
|
|
|
}
|