From 85742f764270c701d2245615c590702c5110b030 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me@m-sp.org>
Date: Thu, 10 Apr 2025 10:36:21 +0200
Subject: [PATCH] [mlir][LLVM] Delete `getFixedVectorType` and
 `getScalableVectorType` (#135051)

The LLVM dialect no longer has its own vector types. It uses
`mlir::VectorType` everywhere. Remove
`LLVM::getFixedVectorType/getScalableVectorType` and use
`VectorType::get` instead. This commit addresses a
[comment](https://github.com/llvm/llvm-project/pull/133286#discussion_r2022192500)
on the PR that deleted the LLVM vector types.
---
 mlir/docs/Dialects/LLVM.md                    |  4 ---
 mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h  |  8 -----
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    | 33 +++++++++----------
 mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp      | 12 -------
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    | 13 +++++---
 mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp     | 24 ++++++--------
 mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp       |  9 ++---
 7 files changed, 39 insertions(+), 64 deletions(-)

diff --git a/mlir/docs/Dialects/LLVM.md b/mlir/docs/Dialects/LLVM.md
index 468f69c41907..4b5d518ca4ea 100644
--- a/mlir/docs/Dialects/LLVM.md
+++ b/mlir/docs/Dialects/LLVM.md
@@ -336,10 +336,6 @@ compatible with the LLVM dialect:
     vector type compatible with the LLVM dialect;
 -   `llvm::ElementCount LLVM::getVectorNumElements(Type)` - returns the number
     of elements in any vector type compatible with the LLVM dialect;
--   `Type LLVM::getFixedVectorType(Type, unsigned)` - gets a fixed vector type
-    with the given element type and size; the resulting type is either a
-    built-in or an LLVM dialect vector type depending on which one supports the
-    given element type.
 
 #### Examples of Compatible Vector Types
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index a2a76c49a2bd..17561f79d135 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -126,14 +126,6 @@ Type getVectorType(Type elementType, unsigned numElements,
 /// and length.
 Type getVectorType(Type elementType, const llvm::ElementCount &numElements);
 
-/// Creates an LLVM dialect-compatible type with the given element type and
-/// length.
-Type getFixedVectorType(Type elementType, unsigned numElements);
-
-/// Creates an LLVM dialect-compatible type with the given element type and
-/// length.
-Type getScalableVectorType(Type elementType, unsigned numElements);
-
 /// Returns the size of the given primitive LLVM dialect-compatible type
 /// (including vectors) in bits, for example, the size of i16 is 16 and
 /// the size of vector<4xi16> is 64. Returns 0 for non-primitive
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 51507c6507b6..69fa62c8196e 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -61,13 +61,13 @@ static Value truncToI32(ImplicitLocOpBuilder &b, Value value) {
 static Type inferIntrinsicResultType(Type vectorResultType) {
   MLIRContext *ctx = vectorResultType.getContext();
   auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
-  auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
+  auto f16x2Ty = VectorType::get(2, Float16Type::get(ctx));
   auto i32Ty = IntegerType::get(ctx, 32);
-  auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
+  auto i32x2Ty = VectorType::get(2, i32Ty);
   Type f64Ty = Float64Type::get(ctx);
-  Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
+  Type f64x2Ty = VectorType::get(2, f64Ty);
   Type f32Ty = Float32Type::get(ctx);
-  Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
+  Type f32x2Ty = VectorType::get(2, f32Ty);
   if (a.getElementType() == f16x2Ty) {
     return LLVM::LLVMStructType::getLiteral(
         ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
@@ -85,7 +85,7 @@ static Type inferIntrinsicResultType(Type vectorResultType) {
         ctx,
         SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
   }
-  if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) {
+  if (a.getElementType() == VectorType::get(1, f32Ty)) {
     return LLVM::LLVMStructType::getLiteral(
         ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
   }
@@ -106,11 +106,11 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
   Type i32Ty = rewriter.getI32Type();
   Type f32Ty = rewriter.getF32Type();
   Type f64Ty = rewriter.getF64Type();
-  Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
-  Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
-  Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
-  Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
-  Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
+  Type f16x2Ty = VectorType::get(2, rewriter.getF16Type());
+  Type i32x2Ty = VectorType::get(2, i32Ty);
+  Type f64x2Ty = VectorType::get(2, f64Ty);
+  Type f32x2Ty = VectorType::get(2, f32Ty);
+  Type f32x1Ty = VectorType::get(1, f32Ty);
 
   auto makeConst = [&](int32_t index) -> Value {
     return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
@@ -181,9 +181,9 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
   Type f64Ty = b.getF64Type();
   Type f32Ty = b.getF32Type();
   Type i64Ty = b.getI64Type();
-  Type i8x4Ty = LLVM::getFixedVectorType(b.getI8Type(), 4);
-  Type i4x8Ty = LLVM::getFixedVectorType(b.getIntegerType(4), 8);
-  Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
+  Type i8x4Ty = VectorType::get(4, b.getI8Type());
+  Type i4x8Ty = VectorType::get(8, b.getIntegerType(4));
+  Type f32x1Ty = VectorType::get(1, f32Ty);
   auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
 
   for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
@@ -268,8 +268,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
     if (!vectorResultType) {
       return failure();
     }
-    Type innerVectorType = LLVM::getFixedVectorType(
-        vectorResultType.getElementType(), vectorResultType.getDimSize(1));
+    Type innerVectorType = VectorType::get(vectorResultType.getDimSize(1),
+                                           vectorResultType.getElementType());
 
     int64_t num32BitRegs = vectorResultType.getDimSize(0);
 
@@ -627,8 +627,7 @@ struct NVGPUMmaSparseSyncLowering
 
     // Bitcast the sparse metadata from vector<2xf16> to an i32.
     Value sparseMetadata = adaptor.getSparseMetadata();
-    if (sparseMetadata.getType() !=
-        LLVM::getFixedVectorType(rewriter.getI16Type(), 2))
+    if (sparseMetadata.getType() != VectorType::get(2, rewriter.getI16Type()))
       return op->emitOpError() << "Expected metadata type to be LLVM "
                                   "VectorType of 2 i16 elements";
     sparseMetadata =
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index b3c2a2930952..29cf38c1fefe 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -851,18 +851,6 @@ Type mlir::LLVM::getVectorType(Type elementType,
                        /*isScalable=*/false);
 }
 
-Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) {
-  assert(VectorType::isValidElementType(elementType) &&
-         "incompatible element type");
-  return VectorType::get(numElements, elementType);
-}
-
-Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) {
-  // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
-  // scalable/non-scalable.
-  return VectorType::get(numElements, elementType, /*scalableDims=*/true);
-}
-
 llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
   assert(isCompatibleType(type) &&
          "expected a type compatible with the LLVM dialect");
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 09bff6101edd..593283f14696 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -144,7 +144,7 @@ LogicalResult BulkStoreOp::verify() {
 std::optional<mlir::NVVM::MMATypes>
 MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
   auto half2Type =
-      LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2);
+      VectorType::get(2, Float16Type::get(operandElType.getContext()));
   if (operandElType.isF64())
     return NVVM::MMATypes::f64;
   if (operandElType.isF16() || operandElType == half2Type)
@@ -243,7 +243,8 @@ void MmaOp::print(OpAsmPrinter &p) {
   p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
 
   // Print the types of the operands and result.
-  p << " : " << "(";
+  p << " : "
+    << "(";
   llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
                                              frags[1].regs[0].getType(),
                                              frags[2].regs[0].getType()},
@@ -404,7 +405,7 @@ LogicalResult MmaOp::verify() {
   MLIRContext *context = getContext();
   auto f16Ty = Float16Type::get(context);
   auto i32Ty = IntegerType::get(context, 32);
-  auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
+  auto f16x2Ty = VectorType::get(2, f16Ty);
   auto f32Ty = Float32Type::get(context);
   auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
       context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
@@ -506,7 +507,7 @@ LogicalResult MmaOp::verify() {
       expectedA.emplace_back(1, f64Ty);
       expectedB.emplace_back(1, f64Ty);
       expectedC.emplace_back(2, f64Ty);
-      // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
+      // expectedC.emplace_back(1, VectorType::get(2, f64Ty));
       expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
           context, SmallVector<Type>(2, f64Ty)));
       allowedShapes.push_back({8, 8, 4});
@@ -992,7 +993,9 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
   ss << "},";
   // Need to map read/write registers correctly.
   regCnt = (regCnt * 2);
-  ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
+  ss << " $" << (regCnt) << ","
+     << " $" << (regCnt + 1) << ","
+     << " p";
   if (getTypeD() != WGMMATypes::s32) {
     ss << ", $" << (regCnt + 3) << ",  $" << (regCnt + 4);
   }
diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
index 39cca7d363e0..e80360aa08ed 100644
--- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
+++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
@@ -103,16 +103,15 @@ nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) {
 
   Type elType = type.vectorType.getElementType();
   if (elType.isF16()) {
-    return FragmentElementInfo{
-        LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32,
-        inferNumRegistersPerMatrixFragment(type)};
+    return FragmentElementInfo{VectorType::get(2, Float16Type::get(ctx)), 2, 32,
+                               inferNumRegistersPerMatrixFragment(type)};
   }
 
   // f64 operand
   Type f64Ty = Float64Type::get(ctx);
   if (elType.isF64()) {
     return isAccum
-               ? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128,
+               ? FragmentElementInfo{VectorType::get(2, f64Ty), 2, 128,
                                      inferNumRegistersPerMatrixFragment(type)}
                : FragmentElementInfo{f64Ty, 1, 64,
                                      inferNumRegistersPerMatrixFragment(type)};
@@ -120,30 +119,27 @@ nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) {
 
   // int8 operand
   if (elType.isInteger(8)) {
-    return FragmentElementInfo{
-        LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32,
-        inferNumRegistersPerMatrixFragment(type)};
+    return FragmentElementInfo{VectorType::get(4, IntegerType::get(ctx, 8)), 4,
+                               32, inferNumRegistersPerMatrixFragment(type)};
   }
 
   // int4 operand
   if (elType.isInteger(4)) {
-    return FragmentElementInfo{
-        LLVM::getFixedVectorType(IntegerType::get(ctx, 4), 8), 8, 32,
-        inferNumRegistersPerMatrixFragment(type)};
+    return FragmentElementInfo{VectorType::get(8, IntegerType::get(ctx, 4)), 8,
+                               32, inferNumRegistersPerMatrixFragment(type)};
   }
 
   // Integer 32bit acc operands
   if (elType.isInteger(32)) {
-    return FragmentElementInfo{
-        LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64,
-        inferNumRegistersPerMatrixFragment(type)};
+    return FragmentElementInfo{VectorType::get(2, IntegerType::get(ctx, 32)), 2,
+                               64, inferNumRegistersPerMatrixFragment(type)};
   }
 
   // Floating point 32bit operands
   if (elType.isF32()) {
     Type f32Ty = Float32Type::get(ctx);
     return isAccum
-               ? FragmentElementInfo{LLVM::getFixedVectorType(f32Ty, 2), 2, 64,
+               ? FragmentElementInfo{VectorType::get(2, f32Ty), 2, 64,
                                      inferNumRegistersPerMatrixFragment(type)}
                : FragmentElementInfo{f32Ty, 1, 32,
                                      inferNumRegistersPerMatrixFragment(type)};
diff --git a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
index bc9765fff295..c46aa3e80d51 100644
--- a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
+++ b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
@@ -124,14 +124,15 @@ private:
 
   /// Translates the given fixed-vector type.
   Type translate(llvm::FixedVectorType *type) {
-    return LLVM::getFixedVectorType(translateType(type->getElementType()),
-                                    type->getNumElements());
+    return VectorType::get(type->getNumElements(),
+                           translateType(type->getElementType()));
   }
 
   /// Translates the given scalable-vector type.
   Type translate(llvm::ScalableVectorType *type) {
-    return LLVM::getScalableVectorType(translateType(type->getElementType()),
-                                       type->getMinNumElements());
+    return VectorType::get(type->getMinNumElements(),
+                           translateType(type->getElementType()),
+                           /*scalable=*/true);
   }
 
   /// Translates the given target extension type.