[mlir][spirv] Implement vector type legalization for function signatures (#98337)

### Description
This PR implements a minimal version of function signature conversion to
unroll vectors into 1D and with a size supported by SPIR-V (2, 3 or 4
depending on the original dimension). This PR also includes new unit
tests that only check for function signature conversion.

### Future Plans
- Check for capabilities that support vectors of size 8 or 16.
- Set up `OneToNTypeConversion` and `DialectConversion` to replace the
current implementation that uses `GreedyPatternRewriteDriver`.
- Introduce other vector unrolling patterns to cancel out the
`vector.insert_strided_slice` and `vector.extract_strided_slice` ops and
fully legalize the vector types in the function body.
- Handle `func::CallOp` and declarations.
- Restructure the code in `SPIRVConversion.cpp`.
- Create test passes for testing sets of patterns in isolation.
- Optimize the way original shape is splitted into target shapes, e.g.
`vector<5xi32>` can be splitted into `vector<4xi32>` and
`vector<1xi32>`.

---------

Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
This commit is contained in:
Angel Zhang 2024-07-17 13:09:15 -04:00 committed by GitHub
parent c7b08ac01f
commit 6867e49fc8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 579 additions and 14 deletions

View File

@ -40,7 +40,15 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
let description = [{
This is a generic pass to convert to SPIR-V.
}];
let dependentDialects = ["spirv::SPIRVDialect"];
let dependentDialects = [
"spirv::SPIRVDialect",
"vector::VectorDialect",
];
let options = [
Option<"runSignatureConversion", "run-signature-conversion", "bool",
/*default=*/"true",
"Run function signature conversion to convert vector types">
];
}
//===----------------------------------------------------------------------===//

View File

