[mlir][LLVM] Delete getFixedVectorType and getScalableVectorType (#135051)

The LLVM dialect no longer has its own vector types. It uses
`mlir::VectorType` everywhere. Remove
`LLVM::getFixedVectorType/getScalableVectorType` and use
`VectorType::get` instead. This commit addresses a
[comment](https://github.com/llvm/llvm-project/pull/133286#discussion_r2022192500)
on the PR that deleted the LLVM vector types.
This commit is contained in:
Matthias Springer 2025-04-10 10:36:21 +02:00 committed by GitHub
parent 923da2b843
commit 85742f7642
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 39 additions and 64 deletions

View File

@ -336,10 +336,6 @@ compatible with the LLVM dialect:
vector type compatible with the LLVM dialect;
- `llvm::ElementCount LLVM::getVectorNumElements(Type)` - returns the number
of elements in any vector type compatible with the LLVM dialect;
- `Type LLVM::getFixedVectorType(Type, unsigned)` - gets a fixed vector type
with the given element type and size; the resulting type is either a
built-in or an LLVM dialect vector type depending on which one supports the
given element type.
#### Examples of Compatible Vector Types

View File

@ -126,14 +126,6 @@ Type getVectorType(Type elementType, unsigned numElements,
/// and length.
Type getVectorType(Type elementType, const llvm::ElementCount &numElements);
/// Creates an LLVM dialect-compatible type with the given element type and
/// length.
Type getFixedVectorType(Type elementType, unsigned numElements);
/// Creates an LLVM dialect-compatible type with the given element type and
/// length.
Type getScalableVectorType(Type elementType, unsigned numElements);
/// Returns the size of the given primitive LLVM dialect-compatible type
/// (including vectors) in bits, for example, the size of i16 is 16 and
/// the size of vector<4xi16> is 64. Returns 0 for non-primitive

View File

@ -61,13 +61,13 @@ static Value truncToI32(ImplicitLocOpBuilder &b, Value value) {
static Type inferIntrinsicResultType(Type vectorResultType) {
MLIRContext *ctx = vectorResultType.getContext();
auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
auto f16x2Ty = VectorType::get(2, Float16Type::get(ctx));
auto i32Ty = IntegerType::get(ctx, 32);
auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
auto i32x2Ty = VectorType::get(2, i32Ty);
Type f64Ty = Float64Type::get(ctx);
Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
Type f64x2Ty = VectorType::get(2, f64Ty);
Type f32Ty = Float32Type::get(ctx);
Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
Type f32x2Ty = VectorType::get(2, f32Ty);
if (a.getElementType() == f16x2Ty) {
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
@ -85,7 +85,7 @@ static Type inferIntrinsicResultType(Type vectorResultType) {
ctx,
SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
}
if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) {
if (a.getElementType() == VectorType::get(1, f32Ty)) {
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
}
@ -106,11 +106,11 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
Type i32Ty = rewriter.getI32Type();
Type f32Ty = rewriter.getF32Type();
Type f64Ty = rewriter.getF64Type();
Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
Type f16x2Ty = VectorType::get(2, rewriter.getF16Type());
Type i32x2Ty = VectorType::get(2, i32Ty);
Type f64x2Ty = VectorType::get(2, f64Ty);
Type f32x2Ty = VectorType::get(2, f32Ty);
Type f32x1Ty = VectorType::get(1, f32Ty);
auto makeConst = [&](int32_t index) -> Value {
return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
@ -181,9 +181,9 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
Type f64Ty = b.getF64Type();
Type f32Ty = b.getF32Type();
Type i64Ty = b.getI64Type();
Type i8x4Ty = LLVM::getFixedVectorType(b.getI8Type(), 4);
Type i4x8Ty = LLVM::getFixedVectorType(b.getIntegerType(4), 8);
Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
Type i8x4Ty = VectorType::get(4, b.getI8Type());
Type i4x8Ty = VectorType::get(8, b.getIntegerType(4));
Type f32x1Ty = VectorType::get(1, f32Ty);
auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
@ -268,8 +268,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
if (!vectorResultType) {
return failure();
}
Type innerVectorType = LLVM::getFixedVectorType(
vectorResultType.getElementType(), vectorResultType.getDimSize(1));
Type innerVectorType = VectorType::get(vectorResultType.getDimSize(1),
vectorResultType.getElementType());
int64_t num32BitRegs = vectorResultType.getDimSize(0);
@ -627,8 +627,7 @@ struct NVGPUMmaSparseSyncLowering
// Bitcast the sparse metadata from vector<2xf16> to an i32.
Value sparseMetadata = adaptor.getSparseMetadata();
if (sparseMetadata.getType() !=
LLVM::getFixedVectorType(rewriter.getI16Type(), 2))
if (sparseMetadata.getType() != VectorType::get(2, rewriter.getI16Type()))
return op->emitOpError() << "Expected metadata type to be LLVM "
"VectorType of 2 i16 elements";
sparseMetadata =

View File

@ -851,18 +851,6 @@ Type mlir::LLVM::getVectorType(Type elementType,
/*isScalable=*/false);
}
Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) {
assert(VectorType::isValidElementType(elementType) &&
"incompatible element type");
return VectorType::get(numElements, elementType);
}
Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) {
// LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
// scalable/non-scalable.
return VectorType::get(numElements, elementType, /*scalableDims=*/true);
}
llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
assert(isCompatibleType(type) &&
"expected a type compatible with the LLVM dialect");

View File

@ -144,7 +144,7 @@ LogicalResult BulkStoreOp::verify() {
std::optional<mlir::NVVM::MMATypes>
MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
auto half2Type =
LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2);
VectorType::get(2, Float16Type::get(operandElType.getContext()));
if (operandElType.isF64())
return NVVM::MMATypes::f64;
if (operandElType.isF16() || operandElType == half2Type)
@ -243,7 +243,8 @@ void MmaOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
// Print the types of the operands and result.
p << " : " << "(";
p << " : "
<< "(";
llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
frags[1].regs[0].getType(),
frags[2].regs[0].getType()},
@ -404,7 +405,7 @@ LogicalResult MmaOp::verify() {
MLIRContext *context = getContext();
auto f16Ty = Float16Type::get(context);
auto i32Ty = IntegerType::get(context, 32);
auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
auto f16x2Ty = VectorType::get(2, f16Ty);
auto f32Ty = Float32Type::get(context);
auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
@ -506,7 +507,7 @@ LogicalResult MmaOp::verify() {
expectedA.emplace_back(1, f64Ty);
expectedB.emplace_back(1, f64Ty);
expectedC.emplace_back(2, f64Ty);
// expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
// expectedC.emplace_back(1, VectorType::get(2, f64Ty));
expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
context, SmallVector<Type>(2, f64Ty)));
allowedShapes.push_back({8, 8, 4});
@ -992,7 +993,9 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
ss << "},";
// Need to map read/write registers correctly.
regCnt = (regCnt * 2);
ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
ss << " $" << (regCnt) << ","
<< " $" << (regCnt + 1) << ","
<< " p";
if (getTypeD() != WGMMATypes::s32) {
ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
}

View File

@ -103,16 +103,15 @@ nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) {
Type elType = type.vectorType.getElementType();
if (elType.isF16()) {
return FragmentElementInfo{
LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32,
inferNumRegistersPerMatrixFragment(type)};
return FragmentElementInfo{VectorType::get(2, Float16Type::get(ctx)), 2, 32,
inferNumRegistersPerMatrixFragment(type)};
}
// f64 operand
Type f64Ty = Float64Type::get(ctx);
if (elType.isF64()) {
return isAccum
? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128,
? FragmentElementInfo{VectorType::get(2, f64Ty), 2, 128,
inferNumRegistersPerMatrixFragment(type)}
: FragmentElementInfo{f64Ty, 1, 64,
inferNumRegistersPerMatrixFragment(type)};
@ -120,30 +119,27 @@ nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) {
// int8 operand
if (elType.isInteger(8)) {
return FragmentElementInfo{
LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32,
inferNumRegistersPerMatrixFragment(type)};
return FragmentElementInfo{VectorType::get(4, IntegerType::get(ctx, 8)), 4,
32, inferNumRegistersPerMatrixFragment(type)};
}
// int4 operand
if (elType.isInteger(4)) {
return FragmentElementInfo{
LLVM::getFixedVectorType(IntegerType::get(ctx, 4), 8), 8, 32,
inferNumRegistersPerMatrixFragment(type)};
return FragmentElementInfo{VectorType::get(8, IntegerType::get(ctx, 4)), 8,
32, inferNumRegistersPerMatrixFragment(type)};
}
// Integer 32bit acc operands
if (elType.isInteger(32)) {
return FragmentElementInfo{
LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64,
inferNumRegistersPerMatrixFragment(type)};
return FragmentElementInfo{VectorType::get(2, IntegerType::get(ctx, 32)), 2,
64, inferNumRegistersPerMatrixFragment(type)};
}
// Floating point 32bit operands
if (elType.isF32()) {
Type f32Ty = Float32Type::get(ctx);
return isAccum
? FragmentElementInfo{LLVM::getFixedVectorType(f32Ty, 2), 2, 64,
? FragmentElementInfo{VectorType::get(2, f32Ty), 2, 64,
inferNumRegistersPerMatrixFragment(type)}
: FragmentElementInfo{f32Ty, 1, 32,
inferNumRegistersPerMatrixFragment(type)};

View File

@ -124,14 +124,15 @@ private:
/// Translates the given fixed-vector type.
Type translate(llvm::FixedVectorType *type) {
return LLVM::getFixedVectorType(translateType(type->getElementType()),
type->getNumElements());
return VectorType::get(type->getNumElements(),
translateType(type->getElementType()));
}
/// Translates the given scalable-vector type.
Type translate(llvm::ScalableVectorType *type) {
return LLVM::getScalableVectorType(translateType(type->getElementType()),
type->getMinNumElements());
return VectorType::get(type->getMinNumElements(),
translateType(type->getElementType()),
/*scalable=*/true);
}
/// Translates the given target extension type.