llvm-project/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
Matthias Springer 0b2197b0cf
[mlir][Interfaces] Clean up DestinationStyleOpInterface (#67015)
* "init" operands are specified with `MutableOperandRange` (which gives
access to the underlying `OpOperand *`). No more magic numbers.
* Remove most interface methods and make them helper functions. Only
`getInitsMutable` should be implemented.
* Provide separate helper functions for accessing mutable/immutable
operands (`OpOperand`/`Value`, in line with #66515): `getInitsMutable`
and `getInits` (same naming convention as auto-generated op accessors).
`getInputOperands` was not renamed because this function cannot return a
`MutableOperandRange` (because the operands are not necessarily
consecutive). `OpOperandVector` is no longer needed.
* The new `getDpsInits`/`getDpsInitsMutable` is more efficient than the
old `getDpsInitOperands` because no `SmallVector` is created. The new
functions return a range of operands.
* Fix a bug in `getDpsInputOperands`: out-of-bounds operands were
potentially returned.
2023-09-21 18:04:08 +02:00

63 lines
2.2 KiB
C++

//===- DestinationStyleOpInterface.cpp -- Destination style ops -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
using namespace mlir;
namespace mlir {
#include "mlir/Interfaces/DestinationStyleOpInterface.cpp.inc"
} // namespace mlir
namespace {
size_t getNumTensorResults(Operation *op) {
size_t numTensorResults = 0;
for (auto t : op->getResultTypes()) {
if (isa<TensorType>(t)) {
++numTensorResults;
}
}
return numTensorResults;
}
} // namespace
LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
DestinationStyleOpInterface dstStyleOp =
cast<DestinationStyleOpInterface>(op);
SmallVector<OpOperand *> outputTensorOperands;
for (OpOperand &operand : dstStyleOp.getDpsInitsMutable()) {
Type type = operand.get().getType();
if (isa<RankedTensorType>(type)) {
outputTensorOperands.push_back(&operand);
} else if (!isa<MemRefType>(type)) {
return op->emitOpError("expected that operand #")
<< operand.getOperandNumber()
<< " is a ranked tensor or a ranked memref";
}
}
// Verify the number of tensor results matches the number of output tensors.
if (getNumTensorResults(op) != outputTensorOperands.size())
return op->emitOpError("expected the number of tensor results (")
<< getNumTensorResults(op)
<< ") to be equal to the number of output tensors ("
<< outputTensorOperands.size() << ")";
for (OpOperand *opOperand : outputTensorOperands) {
OpResult result = dstStyleOp.getTiedOpResult(opOperand);
if (result.getType() != opOperand->get().getType())
return op->emitOpError("expected type of operand #")
<< opOperand->getOperandNumber() << " ("
<< opOperand->get().getType() << ")"
<< " to match type of corresponding result (" << result.getType()
<< ")";
}
return success();
}