@ -17,7 +17,9 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/OneToNTypeConversion.h"
#include "llvm/ADT/SmallSet.h"
namespace mlir {
@ -134,6 +136,10 @@ private:
void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns);
void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns);
namespace spirv {
class AccessChainOp;

View File

@ -39,18 +39,31 @@ namespace {
/// A pass to perform the SPIR-V conversion.
struct ConvertToSPIRVPass final
: impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
using ConvertToSPIRVPassBase::ConvertToSPIRVPassBase;
void runOnOperation() override {
MLIRContext *context = &getContext();
Operation *op = getOperation();
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
SPIRVTypeConverter typeConverter(targetAttr);
if (runSignatureConversion) {
// Unroll vectors in function signatures to native vector size.
RewritePatternSet patterns(context);
populateFuncOpVectorRewritePatterns(patterns);
populateReturnOpVectorRewritePatterns(patterns);
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
return signalPassFailure();
}
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
RewritePatternSet patterns(context);
ScfToSPIRVContext scfToSPIRVContext;
// Populate patterns.
// Populate patterns for each dialect.
arith::populateCeilFloorDivExpandOpsPatterns(patterns);
arith::populateArithToSPIRVPatterns(typeConverter, patterns);
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
@ -60,9 +73,6 @@ struct ConvertToSPIRVPass final
populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
if (failed(applyPartialConversion(op, *target, std::move(patterns))))
return signalPassFailure();
}

View File

@ -16,9 +16,15 @@ add_mlir_dialect_library(MLIRSPIRVConversion
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
LINK_LIBS PUBLIC
MLIRArithDialect
MLIRDialectUtils
MLIRFuncDialect
MLIRIR
MLIRSPIRVDialect
MLIRSupport
MLIRTransformUtils
MLIRVectorDialect
MLIRVectorTransforms
)
add_mlir_dialect_library(MLIRSPIRVTransforms

View File

@ -11,14 +11,24 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/OneToNTypeConversion.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
@ -34,6 +44,43 @@ using namespace mlir;
// Utility functions
//===----------------------------------------------------------------------===//
static int getComputeVectorSize(int64_t size) {
for (int i : {4, 3, 2}) {
if (size % i == 0)
return i;
}
return 1;
}
static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");
if (vecType.isScalable()) {
LLVM_DEBUG(llvm::dbgs()
<< "--scalable vectors are not supported -> BAIL\n");
return std::nullopt;
}
SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
std::optional<SmallVector<int64_t>> targetShape =
SmallVector<int64_t>(1, getComputeVectorSize(vecType.getShape().back()));
if (!targetShape) {
LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n");
return std::nullopt;
}
auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
if (!maybeShapeRatio) {
LLVM_DEBUG(llvm::dbgs()
<< "--could not compute integral shape ratio -> BAIL\n");
return std::nullopt;
}
if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
LLVM_DEBUG(llvm::dbgs() << "--no unrolling needed -> SKIP\n");
return std::nullopt;
}
LLVM_DEBUG(llvm::dbgs()
<< "--found an integral shape ratio to unroll to -> SUCCESS\n");
return targetShape;
}
/// Checks that `candidates` extension requirements are possible to be satisfied
/// with the given `targetEnv`.
///
@ -813,6 +860,249 @@ void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
}
//===----------------------------------------------------------------------===//
// func::FuncOp Conversion Patterns
//===----------------------------------------------------------------------===//
namespace {
/// A pattern for rewriting function signature to convert vector arguments of
/// functions to be of valid types
struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(func::FuncOp funcOp,
PatternRewriter &rewriter) const override {
FunctionType fnType = funcOp.getFunctionType();
// TODO: Handle declarations.
if (funcOp.isDeclaration()) {
LLVM_DEBUG(llvm::dbgs()
<< fnType << " illegal: declarations are unsupported\n");
return failure();
}
// Create a new func op with the original type and copy the function body.
auto newFuncOp = rewriter.create<func::FuncOp>(funcOp.getLoc(),
funcOp.getName(), fnType);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
Location loc = newFuncOp.getBody().getLoc();
Block &entryBlock = newFuncOp.getBlocks().front();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&entryBlock);
OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
// For arguments that are of illegal types and require unrolling.
// `unrolledInputNums` stores the indices of arguments that result from
// unrolling in the new function signature. `newInputNo` is a counter.
SmallVector<size_t> unrolledInputNums;
size_t newInputNo = 0;
// For arguments that are of legal types and do not require unrolling.
// `tmpOps` stores a mapping from temporary operations that serve as
// placeholders for new arguments that will be added later. These operations
// will be erased once the entry block's argument list is updated.
llvm::SmallDenseMap<Operation *, size_t> tmpOps;
// This counts the number of new operations created.
size_t newOpCount = 0;
// Enumerate through the arguments.
for (auto [origInputNo, origType] : enumerate(fnType.getInputs())) {
// Check whether the argument is of vector type.
auto origVecType = dyn_cast<VectorType>(origType);
if (!origVecType) {
// We need a placeholder for the old argument that will be erased later.
Value result = rewriter.create<arith::ConstantOp>(
loc, origType, rewriter.getZeroAttr(origType));
rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
tmpOps.insert({result.getDefiningOp(), newInputNo});
oneToNTypeMapping.addInputs(origInputNo, origType);
++newInputNo;
++newOpCount;
continue;
}
// Check whether the vector needs unrolling.
auto targetShape = getTargetShape(origVecType);
if (!targetShape) {
// We need a placeholder for the old argument that will be erased later.
Value result = rewriter.create<arith::ConstantOp>(
loc, origType, rewriter.getZeroAttr(origType));
rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
tmpOps.insert({result.getDefiningOp(), newInputNo});
oneToNTypeMapping.addInputs(origInputNo, origType);
++newInputNo;
++newOpCount;
continue;
}
VectorType unrolledType =
VectorType::get(*targetShape, origVecType.getElementType());
auto originalShape =
llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
// Prepare the result vector.
Value result = rewriter.create<arith::ConstantOp>(
loc, origVecType, rewriter.getZeroAttr(origVecType));
++newOpCount;
// Prepare the placeholder for the new arguments that will be added later.
Value dummy = rewriter.create<arith::ConstantOp>(
loc, unrolledType, rewriter.getZeroAttr(unrolledType));
++newOpCount;
// Create the `vector.insert_strided_slice` ops.
SmallVector<int64_t> strides(targetShape->size(), 1);
SmallVector<Type> newTypes;
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalShape, *targetShape)) {
result = rewriter.create<vector::InsertStridedSliceOp>(
loc, dummy, result, offsets, strides);
newTypes.push_back(unrolledType);
unrolledInputNums.push_back(newInputNo);
++newInputNo;
++newOpCount;
}
rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
oneToNTypeMapping.addInputs(origInputNo, newTypes);
}
// Change the function signature.
auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
auto newFnType = fnType.clone(convertedTypes, fnType.getResults());
rewriter.modifyOpInPlace(newFuncOp,
[&] { newFuncOp.setFunctionType(newFnType); });
// Update the arguments in the entry block.
entryBlock.eraseArguments(0, fnType.getNumInputs());
SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
entryBlock.addArguments(convertedTypes, locs);
// Replace the placeholder values with the new arguments. We assume there is
// only one block for now.
size_t unrolledInputIdx = 0;
for (auto [count, op] : enumerate(entryBlock.getOperations())) {
// We first look for operands that are placeholders for initially legal
// arguments.
Operation &curOp = op;
for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) {
Operation *operandOp = operandVal.getDefiningOp();
if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
size_t idx = operandIdx;
rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] {
curOp.setOperand(idx, newFuncOp.getArgument(it->second));
});
}
}
// Since all newly created operations are in the beginning, reaching the
// end of them means that any later `vector.insert_strided_slice` should
// not be touched.
if (count >= newOpCount)
continue;
if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
rewriter.modifyOpInPlace(&curOp, [&] {
curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
});
++unrolledInputIdx;
}
}
// Erase the original funcOp. The `tmpOps` do not need to be erased since
// they have no uses and will be handled by dead-code elimination.
rewriter.eraseOp(funcOp);
return success();
}
};
} // namespace
void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {
patterns.add<FuncOpVectorUnroll>(patterns.getContext());
}
//===----------------------------------------------------------------------===//
// func::ReturnOp Conversion Patterns
//===----------------------------------------------------------------------===//
namespace {
/// A pattern for rewriting function signature and the return op to convert
/// vectors to be of valid types.
struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(func::ReturnOp returnOp,
PatternRewriter &rewriter) const override {
// Check whether the parent funcOp is valid.
auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
if (!funcOp)
return failure();
FunctionType fnType = funcOp.getFunctionType();
OneToNTypeMapping oneToNTypeMapping(fnType.getResults());
Location loc = returnOp.getLoc();
// For the new return op.
SmallVector<Value> newOperands;
// Enumerate through the results.
for (auto [origResultNo, origType] : enumerate(fnType.getResults())) {
// Check whether the argument is of vector type.
auto origVecType = dyn_cast<VectorType>(origType);
if (!origVecType) {
oneToNTypeMapping.addInputs(origResultNo, origType);
newOperands.push_back(returnOp.getOperand(origResultNo));
continue;
}
// Check whether the vector needs unrolling.
auto targetShape = getTargetShape(origVecType);
if (!targetShape) {
// The original argument can be used.
oneToNTypeMapping.addInputs(origResultNo, origType);
newOperands.push_back(returnOp.getOperand(origResultNo));
continue;
}
VectorType unrolledType =
VectorType::get(*targetShape, origVecType.getElementType());
// Create `vector.extract_strided_slice` ops to form legal vectors from
// the original operand of illegal type.
auto originalShape =
llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
SmallVector<int64_t> strides(targetShape->size(), 1);
SmallVector<Type> newTypes;
Value returnValue = returnOp.getOperand(origResultNo);
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalShape, *targetShape)) {
Value result = rewriter.create<vector::ExtractStridedSliceOp>(
loc, returnValue, offsets, *targetShape, strides);
newOperands.push_back(result);
newTypes.push_back(unrolledType);
}
oneToNTypeMapping.addInputs(origResultNo, newTypes);
}
// Change the function signature.
auto newFnType =
FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()),
TypeRange(oneToNTypeMapping.getConvertedTypes()));
rewriter.modifyOpInPlace(funcOp,
[&] { funcOp.setFunctionType(newFnType); });
// Replace the return op using the new operands. This will automatically
// update the entry block as well.
rewriter.replaceOp(returnOp,
rewriter.create<func::ReturnOp>(loc, newOperands));
return success();
}
};
} // namespace
void mlir::populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns) {
patterns.add<ReturnOpVectorUnroll>(patterns.getContext());
}
//===----------------------------------------------------------------------===//
// Builtin Variables
//===----------------------------------------------------------------------===//

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt -convert-to-spirv -split-input-file %s | FileCheck %s
// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" -split-input-file %s | FileCheck %s
//===----------------------------------------------------------------------===//
// arithmetic ops

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s
// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
// CHECK-LABEL: @combined
// CHECK: %[[C0_F32:.*]] = spirv.Constant 0.000000e+00 : f32

