DestinationPassingStyle: allow additional non-tensor results

Also some simplifications:

* `outputBufferOperands` was unused.
* The condition that the number of operands equals the number of inputs
  plus the number of inits seemed vacuously true (?).

Differential Revision: https://reviews.llvm.org/D150376
This commit is contained in:
Benoit Jacob 2023-05-11 15:44:12 +00:00
parent 2c52a18925
commit 2bd6077d7f
4 changed files with 108 additions and 18 deletions

View File

@ -22,35 +22,38 @@ OpOperandVector::operator SmallVector<Value>() {
return result;
}
namespace {
size_t getNumTensorResults(Operation *op) {
size_t numTensorResults = 0;
for (auto t : op->getResultTypes()) {
if (isa<TensorType>(t)) {
++numTensorResults;
}
}
return numTensorResults;
}
} // namespace
LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
DestinationStyleOpInterface dstStyleOp =
cast<DestinationStyleOpInterface>(op);
SmallVector<OpOperand *> outputBufferOperands, outputTensorOperands;
SmallVector<OpOperand *> outputTensorOperands;
for (OpOperand *operand : dstStyleOp.getDpsInitOperands()) {
Type type = operand->get().getType();
if (isa<MemRefType>(type)) {
outputBufferOperands.push_back(operand);
} else if (isa<RankedTensorType>(type)) {
if (isa<RankedTensorType>(type)) {
outputTensorOperands.push_back(operand);
} else {
} else if (!isa<MemRefType>(type)) {
return op->emitOpError("expected that operand #")
<< operand->getOperandNumber()
<< " is a ranked tensor or a ranked memref";
}
}
// Expect at least one output operand.
int64_t numInputs = dstStyleOp.getNumDpsInputs();
int64_t numInits = dstStyleOp.getNumDpsInits();
if (numInits == 0)
return op->emitOpError("expected at least one output operand");
if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numInits)))
return failure();
// Verify the number of results matches the number of output tensors.
if (op->getNumResults() != outputTensorOperands.size())
return op->emitOpError("expected the number of results (")
<< op->getNumResults()
// Verify the number of tensor results matches the number of output tensors.
if (getNumTensorResults(op) != outputTensorOperands.size())
return op->emitOpError("expected the number of tensor results (")
<< getNumTensorResults(op)
<< ") to be equal to the number of output tensors ("
<< outputTensorOperands.size() << ")";

View File

@ -326,7 +326,7 @@ func.func @matching_inits(%m: memref<?x?xf32>, %t: tensor<?x?xf32>) {
func.func @illegal_fill_tensor_no_return(%arg0 : index, %arg1 : index, %arg2 : f32)
{
%0 = tensor.empty(%arg0, %arg1) : tensor<?x?xf32>
// expected-error @+1 {{expected the number of results (0) to be equal to the number of output tensors (1)}}
// expected-error @+1 {{expected the number of tensor results (0) to be equal to the number of output tensors (1)}}
linalg.fill ins(%arg2 : f32) outs(%0 : tensor<?x?xf32>)
}
@ -335,7 +335,7 @@ func.func @illegal_fill_tensor_no_return(%arg0 : index, %arg1 : index, %arg2 : f
func.func @illegal_fill_memref_with_tensor_return
(%arg0 : memref<?x?xf32>, %arg1 : f32) -> tensor<?x?xf32>
{
// expected-error @+1 {{expected the number of results (1) to be equal to the number of output tensors (0)}}
// expected-error @+1 {{expected the number of tensor results (1) to be equal to the number of output tensors (0)}}
%0 = linalg.fill ins(%arg1 : f32) outs(%arg0 : memref<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}

View File

@ -0,0 +1,59 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
func.func @ins_1_index_outs_none_results_1_index(%arg0 : index) -> index {
%0 = test.destination_style_op ins(%arg0 : index) -> index
func.return %0 : index
}
// -----
func.func @ins_1_index_outs_1_tensor_results_1_index(%arg0 : index, %arg1 : tensor<2x2xf32>) -> index {
// expected-error @+1 {{op expected the number of tensor results (0) to be equal to the number of output tensors (1)}}
%0 = test.destination_style_op ins(%arg0 : index) outs(%arg1 : tensor<2x2xf32>) -> index
func.return %0 : index
}
// -----
func.func @ins_1_tensor_outs_none_results_1_index(%arg0 :tensor<2x2xf32>) -> index {
%0 = test.destination_style_op ins(%arg0 : tensor<2x2xf32>) -> index
func.return %0 : index
}
// -----
func.func @ins_1_tensor_outs_1_tensor_results_1_index(%arg0 :tensor<2x2xf32>, %arg1 : tensor<2x2xf32>) -> index {
// expected-error @+1 {{op expected the number of tensor results (0) to be equal to the number of output tensors (1)}}
%0 = test.destination_style_op ins(%arg0 : tensor<2x2xf32>) outs(%arg1 : tensor<2x2xf32>) -> index
func.return %0 : index
}
// -----
func.func @ins_1_index_outs_none_results_1_tensor(%arg0 : index) -> tensor<2x2xf32> {
// expected-error @+1 {{op expected the number of tensor results (1) to be equal to the number of output tensors (0)}}
%0 = test.destination_style_op ins(%arg0 : index) -> tensor<2x2xf32>
func.return %0 : tensor<2x2xf32>
}
// -----
func.func @ins_1_index_outs_1_tensor_results_1_tensor(%arg0 : index, %arg1 : tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = test.destination_style_op ins(%arg0 : index) outs(%arg1 : tensor<2x2xf32>) -> tensor<2x2xf32>
func.return %0 : tensor<2x2xf32>
}
// -----
func.func @ins_1_tensor_outs_none_results_1_tensor(%arg0 :tensor<2x2xf32>) -> tensor<2x2xf32> {
// expected-error @+1 {{op expected the number of tensor results (1) to be equal to the number of output tensors (0)}}
%0 = test.destination_style_op ins(%arg0 : tensor<2x2xf32>) -> tensor<2x2xf32>
func.return %0 : tensor<2x2xf32>
}
// -----
func.func @ins_1_tensor_outs_1_tensor_results_1_tensor(%arg0 :tensor<2x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = test.destination_style_op ins(%arg0 : tensor<2x2xf32>) outs(%arg1 : tensor<2x2xf32>) -> tensor<2x2xf32>
func.return %0 : tensor<2x2xf32>
}

View File

@ -2908,6 +2908,34 @@ def OpCrashShort : TEST_Op<"op_crash_short"> {
def : Pat<(OpCrashLong $_, $_, $_), (OpCrashShort)>;
//===----------------------------------------------------------------------===//
// Test DestinationStyleOpInterface.
//===----------------------------------------------------------------------===//
def TestDestinationStyleOp :
TEST_Op<"destination_style_op", [
DestinationStyleOpInterface,
AttrSizedOperandSegments]> {
let arguments = (ins
Variadic<AnyType>:$inputs,
Variadic<AnyType>:$outputs,
Variadic<AnyType>:$other_operands);
let results = (outs Variadic<AnyType>:$results);
let assemblyFormat = [{
attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)?
(`outs` `(` $outputs^ `:` type($outputs) `)`)?
(`(` $other_operands^ `:` type($other_operands) `)`)?
(`->` type($results)^)?
}];
let extraClassDeclaration = [{
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
int64_t numOperands = this->getNumOperands();
return {numOperands - getOutputs().size(), numOperands};
}
}];
}
//===----------------------------------------------------------------------===//
// Test LinalgConvolutionOpInterface.
//===----------------------------------------------------------------------===//