mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-14 15:46:32 +00:00
[mlir][LLVM] Delete getFixedVectorType
and getScalableVectorType
(#135051)
The LLVM dialect no longer has its own vector types. It uses `mlir::VectorType` everywhere. Remove `LLVM::getFixedVectorType/getScalableVectorType` and use `VectorType::get` instead. This commit addresses a [comment](https://github.com/llvm/llvm-project/pull/133286#discussion_r2022192500) on the PR that deleted the LLVM vector types.
This commit is contained in:
parent
923da2b843
commit
85742f7642
@ -336,10 +336,6 @@ compatible with the LLVM dialect:
|
||||
vector type compatible with the LLVM dialect;
|
||||
- `llvm::ElementCount LLVM::getVectorNumElements(Type)` - returns the number
|
||||
of elements in any vector type compatible with the LLVM dialect;
|
||||
- `Type LLVM::getFixedVectorType(Type, unsigned)` - gets a fixed vector type
|
||||
with the given element type and size; the resulting type is either a
|
||||
built-in or an LLVM dialect vector type depending on which one supports the
|
||||
given element type.
|
||||
|
||||
#### Examples of Compatible Vector Types
|
||||
|
||||
|
@ -126,14 +126,6 @@ Type getVectorType(Type elementType, unsigned numElements,
|
||||
/// and length.
|
||||
Type getVectorType(Type elementType, const llvm::ElementCount &numElements);
|
||||
|
||||
/// Creates an LLVM dialect-compatible type with the given element type and
|
||||
/// length.
|
||||
Type getFixedVectorType(Type elementType, unsigned numElements);
|
||||
|
||||
/// Creates an LLVM dialect-compatible type with the given element type and
|
||||
/// length.
|
||||
Type getScalableVectorType(Type elementType, unsigned numElements);
|
||||
|
||||
/// Returns the size of the given primitive LLVM dialect-compatible type
|
||||
/// (including vectors) in bits, for example, the size of i16 is 16 and
|
||||
/// the size of vector<4xi16> is 64. Returns 0 for non-primitive
|
||||
|
@ -61,13 +61,13 @@ static Value truncToI32(ImplicitLocOpBuilder &b, Value value) {
|
||||
static Type inferIntrinsicResultType(Type vectorResultType) {
|
||||
MLIRContext *ctx = vectorResultType.getContext();
|
||||
auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
|
||||
auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
|
||||
auto f16x2Ty = VectorType::get(2, Float16Type::get(ctx));
|
||||
auto i32Ty = IntegerType::get(ctx, 32);
|
||||
auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
|
||||
auto i32x2Ty = VectorType::get(2, i32Ty);
|
||||
Type f64Ty = Float64Type::get(ctx);
|
||||
Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
|
||||
Type f64x2Ty = VectorType::get(2, f64Ty);
|
||||
Type f32Ty = Float32Type::get(ctx);
|
||||
Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
|
||||
Type f32x2Ty = VectorType::get(2, f32Ty);
|
||||
if (a.getElementType() == f16x2Ty) {
|
||||
return LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
|
||||
@ -85,7 +85,7 @@ static Type inferIntrinsicResultType(Type vectorResultType) {
|
||||
ctx,
|
||||
SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
|
||||
}
|
||||
if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) {
|
||||
if (a.getElementType() == VectorType::get(1, f32Ty)) {
|
||||
return LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
|
||||
}
|
||||
@ -106,11 +106,11 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
|
||||
Type i32Ty = rewriter.getI32Type();
|
||||
Type f32Ty = rewriter.getF32Type();
|
||||
Type f64Ty = rewriter.getF64Type();
|
||||
Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
|
||||
Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
|
||||
Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
|
||||
Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
|
||||
Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
|
||||
Type f16x2Ty = VectorType::get(2, rewriter.getF16Type());
|
||||
Type i32x2Ty = VectorType::get(2, i32Ty);
|
||||
Type f64x2Ty = VectorType::get(2, f64Ty);
|
||||
Type f32x2Ty = VectorType::get(2, f32Ty);
|
||||
Type f32x1Ty = VectorType::get(1, f32Ty);
|
||||
|
||||
auto makeConst = [&](int32_t index) -> Value {
|
||||
return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
|
||||
@ -181,9 +181,9 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
|
||||
Type f64Ty = b.getF64Type();
|
||||
Type f32Ty = b.getF32Type();
|
||||
Type i64Ty = b.getI64Type();
|
||||
Type i8x4Ty = LLVM::getFixedVectorType(b.getI8Type(), 4);
|
||||
Type i4x8Ty = LLVM::getFixedVectorType(b.getIntegerType(4), 8);
|
||||
Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
|
||||
Type i8x4Ty = VectorType::get(4, b.getI8Type());
|
||||
Type i4x8Ty = VectorType::get(8, b.getIntegerType(4));
|
||||
Type f32x1Ty = VectorType::get(1, f32Ty);
|
||||
auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
|
||||
|
||||
for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
|
||||
@ -268,8 +268,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
|
||||
if (!vectorResultType) {
|
||||
return failure();
|
||||
}
|
||||
Type innerVectorType = LLVM::getFixedVectorType(
|
||||
vectorResultType.getElementType(), vectorResultType.getDimSize(1));
|
||||
Type innerVectorType = VectorType::get(vectorResultType.getDimSize(1),
|
||||
vectorResultType.getElementType());
|
||||
|
||||
int64_t num32BitRegs = vectorResultType.getDimSize(0);
|
||||
|
||||
@ -627,8 +627,7 @@ struct NVGPUMmaSparseSyncLowering
|
||||
|
||||
// Bitcast the sparse metadata from vector<2xf16> to an i32.
|
||||
Value sparseMetadata = adaptor.getSparseMetadata();
|
||||
if (sparseMetadata.getType() !=
|
||||
LLVM::getFixedVectorType(rewriter.getI16Type(), 2))
|
||||
if (sparseMetadata.getType() != VectorType::get(2, rewriter.getI16Type()))
|
||||
return op->emitOpError() << "Expected metadata type to be LLVM "
|
||||
"VectorType of 2 i16 elements";
|
||||
sparseMetadata =
|
||||
|
@ -851,18 +851,6 @@ Type mlir::LLVM::getVectorType(Type elementType,
|
||||
/*isScalable=*/false);
|
||||
}
|
||||
|
||||
Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) {
|
||||
assert(VectorType::isValidElementType(elementType) &&
|
||||
"incompatible element type");
|
||||
return VectorType::get(numElements, elementType);
|
||||
}
|
||||
|
||||
Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) {
|
||||
// LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
|
||||
// scalable/non-scalable.
|
||||
return VectorType::get(numElements, elementType, /*scalableDims=*/true);
|
||||
}
|
||||
|
||||
llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
|
||||
assert(isCompatibleType(type) &&
|
||||
"expected a type compatible with the LLVM dialect");
|
||||
|
@ -144,7 +144,7 @@ LogicalResult BulkStoreOp::verify() {
|
||||
std::optional<mlir::NVVM::MMATypes>
|
||||
MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
|
||||
auto half2Type =
|
||||
LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2);
|
||||
VectorType::get(2, Float16Type::get(operandElType.getContext()));
|
||||
if (operandElType.isF64())
|
||||
return NVVM::MMATypes::f64;
|
||||
if (operandElType.isF16() || operandElType == half2Type)
|
||||
@ -243,7 +243,8 @@ void MmaOp::print(OpAsmPrinter &p) {
|
||||
p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
|
||||
|
||||
// Print the types of the operands and result.
|
||||
p << " : " << "(";
|
||||
p << " : "
|
||||
<< "(";
|
||||
llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
|
||||
frags[1].regs[0].getType(),
|
||||
frags[2].regs[0].getType()},
|
||||
@ -404,7 +405,7 @@ LogicalResult MmaOp::verify() {
|
||||
MLIRContext *context = getContext();
|
||||
auto f16Ty = Float16Type::get(context);
|
||||
auto i32Ty = IntegerType::get(context, 32);
|
||||
auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
|
||||
auto f16x2Ty = VectorType::get(2, f16Ty);
|
||||
auto f32Ty = Float32Type::get(context);
|
||||
auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
|
||||
context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
|
||||
@ -506,7 +507,7 @@ LogicalResult MmaOp::verify() {
|
||||
expectedA.emplace_back(1, f64Ty);
|
||||
expectedB.emplace_back(1, f64Ty);
|
||||
expectedC.emplace_back(2, f64Ty);
|
||||
// expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
|
||||
// expectedC.emplace_back(1, VectorType::get(2, f64Ty));
|
||||
expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
|
||||
context, SmallVector<Type>(2, f64Ty)));
|
||||
allowedShapes.push_back({8, 8, 4});
|
||||
@ -992,7 +993,9 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
|
||||
ss << "},";
|
||||
// Need to map read/write registers correctly.
|
||||
regCnt = (regCnt * 2);
|
||||
ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
|
||||
ss << " $" << (regCnt) << ","
|
||||
<< " $" << (regCnt + 1) << ","
|
||||
<< " p";
|
||||
if (getTypeD() != WGMMATypes::s32) {
|
||||
ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
|
||||
}
|
||||
|
@ -103,16 +103,15 @@ nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) {
|
||||
|
||||
Type elType = type.vectorType.getElementType();
|
||||
if (elType.isF16()) {
|
||||
return FragmentElementInfo{
|
||||
LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32,
|
||||
inferNumRegistersPerMatrixFragment(type)};
|
||||
return FragmentElementInfo{VectorType::get(2, Float16Type::get(ctx)), 2, 32,
|
||||
inferNumRegistersPerMatrixFragment(type)};
|
||||
}
|
||||
|
||||
// f64 operand
|
||||
Type f64Ty = Float64Type::get(ctx);
|
||||
if (elType.isF64()) {
|
||||
return isAccum
|
||||
? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128,
|
||||
? FragmentElementInfo{VectorType::get(2, f64Ty), 2, 128,
|
||||
inferNumRegistersPerMatrixFragment(type)}
|
||||
: FragmentElementInfo{f64Ty, 1, 64,
|
||||
inferNumRegistersPerMatrixFragment(type)};
|
||||
@ -120,30 +119,27 @@ nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) {
|
||||
|
||||
// int8 operand
|
||||
if (elType.isInteger(8)) {
|
||||
return FragmentElementInfo{
|
||||
LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32,
|
||||
inferNumRegistersPerMatrixFragment(type)};
|
||||
return FragmentElementInfo{VectorType::get(4, IntegerType::get(ctx, 8)), 4,
|
||||
32, inferNumRegistersPerMatrixFragment(type)};
|
||||
}
|
||||
|
||||
// int4 operand
|
||||
if (elType.isInteger(4)) {
|
||||
return FragmentElementInfo{
|
||||
LLVM::getFixedVectorType(IntegerType::get(ctx, 4), 8), 8, 32,
|
||||
inferNumRegistersPerMatrixFragment(type)};
|
||||
return FragmentElementInfo{VectorType::get(8, IntegerType::get(ctx, 4)), 8,
|
||||
32, inferNumRegistersPerMatrixFragment(type)};
|
||||
}
|
||||
|
||||
// Integer 32bit acc operands
|
||||
if (elType.isInteger(32)) {
|
||||
return FragmentElementInfo{
|
||||
LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64,
|
||||
inferNumRegistersPerMatrixFragment(type)};
|
||||
return FragmentElementInfo{VectorType::get(2, IntegerType::get(ctx, 32)), 2,
|
||||
64, inferNumRegistersPerMatrixFragment(type)};
|
||||
}
|
||||
|
||||
// Floating point 32bit operands
|
||||
if (elType.isF32()) {
|
||||
Type f32Ty = Float32Type::get(ctx);
|
||||
return isAccum
|
||||
? FragmentElementInfo{LLVM::getFixedVectorType(f32Ty, 2), 2, 64,
|
||||
? FragmentElementInfo{VectorType::get(2, f32Ty), 2, 64,
|
||||
inferNumRegistersPerMatrixFragment(type)}
|
||||
: FragmentElementInfo{f32Ty, 1, 32,
|
||||
inferNumRegistersPerMatrixFragment(type)};
|
||||
|
@ -124,14 +124,15 @@ private:
|
||||
|
||||
/// Translates the given fixed-vector type.
|
||||
Type translate(llvm::FixedVectorType *type) {
|
||||
return LLVM::getFixedVectorType(translateType(type->getElementType()),
|
||||
type->getNumElements());
|
||||
return VectorType::get(type->getNumElements(),
|
||||
translateType(type->getElementType()));
|
||||
}
|
||||
|
||||
/// Translates the given scalable-vector type.
|
||||
Type translate(llvm::ScalableVectorType *type) {
|
||||
return LLVM::getScalableVectorType(translateType(type->getElementType()),
|
||||
type->getMinNumElements());
|
||||
return VectorType::get(type->getMinNumElements(),
|
||||
translateType(type->getElementType()),
|
||||
/*scalable=*/true);
|
||||
}
|
||||
|
||||
/// Translates the given target extension type.
|
||||
|
Loading…
x
Reference in New Issue
Block a user