mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-25 15:36:07 +00:00

Rename interface functions as follows: * `hasTensorSemantics` -> `hasPureTensorSemantics` * `hasBufferSemantics` -> `hasPureBufferSemantics` These two functions return "true" if the op has tensor/buffer operands but not buffer/tensor operands. Also drop the "ranked" part from the interface, i.e., do not distinguish between ranked/unranked types. The new function names describe the functions more accurately. They also align their semantics with the notion of "tensor semantics" with the bufferization framework. (An op is supposed to be bufferized if it has tensor operands, and we don't care if it also has memref operands.) This change is in preparation of #75273, which adds `BufferizableOpInterface::hasTensorSemantics`. By renaming the functions in the `DestinationStyleOpInterface`, we can avoid name clashes between the two interfaces.
63 lines
2.1 KiB
C++
63 lines
2.1 KiB
C++
//===- DestinationStyleOpInterface.cpp -- Destination style ops -----------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace mlir {
|
|
#include "mlir/Interfaces/DestinationStyleOpInterface.cpp.inc"
|
|
} // namespace mlir
|
|
|
|
namespace {
|
|
size_t getNumTensorResults(Operation *op) {
|
|
size_t numTensorResults = 0;
|
|
for (auto t : op->getResultTypes()) {
|
|
if (isa<TensorType>(t)) {
|
|
++numTensorResults;
|
|
}
|
|
}
|
|
return numTensorResults;
|
|
}
|
|
} // namespace
|
|
|
|
LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
|
|
DestinationStyleOpInterface dstStyleOp =
|
|
cast<DestinationStyleOpInterface>(op);
|
|
|
|
SmallVector<OpOperand *> outputTensorOperands;
|
|
for (OpOperand &operand : dstStyleOp.getDpsInitsMutable()) {
|
|
Type type = operand.get().getType();
|
|
if (isa<TensorType>(type)) {
|
|
outputTensorOperands.push_back(&operand);
|
|
} else if (!isa<BaseMemRefType>(type)) {
|
|
return op->emitOpError("expected that operand #")
|
|
<< operand.getOperandNumber() << " is a tensor or a memref";
|
|
}
|
|
}
|
|
|
|
// Verify the number of tensor results matches the number of output tensors.
|
|
if (getNumTensorResults(op) != outputTensorOperands.size())
|
|
return op->emitOpError("expected the number of tensor results (")
|
|
<< getNumTensorResults(op)
|
|
<< ") to be equal to the number of output tensors ("
|
|
<< outputTensorOperands.size() << ")";
|
|
|
|
for (OpOperand *opOperand : outputTensorOperands) {
|
|
OpResult result = dstStyleOp.getTiedOpResult(opOperand);
|
|
if (result.getType() != opOperand->get().getType())
|
|
return op->emitOpError("expected type of operand #")
|
|
<< opOperand->getOperandNumber() << " ("
|
|
<< opOperand->get().getType() << ")"
|
|
<< " to match type of corresponding result (" << result.getType()
|
|
<< ")";
|
|
}
|
|
|
|
return success();
|
|
}
|