View File

@ -0,0 +1,147 @@
// RUN: mlir-opt -test-spirv-func-signature-conversion -split-input-file %s | FileCheck %s
// CHECK-LABEL: @simple_scalar
// CHECK-SAME: (%[[ARG0:.+]]: i32)
func.func @simple_scalar(%arg0 : i32) -> i32 {
// CHECK: return %[[ARG0]] : i32
return %arg0 : i32
}
// -----
// CHECK-LABEL: @simple_vector_4
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>)
func.func @simple_vector_4(%arg0 : vector<4xi32>) -> vector<4xi32> {
// CHECK: return %[[ARG0]] : vector<4xi32>
return %arg0 : vector<4xi32>
}
// -----
// CHECK-LABEL: @simple_vector_5
// CHECK-SAME: (%[[ARG0:.+]]: vector<1xi32>, %[[ARG1:.+]]: vector<1xi32>, %[[ARG2:.+]]: vector<1xi32>, %[[ARG3:.+]]: vector<1xi32>, %[[ARG4:.+]]: vector<1xi32>)
func.func @simple_vector_5(%arg0 : vector<5xi32>) -> vector<5xi32> {
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<5xi32>
// CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<1xi32> into vector<5xi32>
// CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [1], strides = [1]} : vector<1xi32> into vector<5xi32>
// CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT1]] {offsets = [2], strides = [1]} : vector<1xi32> into vector<5xi32>
// CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [3], strides = [1]} : vector<1xi32> into vector<5xi32>
// CHECK: %[[INSERT4:.*]] = vector.insert_strided_slice %[[ARG4]], %[[INSERT3]] {offsets = [4], strides = [1]} : vector<1xi32> into vector<5xi32>
// CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [0], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32>
// CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [1], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32>
// CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [2], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32>
// CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [3], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32>
// CHECK: %[[EXTRACT4:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [4], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32>
// CHECK: return %[[EXTRACT0]], %[[EXTRACT1]], %[[EXTRACT2]], %[[EXTRACT3]], %[[EXTRACT4]] : vector<1xi32>, vector<1xi32>, vector<1xi32>, vector<1xi32>, vector<1xi32>
return %arg0 : vector<5xi32>
}
// -----
// CHECK-LABEL: @simple_vector_6
// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>)
func.func @simple_vector_6(%arg0 : vector<6xi32>) -> vector<6xi32> {
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<6xi32>
// CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<3xi32> into vector<6xi32>
// CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [3], strides = [1]} : vector<3xi32> into vector<6xi32>
// CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xi32> to vector<3xi32>
// CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xi32> to vector<3xi32>
// CHECK: return %[[EXTRACT0]], %[[EXTRACT1]] : vector<3xi32>, vector<3xi32>
return %arg0 : vector<6xi32>
}
// -----
// CHECK-LABEL: @simple_vector_8
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>)
func.func @simple_vector_8(%arg0 : vector<8xi32>) -> vector<8xi32> {
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
// CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
// CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
// CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
// CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
// CHECK: return %[[EXTRACT0]], %[[EXTRACT1]] : vector<4xi32>, vector<4xi32>
return %arg0 : vector<8xi32>
}
// -----
// CHECK-LABEL: @vector_6and8
// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>)
func.func @vector_6and8(%arg0 : vector<6xi32>, %arg1 : vector<8xi32>) -> (vector<6xi32>, vector<8xi32>) {
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
// CHECK: %[[CST0:.*]] = arith.constant dense<0> : vector<6xi32>
// CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]] {offsets = [0], strides = [1]} : vector<3xi32> into vector<6xi32>
// CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [3], strides = [1]} : vector<3xi32> into vector<6xi32>
// CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
// CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
// CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xi32> to vector<3xi32>
// CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xi32> to vector<3xi32>
// CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
// CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
// CHECK: return %[[EXTRACT0]], %[[EXTRACT1]], %[[EXTRACT2]], %[[EXTRACT3]] : vector<3xi32>, vector<3xi32>, vector<4xi32>, vector<4xi32>
return %arg0, %arg1 : vector<6xi32>, vector<8xi32>
}
// -----
// CHECK-LABEL: @vector_3and8
// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>)
func.func @vector_3and8(%arg0 : vector<3xi32>, %arg1 : vector<8xi32>) -> (vector<3xi32>, vector<8xi32>) {
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
// CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG1]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
// CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT0]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
// CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
// CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
// CHECK: return %[[ARG0]], %[[EXTRACT0]], %[[EXTRACT1]] : vector<3xi32>, vector<4xi32>, vector<4xi32>
return %arg0, %arg1 : vector<3xi32>, vector<8xi32>
}
// -----
// CHECK-LABEL: @scalar_vector
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<3xi32>, %[[ARG3:.+]]: i32)
func.func @scalar_vector(%arg0 : vector<8xi32>, %arg1 : vector<3xi32>, %arg2 : i32) -> (vector<8xi32>, vector<3xi32>, i32) {
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
// CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
// CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
// CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
// CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32>
// CHECK: return %[[EXTRACT0]], %[[EXTRACT1]], %[[ARG2]], %[[ARG3]] : vector<4xi32>, vector<4xi32>, vector<3xi32>, i32
return %arg0, %arg1, %arg2 : vector<8xi32>, vector<3xi32>, i32
}
// -----
// CHECK-LABEL: @reduction
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>, %[[ARG4:.+]]: i32)
func.func @reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>, %arg2 : i32) -> (i32) {
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
// CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
// CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
// CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32>
// CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32>
// CHECK: %[[ADDI:.*]] = arith.addi %[[INSERT1]], %[[INSERT3]] : vector<8xi32>
// CHECK: %[[REDUCTION:.*]] = vector.reduction <add>, %[[ADDI]] : vector<8xi32> into i32
// CHECK: %[[RET:.*]] = arith.addi %[[REDUCTION]], %[[ARG4]] : i32
// CHECK: return %[[RET]] : i32
%0 = arith.addi %arg0, %arg1 : vector<8xi32>
%1 = vector.reduction <add>, %0 : vector<8xi32> into i32
%2 = arith.addi %1, %arg2 : i32
return %2 : i32
}
// -----
// CHECK-LABEL: func.func private @unsupported_decl(vector<8xi32>)
func.func private @unsupported_decl(vector<8xi32>)
// -----
// CHECK-LABEL: @unsupported_scalable
// CHECK-SAME: (%[[ARG0:.+]]: vector<[8]xi32>)
func.func @unsupported_scalable(%arg0 : vector<[8]xi32>) -> (vector<[8]xi32>) {
// CHECK: return %[[ARG0]] : vector<[8]xi32>
return %arg0 : vector<[8]xi32>
}

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-to-spirv | FileCheck %s
// RUN: mlir-opt %s -convert-to-spirv="run-signature-conversion=false" | FileCheck %s
// CHECK-LABEL: @basic
func.func @basic(%a: index, %b: index) {

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s
// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
// CHECK-LABEL: @if_yield
// CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr<f32, Function>

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s
// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
// CHECK-LABEL: @return_scalar
// CHECK-SAME: %[[ARG0:.*]]: i32

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s
// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
// CHECK-LABEL: @ub
// CHECK: %[[UNDEF:.*]] = spirv.Undef : i32

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt -split-input-file -convert-to-spirv %s | FileCheck %s
// RUN: mlir-opt -split-input-file -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
// CHECK-LABEL: @extract
// CHECK-SAME: %[[ARG:.+]]: vector<2xf32>

View File

@ -1,3 +1,4 @@
add_subdirectory(ConvertToSPIRV)
add_subdirectory(FuncToLLVM)
add_subdirectory(MathToVCIX)
add_subdirectory(OneToNTypeConversion)

View File

@ -0,0 +1,16 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestConvertToSPIRV
TestSPIRVFuncSignatureConversion.cpp
EXCLUDE_FROM_LIBMLIR
LINK_LIBS PUBLIC
MLIRArithDialect
MLIRFuncDialect
MLIRPass
MLIRSPIRVConversion
MLIRSPIRVDialect
MLIRTransformUtils
MLIRTransforms
MLIRVectorDialect
)

