From 85742f764270c701d2245615c590702c5110b030 Mon Sep 17 00:00:00 2001 From: Matthias Springer <me@m-sp.org> Date: Thu, 10 Apr 2025 10:36:21 +0200 Subject: [PATCH] [mlir][LLVM] Delete `getFixedVectorType` and `getScalableVectorType` (#135051) The LLVM dialect no longer has its own vector types. It uses `mlir::VectorType` everywhere. Remove `LLVM::getFixedVectorType/getScalableVectorType` and use `VectorType::get` instead. This commit addresses a [comment](https://github.com/llvm/llvm-project/pull/133286#discussion_r2022192500) on the PR that deleted the LLVM vector types. --- mlir/docs/Dialects/LLVM.md | 4 --- mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h | 8 ----- .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 33 +++++++++---------- mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp | 12 ------- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 13 +++++--- mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp | 24 ++++++-------- mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp | 9 ++--- 7 files changed, 39 insertions(+), 64 deletions(-) diff --git a/mlir/docs/Dialects/LLVM.md b/mlir/docs/Dialects/LLVM.md index 468f69c41907..4b5d518ca4ea 100644 --- a/mlir/docs/Dialects/LLVM.md +++ b/mlir/docs/Dialects/LLVM.md @@ -336,10 +336,6 @@ compatible with the LLVM dialect: vector type compatible with the LLVM dialect; - `llvm::ElementCount LLVM::getVectorNumElements(Type)` - returns the number of elements in any vector type compatible with the LLVM dialect; -- `Type LLVM::getFixedVectorType(Type, unsigned)` - gets a fixed vector type - with the given element type and size; the resulting type is either a - built-in or an LLVM dialect vector type depending on which one supports the - given element type. #### Examples of Compatible Vector Types diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h index a2a76c49a2bd..17561f79d135 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -126,14 +126,6 @@ Type getVectorType(Type elementType, unsigned numElements, /// and length. Type getVectorType(Type elementType, const llvm::ElementCount &numElements); -/// Creates an LLVM dialect-compatible type with the given element type and -/// length. -Type getFixedVectorType(Type elementType, unsigned numElements); - -/// Creates an LLVM dialect-compatible type with the given element type and -/// length. -Type getScalableVectorType(Type elementType, unsigned numElements); - /// Returns the size of the given primitive LLVM dialect-compatible type /// (including vectors) in bits, for example, the size of i16 is 16 and /// the size of vector<4xi16> is 64. Returns 0 for non-primitive diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 51507c6507b6..69fa62c8196e 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -61,13 +61,13 @@ static Value truncToI32(ImplicitLocOpBuilder &b, Value value) { static Type inferIntrinsicResultType(Type vectorResultType) { MLIRContext *ctx = vectorResultType.getContext(); auto a = cast<LLVM::LLVMArrayType>(vectorResultType); - auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2); + auto f16x2Ty = VectorType::get(2, Float16Type::get(ctx)); auto i32Ty = IntegerType::get(ctx, 32); - auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); + auto i32x2Ty = VectorType::get(2, i32Ty); Type f64Ty = Float64Type::get(ctx); - Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); + Type f64x2Ty = VectorType::get(2, f64Ty); Type f32Ty = Float32Type::get(ctx); - Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2); + Type f32x2Ty = VectorType::get(2, f32Ty); if (a.getElementType() == f16x2Ty) { return LLVM::LLVMStructType::getLiteral( ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty)); @@ -85,7 +85,7 @@ static Type inferIntrinsicResultType(Type vectorResultType) { ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty)); } - if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) { + if (a.getElementType() == VectorType::get(1, f32Ty)) { return LLVM::LLVMStructType::getLiteral( ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty)); } @@ -106,11 +106,11 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type i32Ty = rewriter.getI32Type(); Type f32Ty = rewriter.getF32Type(); Type f64Ty = rewriter.getF64Type(); - Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2); - Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); - Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); - Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2); - Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1); + Type f16x2Ty = VectorType::get(2, rewriter.getF16Type()); + Type i32x2Ty = VectorType::get(2, i32Ty); + Type f64x2Ty = VectorType::get(2, f64Ty); + Type f32x2Ty = VectorType::get(2, f32Ty); + Type f32x1Ty = VectorType::get(1, f32Ty); auto makeConst = [&](int32_t index) -> Value { return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32), @@ -181,9 +181,9 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b, Type f64Ty = b.getF64Type(); Type f32Ty = b.getF32Type(); Type i64Ty = b.getI64Type(); - Type i8x4Ty = LLVM::getFixedVectorType(b.getI8Type(), 4); - Type i4x8Ty = LLVM::getFixedVectorType(b.getIntegerType(4), 8); - Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1); + Type i8x4Ty = VectorType::get(4, b.getI8Type()); + Type i4x8Ty = VectorType::get(8, b.getIntegerType(4)); + Type f32x1Ty = VectorType::get(1, f32Ty); auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType()); for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) { @@ -268,8 +268,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> { if (!vectorResultType) { return failure(); } - Type innerVectorType = LLVM::getFixedVectorType( - vectorResultType.getElementType(), vectorResultType.getDimSize(1)); + Type innerVectorType = VectorType::get(vectorResultType.getDimSize(1), + vectorResultType.getElementType()); int64_t num32BitRegs = vectorResultType.getDimSize(0); @@ -627,8 +627,7 @@ struct NVGPUMmaSparseSyncLowering // Bitcast the sparse metadata from vector<2xf16> to an i32. Value sparseMetadata = adaptor.getSparseMetadata(); - if (sparseMetadata.getType() != - LLVM::getFixedVectorType(rewriter.getI16Type(), 2)) + if (sparseMetadata.getType() != VectorType::get(2, rewriter.getI16Type())) return op->emitOpError() << "Expected metadata type to be LLVM " "VectorType of 2 i16 elements"; sparseMetadata = diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index b3c2a2930952..29cf38c1fefe 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -851,18 +851,6 @@ Type mlir::LLVM::getVectorType(Type elementType, /*isScalable=*/false); } -Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) { - assert(VectorType::isValidElementType(elementType) && - "incompatible element type"); - return VectorType::get(numElements, elementType); -} - -Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) { - // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as - // scalable/non-scalable. - return VectorType::get(numElements, elementType, /*scalableDims=*/true); -} - llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) { assert(isCompatibleType(type) && "expected a type compatible with the LLVM dialect"); diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 09bff6101edd..593283f14696 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -144,7 +144,7 @@ LogicalResult BulkStoreOp::verify() { std::optional<mlir::NVVM::MMATypes> MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) { auto half2Type = - LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2); + VectorType::get(2, Float16Type::get(operandElType.getContext())); if (operandElType.isF64()) return NVVM::MMATypes::f64; if (operandElType.isF16() || operandElType == half2Type) @@ -243,7 +243,8 @@ void MmaOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames); // Print the types of the operands and result. - p << " : " << "("; + p << " : " + << "("; llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(), frags[1].regs[0].getType(), frags[2].regs[0].getType()}, @@ -404,7 +405,7 @@ LogicalResult MmaOp::verify() { MLIRContext *context = getContext(); auto f16Ty = Float16Type::get(context); auto i32Ty = IntegerType::get(context, 32); - auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2); + auto f16x2Ty = VectorType::get(2, f16Ty); auto f32Ty = Float32Type::get(context); auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral( context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); @@ -506,7 +507,7 @@ LogicalResult MmaOp::verify() { expectedA.emplace_back(1, f64Ty); expectedB.emplace_back(1, f64Ty); expectedC.emplace_back(2, f64Ty); - // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2)); + // expectedC.emplace_back(1, VectorType::get(2, f64Ty)); expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral( context, SmallVector<Type>(2, f64Ty))); allowedShapes.push_back({8, 8, 4}); @@ -992,7 +993,9 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() { ss << "},"; // Need to map read/write registers correctly. regCnt = (regCnt * 2); - ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p"; + ss << " $" << (regCnt) << "," + << " $" << (regCnt + 1) << "," + << " p"; if (getTypeD() != WGMMATypes::s32) { ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4); } diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp index 39cca7d363e0..e80360aa08ed 100644 --- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp +++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp @@ -103,16 +103,15 @@ nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) { Type elType = type.vectorType.getElementType(); if (elType.isF16()) { - return FragmentElementInfo{ - LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32, - inferNumRegistersPerMatrixFragment(type)}; + return FragmentElementInfo{VectorType::get(2, Float16Type::get(ctx)), 2, 32, + inferNumRegistersPerMatrixFragment(type)}; } // f64 operand Type f64Ty = Float64Type::get(ctx); if (elType.isF64()) { return isAccum - ? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128, + ? FragmentElementInfo{VectorType::get(2, f64Ty), 2, 128, inferNumRegistersPerMatrixFragment(type)} : FragmentElementInfo{f64Ty, 1, 64, inferNumRegistersPerMatrixFragment(type)}; @@ -120,30 +119,27 @@ nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) { // int8 operand if (elType.isInteger(8)) { - return FragmentElementInfo{ - LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32, - inferNumRegistersPerMatrixFragment(type)}; + return FragmentElementInfo{VectorType::get(4, IntegerType::get(ctx, 8)), 4, + 32, inferNumRegistersPerMatrixFragment(type)}; } // int4 operand if (elType.isInteger(4)) { - return FragmentElementInfo{ - LLVM::getFixedVectorType(IntegerType::get(ctx, 4), 8), 8, 32, - inferNumRegistersPerMatrixFragment(type)}; + return FragmentElementInfo{VectorType::get(8, IntegerType::get(ctx, 4)), 8, + 32, inferNumRegistersPerMatrixFragment(type)}; } // Integer 32bit acc operands if (elType.isInteger(32)) { - return FragmentElementInfo{ - LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64, - inferNumRegistersPerMatrixFragment(type)}; + return FragmentElementInfo{VectorType::get(2, IntegerType::get(ctx, 32)), 2, + 64, inferNumRegistersPerMatrixFragment(type)}; } // Floating point 32bit operands if (elType.isF32()) { Type f32Ty = Float32Type::get(ctx); return isAccum - ? FragmentElementInfo{LLVM::getFixedVectorType(f32Ty, 2), 2, 64, + ? FragmentElementInfo{VectorType::get(2, f32Ty), 2, 64, inferNumRegistersPerMatrixFragment(type)} : FragmentElementInfo{f32Ty, 1, 32, inferNumRegistersPerMatrixFragment(type)}; diff --git a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp index bc9765fff295..c46aa3e80d51 100644 --- a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp +++ b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp @@ -124,14 +124,15 @@ private: /// Translates the given fixed-vector type. Type translate(llvm::FixedVectorType *type) { - return LLVM::getFixedVectorType(translateType(type->getElementType()), - type->getNumElements()); + return VectorType::get(type->getNumElements(), + translateType(type->getElementType())); } /// Translates the given scalable-vector type. Type translate(llvm::ScalableVectorType *type) { - return LLVM::getScalableVectorType(translateType(type->getElementType()), - type->getMinNumElements()); + return VectorType::get(type->getMinNumElements(), + translateType(type->getElementType()), + /*scalable=*/true); } /// Translates the given target extension type.