mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-27 22:36:06 +00:00
[mlir][bufferize][NFC] Make PostAnalysisSteps a function
They used to be classes with a virtual `run` function. This was inconvenient because post analysis steps are stored in BufferizationOptions. Because of this design choice, BufferizationOptions were not copyable. Differential Revision: https://reviews.llvm.org/D119258
This commit is contained in:
parent
50bccf2297
commit
cdb7675c26
@ -47,9 +47,6 @@ struct BufferizationOptions {
|
||||
|
||||
BufferizationOptions();
|
||||
|
||||
// BufferizationOptions cannot be copied.
|
||||
BufferizationOptions(const BufferizationOptions &other) = delete;
|
||||
|
||||
/// Return `true` if the op is allowed to be bufferized.
|
||||
bool isOpAllowed(Operation *op) const {
|
||||
if (!hasFilter)
|
||||
|
@ -82,7 +82,7 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options);
|
||||
void populateBufferizationPattern(const BufferizationState &state,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
std::unique_ptr<BufferizationOptions> getPartialBufferizationOptions();
|
||||
BufferizationOptions getPartialBufferizationOptions();
|
||||
|
||||
} // namespace bufferization
|
||||
} // namespace mlir
|
||||
|
@ -20,35 +20,25 @@ class AnalysisBufferizationState;
|
||||
class BufferizationAliasInfo;
|
||||
struct AnalysisBufferizationOptions;
|
||||
|
||||
/// PostAnalysisSteps can be registered with `BufferizationOptions` and are
|
||||
/// PostAnalysisStepFns can be registered with `BufferizationOptions` and are
|
||||
/// executed after the analysis, but before bufferization. They can be used to
|
||||
/// implement custom dialect-specific optimizations.
|
||||
struct PostAnalysisStep {
|
||||
virtual ~PostAnalysisStep() = default;
|
||||
/// implement custom dialect-specific optimizations. They may modify the IR, but
|
||||
/// must keep `aliasInfo` consistent. Newly created operations and operations
|
||||
/// that should be re-analyzed must be added to `newOps`.
|
||||
using PostAnalysisStepFn = std::function<LogicalResult(
|
||||
Operation *, BufferizationState &, BufferizationAliasInfo &,
|
||||
SmallVector<Operation *> &)>;
|
||||
|
||||
/// Run the post analysis step. This function may modify the IR, but must keep
|
||||
/// `aliasInfo` consistent. Newly created operations and operations that
|
||||
/// should be re-analyzed must be added to `newOps`.
|
||||
virtual LogicalResult run(Operation *op, BufferizationState &state,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
SmallVector<Operation *> &newOps) = 0;
|
||||
};
|
||||
|
||||
using PostAnalysisStepList = std::vector<std::unique_ptr<PostAnalysisStep>>;
|
||||
using PostAnalysisStepList = SmallVector<PostAnalysisStepFn>;
|
||||
|
||||
/// Options for analysis-enabled bufferization.
|
||||
struct AnalysisBufferizationOptions : public BufferizationOptions {
|
||||
AnalysisBufferizationOptions() = default;
|
||||
|
||||
// AnalysisBufferizationOptions cannot be copied.
|
||||
AnalysisBufferizationOptions(const AnalysisBufferizationOptions &) = delete;
|
||||
|
||||
/// Register a "post analysis" step. Such steps are executed after the
|
||||
/// analysis, but before bufferization.
|
||||
template <typename Step, typename... Args>
|
||||
void addPostAnalysisStep(Args... args) {
|
||||
postAnalysisSteps.emplace_back(
|
||||
std::make_unique<Step>(std::forward<Args>(args)...));
|
||||
void addPostAnalysisStep(PostAnalysisStepFn fn) {
|
||||
postAnalysisSteps.push_back(fn);
|
||||
}
|
||||
|
||||
/// Registered post analysis steps.
|
||||
|
@ -18,42 +18,38 @@ namespace linalg {
|
||||
namespace comprehensive_bufferize {
|
||||
namespace linalg_ext {
|
||||
|
||||
struct InitTensorEliminationStep : public bufferization::PostAnalysisStep {
|
||||
/// A function that matches anchor OpOperands for InitTensorOp elimination.
|
||||
/// If an OpOperand is matched, the function should populate the SmallVector
|
||||
/// with all values that are needed during `RewriteFn` to produce the
|
||||
/// replacement value.
|
||||
using AnchorMatchFn = std::function<bool(OpOperand &, SmallVector<Value> &)>;
|
||||
/// A function that matches anchor OpOperands for InitTensorOp elimination.
|
||||
/// If an OpOperand is matched, the function should populate the SmallVector
|
||||
/// with all values that are needed during `RewriteFn` to produce the
|
||||
/// replacement value.
|
||||
using AnchorMatchFn = std::function<bool(OpOperand &, SmallVector<Value> &)>;
|
||||
|
||||
/// A function that rewrites matched anchors.
|
||||
using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;
|
||||
/// A function that rewrites matched anchors.
|
||||
using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;
|
||||
|
||||
/// Try to eliminate InitTensorOps inside `op`.
|
||||
///
|
||||
/// * `rewriteFunc` generates the replacement for the InitTensorOp.
|
||||
/// * Only InitTensorOps that are anchored on a matching OpOperand as per
|
||||
/// `anchorMatchFunc` are considered. "Anchored" means that there is a path
|
||||
/// on the reverse SSA use-def chain, starting from the OpOperand and always
|
||||
/// following the aliasing OpOperand, that eventually ends at a single
|
||||
/// InitTensorOp.
|
||||
/// * The result of `rewriteFunc` must usually be analyzed for inplacability.
|
||||
/// This analysis can be skipped with `skipAnalysis`.
|
||||
LogicalResult
|
||||
eliminateInitTensors(Operation *op, bufferization::BufferizationState &state,
|
||||
bufferization::BufferizationAliasInfo &aliasInfo,
|
||||
AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc,
|
||||
SmallVector<Operation *> &newOps);
|
||||
};
|
||||
/// Try to eliminate InitTensorOps inside `op`.
|
||||
///
|
||||
/// * `rewriteFunc` generates the replacement for the InitTensorOp.
|
||||
/// * Only InitTensorOps that are anchored on a matching OpOperand as per
|
||||
/// `anchorMatchFunc` are considered. "Anchored" means that there is a path
|
||||
/// on the reverse SSA use-def chain, starting from the OpOperand and always
|
||||
/// following the aliasing OpOperand, that eventually ends at a single
|
||||
/// InitTensorOp.
|
||||
/// * The result of `rewriteFunc` must usually be analyzed for inplacability.
|
||||
/// This analysis can be skipped with `skipAnalysis`.
|
||||
LogicalResult
|
||||
eliminateInitTensors(Operation *op, bufferization::BufferizationState &state,
|
||||
bufferization::BufferizationAliasInfo &aliasInfo,
|
||||
AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc,
|
||||
SmallVector<Operation *> &newOps);
|
||||
|
||||
/// Try to eliminate InitTensorOps inside `op` that are anchored on an
|
||||
/// InsertSliceOp, i.e., if it is eventually inserted into another tensor
|
||||
/// (and some other conditions are met).
|
||||
struct InsertSliceAnchoredInitTensorEliminationStep
|
||||
: public InitTensorEliminationStep {
|
||||
LogicalResult run(Operation *op, bufferization::BufferizationState &state,
|
||||
bufferization::BufferizationAliasInfo &aliasInfo,
|
||||
SmallVector<Operation *> &newOps) override;
|
||||
};
|
||||
LogicalResult insertSliceAnchoredInitTensorEliminationStep(
|
||||
Operation *op, bufferization::BufferizationState &state,
|
||||
bufferization::BufferizationAliasInfo &aliasInfo,
|
||||
SmallVector<Operation *> &newOps);
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
|
||||
|
||||
|
@ -14,16 +14,21 @@
|
||||
namespace mlir {
|
||||
class DialectRegistry;
|
||||
|
||||
namespace bufferization {
|
||||
class BufferizationState;
|
||||
class BufferizationAliasInfo;
|
||||
} // namespace bufferization
|
||||
|
||||
namespace scf {
|
||||
/// Assert that yielded values of an scf.for op are aliasing their corresponding
|
||||
/// bbArgs. This is required because the i-th OpResult of an scf.for op is
|
||||
/// currently assumed to alias with the i-th iter_arg (in the absence of
|
||||
/// conflicts).
|
||||
struct AssertScfForAliasingProperties : public bufferization::PostAnalysisStep {
|
||||
LogicalResult run(Operation *op, bufferization::BufferizationState &state,
|
||||
bufferization::BufferizationAliasInfo &aliasInfo,
|
||||
SmallVector<Operation *> &newOps) override;
|
||||
};
|
||||
LogicalResult
|
||||
assertScfForAliasingProperties(Operation *op,
|
||||
bufferization::BufferizationState &state,
|
||||
bufferization::BufferizationAliasInfo &aliasInfo,
|
||||
SmallVector<Operation *> &newOps);
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
|
||||
} // namespace scf
|
||||
|
@ -29,16 +29,15 @@ struct ArithmeticBufferizePass
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
std::unique_ptr<BufferizationOptions> options =
|
||||
getPartialBufferizationOptions();
|
||||
BufferizationOptions options = getPartialBufferizationOptions();
|
||||
if (constantOpOnly) {
|
||||
options->addToOperationFilter<arith::ConstantOp>();
|
||||
options.addToOperationFilter<arith::ConstantOp>();
|
||||
} else {
|
||||
options->addToDialectFilter<arith::ArithmeticDialect>();
|
||||
options.addToDialectFilter<arith::ArithmeticDialect>();
|
||||
}
|
||||
options->bufferAlignment = alignment;
|
||||
options.bufferAlignment = alignment;
|
||||
|
||||
if (failed(bufferizeOp(getOperation(), *options)))
|
||||
if (failed(bufferizeOp(getOperation(), options)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
|
@ -253,12 +253,11 @@ void bufferization::populateBufferizationPattern(
|
||||
patterns.add<BufferizationPattern>(patterns.getContext(), state);
|
||||
}
|
||||
|
||||
std::unique_ptr<BufferizationOptions>
|
||||
bufferization::getPartialBufferizationOptions() {
|
||||
auto options = std::make_unique<BufferizationOptions>();
|
||||
options->allowReturnMemref = true;
|
||||
options->allowUnknownOps = true;
|
||||
options->createDeallocs = false;
|
||||
options->fullyDynamicLayoutMaps = false;
|
||||
BufferizationOptions bufferization::getPartialBufferizationOptions() {
|
||||
BufferizationOptions options;
|
||||
options.allowReturnMemref = true;
|
||||
options.allowUnknownOps = true;
|
||||
options.createDeallocs = false;
|
||||
options.fullyDynamicLayoutMaps = false;
|
||||
return options;
|
||||
}
|
||||
|
@ -698,52 +698,51 @@ annotateOpsWithBufferizationMarkers(Operation *op,
|
||||
// aliasing values, which is stricter than needed. We can currently not check
|
||||
// for aliasing values because the analysis is a maybe-alias analysis and we
|
||||
// need a must-alias analysis here.
|
||||
struct AssertDestinationPassingStyle : public PostAnalysisStep {
|
||||
LogicalResult run(Operation *op, BufferizationState &state,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
SmallVector<Operation *> &newOps) override {
|
||||
LogicalResult status = success();
|
||||
DominanceInfo domInfo(op);
|
||||
op->walk([&](Operation *returnOp) {
|
||||
if (!isRegionReturnLike(returnOp))
|
||||
return WalkResult::advance();
|
||||
|
||||
for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
|
||||
Value returnVal = returnValOperand.get();
|
||||
// Skip non-tensor values.
|
||||
if (!returnVal.getType().isa<TensorType>())
|
||||
continue;
|
||||
|
||||
bool foundEquivValue = false;
|
||||
aliasInfo.applyOnEquivalenceClass(returnVal, [&](Value equivVal) {
|
||||
if (auto bbArg = equivVal.dyn_cast<BlockArgument>()) {
|
||||
Operation *definingOp = bbArg.getOwner()->getParentOp();
|
||||
if (definingOp->isProperAncestor(returnOp))
|
||||
foundEquivValue = true;
|
||||
return;
|
||||
}
|
||||
|
||||
Operation *definingOp = equivVal.getDefiningOp();
|
||||
if (definingOp->getBlock()->findAncestorOpInBlock(
|
||||
*returnOp->getParentOp()))
|
||||
// Skip ops that happen after `returnOp` and parent ops.
|
||||
if (happensBefore(definingOp, returnOp, domInfo))
|
||||
foundEquivValue = true;
|
||||
});
|
||||
|
||||
if (!foundEquivValue)
|
||||
status =
|
||||
returnOp->emitError()
|
||||
<< "operand #" << returnValOperand.getOperandNumber()
|
||||
<< " of ReturnLike op does not satisfy destination passing style";
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
assertDestinationPassingStyle(Operation *op, BufferizationState &state,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
SmallVector<Operation *> &newOps) {
|
||||
LogicalResult status = success();
|
||||
DominanceInfo domInfo(op);
|
||||
op->walk([&](Operation *returnOp) {
|
||||
if (!isRegionReturnLike(returnOp))
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
|
||||
Value returnVal = returnValOperand.get();
|
||||
// Skip non-tensor values.
|
||||
if (!returnVal.getType().isa<TensorType>())
|
||||
continue;
|
||||
|
||||
bool foundEquivValue = false;
|
||||
aliasInfo.applyOnEquivalenceClass(returnVal, [&](Value equivVal) {
|
||||
if (auto bbArg = equivVal.dyn_cast<BlockArgument>()) {
|
||||
Operation *definingOp = bbArg.getOwner()->getParentOp();
|
||||
if (definingOp->isProperAncestor(returnOp))
|
||||
foundEquivValue = true;
|
||||
return;
|
||||
}
|
||||
|
||||
Operation *definingOp = equivVal.getDefiningOp();
|
||||
if (definingOp->getBlock()->findAncestorOpInBlock(
|
||||
*returnOp->getParentOp()))
|
||||
// Skip ops that happen after `returnOp` and parent ops.
|
||||
if (happensBefore(definingOp, returnOp, domInfo))
|
||||
foundEquivValue = true;
|
||||
});
|
||||
|
||||
if (!foundEquivValue)
|
||||
status =
|
||||
returnOp->emitError()
|
||||
<< "operand #" << returnValOperand.getOperandNumber()
|
||||
<< " of ReturnLike op does not satisfy destination passing style";
|
||||
}
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
LogicalResult bufferization::analyzeOp(Operation *op,
|
||||
AnalysisBufferizationState &state) {
|
||||
@ -761,12 +760,11 @@ LogicalResult bufferization::analyzeOp(Operation *op,
|
||||
return failure();
|
||||
equivalenceAnalysis(op, aliasInfo, state);
|
||||
|
||||
for (const std::unique_ptr<PostAnalysisStep> &step :
|
||||
options.postAnalysisSteps) {
|
||||
for (const PostAnalysisStepFn &fn : options.postAnalysisSteps) {
|
||||
SmallVector<Operation *> newOps;
|
||||
if (failed(step->run(op, state, aliasInfo, newOps)))
|
||||
if (failed(fn(op, state, aliasInfo, newOps)))
|
||||
return failure();
|
||||
// Analyze ops that were created by the PostAnalysisStep.
|
||||
// Analyze ops that were created by the PostAnalysisStepFn.
|
||||
if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo)))
|
||||
return failure();
|
||||
equivalenceAnalysis(newOps, aliasInfo, state);
|
||||
@ -774,8 +772,7 @@ LogicalResult bufferization::analyzeOp(Operation *op,
|
||||
|
||||
if (!options.allowReturnMemref) {
|
||||
SmallVector<Operation *> newOps;
|
||||
if (failed(
|
||||
AssertDestinationPassingStyle().run(op, state, aliasInfo, newOps)))
|
||||
if (failed(assertDestinationPassingStyle(op, state, aliasInfo, newOps)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
|
@ -524,11 +524,10 @@ findValidInsertionPoint(Operation *initTensorOp,
|
||||
/// chain, starting from the OpOperand and always following the aliasing
|
||||
/// OpOperand, that eventually ends at a single InitTensorOp.
|
||||
LogicalResult
|
||||
mlir::linalg::comprehensive_bufferize::linalg_ext::InitTensorEliminationStep::
|
||||
eliminateInitTensors(Operation *op, BufferizationState &state,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc,
|
||||
SmallVector<Operation *> &newOps) {
|
||||
mlir::linalg::comprehensive_bufferize::linalg_ext::eliminateInitTensors(
|
||||
Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo,
|
||||
AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc,
|
||||
SmallVector<Operation *> &newOps) {
|
||||
OpBuilder b(op->getContext());
|
||||
|
||||
WalkResult status = op->walk([&](Operation *op) {
|
||||
@ -628,7 +627,7 @@ mlir::linalg::comprehensive_bufferize::linalg_ext::InitTensorEliminationStep::
|
||||
/// Note that the newly inserted ExtractSliceOp may have to bufferize
|
||||
/// out-of-place due to RaW conflicts.
|
||||
LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
|
||||
InsertSliceAnchoredInitTensorEliminationStep::run(
|
||||
insertSliceAnchoredInitTensorEliminationStep(
|
||||
Operation *op, BufferizationState &state,
|
||||
BufferizationAliasInfo &aliasInfo, SmallVector<Operation *> &newOps) {
|
||||
return eliminateInitTensors(
|
||||
|
@ -16,11 +16,12 @@
|
||||
// their respective callers.
|
||||
//
|
||||
// After analyzing a FuncOp, additional information about its bbArgs is
|
||||
// gathered through PostAnalysisSteps and stored in `ModuleBufferizationState`.
|
||||
// gathered through PostAnalysisStepFns and stored in
|
||||
// `ModuleBufferizationState`.
|
||||
//
|
||||
// * `EquivalentFuncOpBBArgsAnalysis` determines the equivalent bbArg for each
|
||||
// * `equivalentFuncOpBBArgsAnalysis` determines the equivalent bbArg for each
|
||||
// tensor return value (if any).
|
||||
// * `FuncOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
|
||||
// * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
|
||||
// read/written.
|
||||
//
|
||||
// Only tensors that are equivalent to some FuncOp bbArg may be returned.
|
||||
@ -47,7 +48,7 @@
|
||||
// modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize
|
||||
// out-of-place because `%t0` is modified by the callee but read by the
|
||||
// tensor.extract op. The analysis of CallOps decides whether an OpOperand must
|
||||
// bufferize out-of-place based on results of `FuncOpBbArgReadWriteAnalysis`.
|
||||
// bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`.
|
||||
// ```
|
||||
// func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> {
|
||||
// %f = ... : f32
|
||||
@ -62,7 +63,7 @@
|
||||
// }
|
||||
// ```
|
||||
//
|
||||
// Note: If a function is external, `FuncOpBbArgReadWriteAnalysis` cannot
|
||||
// Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot
|
||||
// analyze the function body. In such a case, the CallOp analysis conservatively
|
||||
// assumes that each tensor OpOperand is both read and written.
|
||||
//
|
||||
@ -159,55 +160,55 @@ static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// Annotate IR with the results of the analysis. For testing purposes only.
|
||||
static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
|
||||
BlockArgument bbArg) {
|
||||
const char *kEquivalentArgsAttr = "__equivalent_func_args__";
|
||||
Operation *op = returnVal.getOwner();
|
||||
|
||||
SmallVector<int64_t> equivBbArgs;
|
||||
if (op->hasAttr(kEquivalentArgsAttr)) {
|
||||
auto attr = op->getAttr(kEquivalentArgsAttr).cast<ArrayAttr>();
|
||||
equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) {
|
||||
return a.cast<IntegerAttr>().getValue().getSExtValue();
|
||||
}));
|
||||
} else {
|
||||
equivBbArgs.append(op->getNumOperands(), -1);
|
||||
}
|
||||
equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber();
|
||||
|
||||
OpBuilder b(op->getContext());
|
||||
op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
|
||||
}
|
||||
|
||||
/// Store function BlockArguments that are equivalent to a returned value in
|
||||
/// ModuleBufferizationState.
|
||||
struct EquivalentFuncOpBBArgsAnalysis : public PostAnalysisStep {
|
||||
/// Annotate IR with the results of the analysis. For testing purposes only.
|
||||
static void annotateReturnOp(OpOperand &returnVal, BlockArgument bbArg) {
|
||||
const char *kEquivalentArgsAttr = "__equivalent_func_args__";
|
||||
Operation *op = returnVal.getOwner();
|
||||
static LogicalResult
|
||||
equivalentFuncOpBBArgsAnalysis(Operation *op, BufferizationState &state,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
SmallVector<Operation *> &newOps) {
|
||||
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
|
||||
|
||||
SmallVector<int64_t> equivBbArgs;
|
||||
if (op->hasAttr(kEquivalentArgsAttr)) {
|
||||
auto attr = op->getAttr(kEquivalentArgsAttr).cast<ArrayAttr>();
|
||||
equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) {
|
||||
return a.cast<IntegerAttr>().getValue().getSExtValue();
|
||||
}));
|
||||
} else {
|
||||
equivBbArgs.append(op->getNumOperands(), -1);
|
||||
}
|
||||
equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber();
|
||||
// Support only single return-terminated block in the function.
|
||||
auto funcOp = cast<FuncOp>(op);
|
||||
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
|
||||
assert(returnOp && "expected func with single return op");
|
||||
|
||||
OpBuilder b(op->getContext());
|
||||
op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
|
||||
}
|
||||
for (OpOperand &returnVal : returnOp->getOpOperands())
|
||||
if (returnVal.get().getType().isa<RankedTensorType>())
|
||||
for (BlockArgument bbArg : funcOp.getArguments())
|
||||
if (bbArg.getType().isa<RankedTensorType>())
|
||||
if (aliasInfo.areEquivalentBufferizedValues(returnVal.get(), bbArg)) {
|
||||
moduleState
|
||||
.equivalentFuncArgs[funcOp][returnVal.getOperandNumber()] =
|
||||
bbArg.getArgNumber();
|
||||
if (state.getOptions().testAnalysisOnly)
|
||||
annotateEquivalentReturnBbArg(returnVal, bbArg);
|
||||
}
|
||||
|
||||
LogicalResult run(Operation *op, BufferizationState &state,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
SmallVector<Operation *> &newOps) override {
|
||||
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
|
||||
|
||||
// Support only single return-terminated block in the function.
|
||||
auto funcOp = cast<FuncOp>(op);
|
||||
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
|
||||
assert(returnOp && "expected func with single return op");
|
||||
|
||||
for (OpOperand &returnVal : returnOp->getOpOperands())
|
||||
if (returnVal.get().getType().isa<RankedTensorType>())
|
||||
for (BlockArgument bbArg : funcOp.getArguments())
|
||||
if (bbArg.getType().isa<RankedTensorType>())
|
||||
if (aliasInfo.areEquivalentBufferizedValues(returnVal.get(),
|
||||
bbArg)) {
|
||||
moduleState
|
||||
.equivalentFuncArgs[funcOp][returnVal.getOperandNumber()] =
|
||||
bbArg.getArgNumber();
|
||||
if (state.getOptions().testAnalysisOnly)
|
||||
annotateReturnOp(returnVal, bbArg);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Return true if the buffer of the given tensor value is written to. Must not
|
||||
/// be called for values inside not yet analyzed functions. (Post-analysis
|
||||
@ -239,38 +240,37 @@ static bool isValueWritten(Value value, const BufferizationState &state,
|
||||
}
|
||||
|
||||
/// Determine which FuncOp bbArgs are read and which are written. If this
|
||||
/// PostAnalysisStep is run on a function with unknown ops, it will
|
||||
/// PostAnalysisStepFn is run on a function with unknown ops, it will
|
||||
/// conservatively assume that such ops bufferize to a read + write.
|
||||
struct FuncOpBbArgReadWriteAnalysis : public PostAnalysisStep {
|
||||
LogicalResult run(Operation *op, BufferizationState &state,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
SmallVector<Operation *> &newOps) override {
|
||||
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
|
||||
auto funcOp = cast<FuncOp>(op);
|
||||
|
||||
// If the function has no body, conservatively assume that all args are
|
||||
// read + written.
|
||||
if (funcOp.getBody().empty()) {
|
||||
for (BlockArgument bbArg : funcOp.getArguments()) {
|
||||
moduleState.readBbArgs.insert(bbArg);
|
||||
moduleState.writtenBbArgs.insert(bbArg);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
static LogicalResult
|
||||
funcOpBbArgReadWriteAnalysis(Operation *op, BufferizationState &state,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
SmallVector<Operation *> &newOps) {
|
||||
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
|
||||
auto funcOp = cast<FuncOp>(op);
|
||||
|
||||
// If the function has no body, conservatively assume that all args are
|
||||
// read + written.
|
||||
if (funcOp.getBody().empty()) {
|
||||
for (BlockArgument bbArg : funcOp.getArguments()) {
|
||||
if (!bbArg.getType().isa<TensorType>())
|
||||
continue;
|
||||
if (state.isValueRead(bbArg))
|
||||
moduleState.readBbArgs.insert(bbArg);
|
||||
if (isValueWritten(bbArg, state, aliasInfo))
|
||||
moduleState.writtenBbArgs.insert(bbArg);
|
||||
moduleState.readBbArgs.insert(bbArg);
|
||||
moduleState.writtenBbArgs.insert(bbArg);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
for (BlockArgument bbArg : funcOp.getArguments()) {
|
||||
if (!bbArg.getType().isa<TensorType>())
|
||||
continue;
|
||||
if (state.isValueRead(bbArg))
|
||||
moduleState.readBbArgs.insert(bbArg);
|
||||
if (isValueWritten(bbArg, state, aliasInfo))
|
||||
moduleState.writtenBbArgs.insert(bbArg);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
static bool isaTensor(Type t) { return t.isa<TensorType>(); }
|
||||
@ -983,10 +983,8 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
|
||||
return failure();
|
||||
|
||||
// Collect bbArg/return value information after the analysis.
|
||||
options->postAnalysisSteps.emplace_back(
|
||||
std::make_unique<EquivalentFuncOpBBArgsAnalysis>());
|
||||
options->postAnalysisSteps.emplace_back(
|
||||
std::make_unique<FuncOpBbArgReadWriteAnalysis>());
|
||||
options->postAnalysisSteps.push_back(equivalentFuncOpBBArgsAnalysis);
|
||||
options->postAnalysisSteps.push_back(funcOpBbArgReadWriteAnalysis);
|
||||
|
||||
// Analyze ops.
|
||||
for (FuncOp funcOp : moduleState.orderedFuncOps) {
|
||||
|
@ -125,12 +125,12 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
|
||||
|
||||
// Enable InitTensorOp elimination.
|
||||
if (initTensorElimination) {
|
||||
options->addPostAnalysisStep<
|
||||
linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>();
|
||||
options->addPostAnalysisStep(
|
||||
linalg_ext::insertSliceAnchoredInitTensorEliminationStep);
|
||||
}
|
||||
|
||||
// Only certain scf.for ops are supported by the analysis.
|
||||
options->addPostAnalysisStep<scf::AssertScfForAliasingProperties>();
|
||||
options->addPostAnalysisStep(scf::assertScfForAliasingProperties);
|
||||
|
||||
ModuleOp moduleOp = getOperation();
|
||||
applyEnablingTransformations(moduleOp);
|
||||
|
@ -432,7 +432,7 @@ struct YieldOpInterface
|
||||
} // namespace scf
|
||||
} // namespace mlir
|
||||
|
||||
LogicalResult mlir::scf::AssertScfForAliasingProperties::run(
|
||||
LogicalResult mlir::scf::assertScfForAliasingProperties(
|
||||
Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo,
|
||||
SmallVector<Operation *> &newOps) {
|
||||
LogicalResult status = success();
|
||||
|
@ -30,11 +30,10 @@ using namespace bufferization;
|
||||
namespace {
|
||||
struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
|
||||
void runOnOperation() override {
|
||||
std::unique_ptr<BufferizationOptions> options =
|
||||
getPartialBufferizationOptions();
|
||||
options->addToDialectFilter<tensor::TensorDialect>();
|
||||
BufferizationOptions options = getPartialBufferizationOptions();
|
||||
options.addToDialectFilter<tensor::TensorDialect>();
|
||||
|
||||
if (failed(bufferizeOp(getOperation(), *options)))
|
||||
if (failed(bufferizeOp(getOperation(), options)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
|
@ -104,7 +104,7 @@ void TestComprehensiveFunctionBufferize::runOnOperation() {
|
||||
auto options = std::make_unique<AnalysisBufferizationOptions>();
|
||||
|
||||
if (!allowReturnMemref)
|
||||
options->addPostAnalysisStep<scf::AssertScfForAliasingProperties>();
|
||||
options->addPostAnalysisStep(scf::assertScfForAliasingProperties);
|
||||
|
||||
options->allowReturnMemref = allowReturnMemref;
|
||||
options->allowUnknownOps = allowUnknownOps;
|
||||
|
Loading…
x
Reference in New Issue
Block a user