View File

@ -0,0 +1,57 @@
//===- TestSPIRVFuncSignatureConversion.cpp - Test signature conversion -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===-------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace {
struct TestSPIRVFuncSignatureConversion final
: PassWrapper<TestSPIRVFuncSignatureConversion, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSPIRVFuncSignatureConversion)
StringRef getArgument() const final {
return "test-spirv-func-signature-conversion";
}
StringRef getDescription() const final {
return "Test patterns that convert vector inputs and results in function "
"signatures";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect, func::FuncDialect, spirv::SPIRVDialect,
vector::VectorDialect>();
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateFuncOpVectorRewritePatterns(patterns);
populateReturnOpVectorRewritePatterns(patterns);
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config);
}
};
} // namespace
namespace test {
void registerTestSPIRVFuncSignatureConversion() {
PassRegistration<TestSPIRVFuncSignatureConversion>();
}
} // namespace test
} // namespace mlir

View File

@ -36,6 +36,7 @@ if(MLIR_INCLUDE_TESTS)
MLIRSPIRVTestPasses
MLIRTensorTestPasses
MLIRTestAnalysis
MLIRTestConvertToSPIRV
MLIRTestDialect
MLIRTestDynDialect
MLIRTestIR

View File

@ -141,6 +141,7 @@ void registerTestSCFWhileOpBuilderPass();
void registerTestSCFWrapInZeroTripCheckPasses();
void registerTestShapeMappingPass();
void registerTestSliceAnalysisPass();
void registerTestSPIRVFuncSignatureConversion();
void registerTestTensorCopyInsertionPass();
void registerTestTensorTransforms();
void registerTestTopologicalSortAnalysisPass();
@ -273,6 +274,7 @@ void registerTestPasses() {
mlir::test::registerTestSCFWrapInZeroTripCheckPasses();
mlir::test::registerTestShapeMappingPass();
mlir::test::registerTestSliceAnalysisPass();
mlir::test::registerTestSPIRVFuncSignatureConversion();
mlir::test::registerTestTensorCopyInsertionPass();
mlir::test::registerTestTensorTransforms();
mlir::test::registerTestTopologicalSortAnalysisPass();

View File

@ -7207,10 +7207,15 @@ cc_library(
hdrs = ["include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"],
includes = ["include"],
deps = [
":ArithDialect",
":DialectUtils",
":FuncDialect",
":IR",
":SPIRVDialect",
":Support",
":TransformUtils",
":VectorDialect",
":VectorTransforms",
"//llvm:Support",
],
)
@ -9588,6 +9593,7 @@ cc_binary(
"//mlir/test:TestArmSME",
"//mlir/test:TestBufferization",
"//mlir/test:TestControlFlow",
"//mlir/test:TestConvertToSPIRV",
"//mlir/test:TestDLTI",
"//mlir/test:TestDialect",
"//mlir/test:TestFunc",

View File

@ -656,6 +656,21 @@ cc_library(
],
)
cc_library(
name = "TestConvertToSPIRV",
srcs = glob(["lib/Conversion/ConvertToSPIRV/*.cpp"]),
deps = [
"//mlir:ArithDialect",
"//mlir:FuncDialect",
"//mlir:Pass",
"//mlir:SPIRVConversion",
"//mlir:SPIRVDialect",
"//mlir:TransformUtils",
"//mlir:Transforms",
"//mlir:VectorDialect",
],
)
cc_library(
name = "TestAffine",
srcs = glob([