[mlir][IR] Remove builder API + caching for low-precision FP types (#123321)

Remove builder API (e.g., `b.getFloat4E2M1FNType()`) and caching in
`MLIRContext` for low-precision FP types. Types are still cached in the
type uniquer.

For details, see:
https://discourse.llvm.org/t/rethink-on-approach-to-low-precision-fp-types/82361/28

Note for LLVM integration: Use `b.getType<Float4E2M1FNType>()` or
`Float4E2M1FNType::get(b.getContext())` instead of
`b.getFloat4E2M1FNType()`.
This commit is contained in:
Matthias Springer 2025-01-18 10:38:51 +01:00 committed by GitHub
parent 67c3f2b430
commit f4943464d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 60 additions and 158 deletions

View File

@ -61,17 +61,6 @@ public:
Attribute metadata = Attribute());
// Types.
FloatType getFloat4E2M1FNType();
FloatType getFloat6E2M3FNType();
FloatType getFloat6E3M2FNType();
FloatType getFloat8E5M2Type();
FloatType getFloat8E4M3Type();
FloatType getFloat8E4M3FNType();
FloatType getFloat8E5M2FNUZType();
FloatType getFloat8E4M3FNUZType();
FloatType getFloat8E4M3B11FNUZType();
FloatType getFloat8E3M4Type();
FloatType getFloat8E8M0FNUType();
FloatType getBF16Type();
FloatType getF16Type();
FloatType getTF32Type();

View File

@ -85,6 +85,12 @@ class Builtin_FloatType<string name, string mnemonic,
DeclareTypeInterfaceMethods<
FloatTypeInterface,
["getFloatSemantics"] # declaredInterfaceMethods>]> {
}
// Float types that are cached in MLIRContext.
class Builtin_CachedFloatType<string name, string mnemonic,
list<string> declaredInterfaceMethods = []>
: Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
let extraClassDeclaration = [{
static }] # name # [{Type get(MLIRContext *context);
}];
@ -326,7 +332,7 @@ def Builtin_Float8E8M0FNU : Builtin_FloatType<"Float8E8M0FNU", "f8E8M0FNU"> {
//===----------------------------------------------------------------------===//
// BFloat16Type
def Builtin_BFloat16 : Builtin_FloatType<"BFloat16", "bf16",
def Builtin_BFloat16 : Builtin_CachedFloatType<"BFloat16", "bf16",
/*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
let summary = "bfloat16 floating-point type";
}
@ -334,7 +340,7 @@ def Builtin_BFloat16 : Builtin_FloatType<"BFloat16", "bf16",
//===----------------------------------------------------------------------===//
// Float16Type
def Builtin_Float16 : Builtin_FloatType<"Float16", "f16",
def Builtin_Float16 : Builtin_CachedFloatType<"Float16", "f16",
/*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
let summary = "16-bit floating-point type";
}
@ -342,14 +348,14 @@ def Builtin_Float16 : Builtin_FloatType<"Float16", "f16",
//===----------------------------------------------------------------------===//
// FloatTF32Type
def Builtin_FloatTF32 : Builtin_FloatType<"FloatTF32", "tf32"> {
def Builtin_FloatTF32 : Builtin_CachedFloatType<"FloatTF32", "tf32"> {
let summary = "TF32 floating-point type";
}
//===----------------------------------------------------------------------===//
// Float32Type
def Builtin_Float32 : Builtin_FloatType<"Float32", "f32",
def Builtin_Float32 : Builtin_CachedFloatType<"Float32", "f32",
/*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
let summary = "32-bit floating-point type";
}
@ -357,21 +363,21 @@ def Builtin_Float32 : Builtin_FloatType<"Float32", "f32",
//===----------------------------------------------------------------------===//
// Float64Type
def Builtin_Float64 : Builtin_FloatType<"Float64", "f64"> {
def Builtin_Float64 : Builtin_CachedFloatType<"Float64", "f64"> {
let summary = "64-bit floating-point type";
}
//===----------------------------------------------------------------------===//
// Float80Type
def Builtin_Float80 : Builtin_FloatType<"Float80", "f80"> {
def Builtin_Float80 : Builtin_CachedFloatType<"Float80", "f80"> {
let summary = "80-bit floating-point type";
}
//===----------------------------------------------------------------------===//
// Float128Type
def Builtin_Float128 : Builtin_FloatType<"Float128", "f128"> {
def Builtin_Float128 : Builtin_CachedFloatType<"Float128", "f128"> {
let summary = "128-bit floating-point type";
}

View File

@ -330,31 +330,31 @@ def F80 : F<80>;
def F128 : F<128>;
def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
BuildableType<"$_builder.getBF16Type()">;
BuildableType<"$_builder.getType<BFloat16Type>()">;
def TF32 : Type<CPred<"$_self.isTF32()">, "tf32 type">,
BuildableType<"$_builder.getTF32Type()">;
BuildableType<"$_builder.getType<FloatTF32Type>()">;
def F8E4M3FN : Type<CPred<"$_self.isFloat8E4M3FN()">, "f8E4M3FN type">,
BuildableType<"$_builder.getFloat8E4M3FNType()">;
BuildableType<"$_builder.getType<Float8E4M3FNType>()">;
def F8E5M2 : Type<CPred<"$_self.isFloat8E5M2()">, "f8E5M2 type">,
BuildableType<"$_builder.getFloat8E5M2Type()">;
BuildableType<"$_builder.getType<Float8E5M2Type>()">;
def F8E4M3 : Type<CPred<"$_self.isFloat8E4M3()">, "f8E4M3 type">,
BuildableType<"$_builder.getFloat8E4M3Type()">;
BuildableType<"$_builder.getType<Float8E4M3Type>()">;
def F8E4M3FNUZ : Type<CPred<"$_self.isFloat8E4M3FNUZ()">, "f8E4M3FNUZ type">,
BuildableType<"$_builder.getFloat8E4M3FNUZType()">;
BuildableType<"$_builder.getType<Float8E4M3FNUZType>()">;
def F8E4M3B11FNUZ : Type<CPred<"$_self.isFloat8E4M3B11FNUZ()">, "f8E4M3B11FNUZ type">,
BuildableType<"$_builder.getFloat8E4M3B11FNUZType()">;
BuildableType<"$_builder.getType<Float8E4M3B11FNUZType>()">;
def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
BuildableType<"$_builder.getFloat8E5M2FNUZType()">;
BuildableType<"$_builder.getType<Float8E5M2FNUZType>()">;
def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
BuildableType<"$_builder.getFloat8E3M4Type()">;
BuildableType<"$_builder.getType<Float8E3M4Type>()">;
def F4E2M1FN : Type<CPred<"$_self.isFloat4E2M1FN()">, "f4E2M1FN type">,
BuildableType<"$_builder.getFloat4E2M1FNType()">;
BuildableType<"$_builder.getType<Float4E2M1FNType>()">;
def F6E2M3FN : Type<CPred<"$_self.isFloat6E2M3FN()">, "f6E2M3FN type">,
BuildableType<"$_builder.getFloat6E2M3FNType()">;
BuildableType<"$_builder.getType<Float6E2M3FNType>()">;
def F6E3M2FN : Type<CPred<"$_self.isFloat6E3M2FN()">, "f6E3M2FN type">,
BuildableType<"$_builder.getFloat6E3M2FNType()">;
BuildableType<"$_builder.getType<Float6E3M2FNType>()">;
def F8E8M0FNU : Type<CPred<"$_self.isFloat8E8M0FNU()">, "f8E8M0FNU type">,
BuildableType<"$_builder.getFloat8E8M0FNUType()">;
BuildableType<"$_builder.getType<Float8E8M0FNUType>()">;
def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">,
"complex-type", "::mlir::ComplexType">;

View File

@ -309,58 +309,58 @@ Type Parser::parseNonFunctionType() {
// float-type
case Token::kw_f4E2M1FN:
consumeToken(Token::kw_f4E2M1FN);
return builder.getFloat4E2M1FNType();
return builder.getType<Float4E2M1FNType>();
case Token::kw_f6E2M3FN:
consumeToken(Token::kw_f6E2M3FN);
return builder.getFloat6E2M3FNType();
return builder.getType<Float6E2M3FNType>();
case Token::kw_f6E3M2FN:
consumeToken(Token::kw_f6E3M2FN);
return builder.getFloat6E3M2FNType();
return builder.getType<Float6E3M2FNType>();
case Token::kw_f8E5M2:
consumeToken(Token::kw_f8E5M2);
return builder.getFloat8E5M2Type();
return builder.getType<Float8E5M2Type>();
case Token::kw_f8E4M3:
consumeToken(Token::kw_f8E4M3);
return builder.getFloat8E4M3Type();
return builder.getType<Float8E4M3Type>();
case Token::kw_f8E4M3FN:
consumeToken(Token::kw_f8E4M3FN);
return builder.getFloat8E4M3FNType();
return builder.getType<Float8E4M3FNType>();
case Token::kw_f8E5M2FNUZ:
consumeToken(Token::kw_f8E5M2FNUZ);
return builder.getFloat8E5M2FNUZType();
return builder.getType<Float8E5M2FNUZType>();
case Token::kw_f8E4M3FNUZ:
consumeToken(Token::kw_f8E4M3FNUZ);
return builder.getFloat8E4M3FNUZType();
return builder.getType<Float8E4M3FNUZType>();
case Token::kw_f8E4M3B11FNUZ:
consumeToken(Token::kw_f8E4M3B11FNUZ);
return builder.getFloat8E4M3B11FNUZType();
return builder.getType<Float8E4M3B11FNUZType>();
case Token::kw_f8E3M4:
consumeToken(Token::kw_f8E3M4);
return builder.getFloat8E3M4Type();
return builder.getType<Float8E3M4Type>();
case Token::kw_f8E8M0FNU:
consumeToken(Token::kw_f8E8M0FNU);
return builder.getFloat8E8M0FNUType();
return builder.getType<Float8E8M0FNUType>();
case Token::kw_bf16:
consumeToken(Token::kw_bf16);
return builder.getBF16Type();
return builder.getType<BFloat16Type>();
case Token::kw_f16:
consumeToken(Token::kw_f16);
return builder.getF16Type();
return builder.getType<Float16Type>();
case Token::kw_tf32:
consumeToken(Token::kw_tf32);
return builder.getTF32Type();
return builder.getType<FloatTF32Type>();
case Token::kw_f32:
consumeToken(Token::kw_f32);
return builder.getF32Type();
return builder.getType<Float32Type>();
case Token::kw_f64:
consumeToken(Token::kw_f64);
return builder.getF64Type();
return builder.getType<Float64Type>();
case Token::kw_f80:
consumeToken(Token::kw_f80);
return builder.getF80Type();
return builder.getType<Float80Type>();
case Token::kw_f128:
consumeToken(Token::kw_f128);
return builder.getF128Type();
return builder.getType<Float128Type>();
// index-type
case Token::kw_index:

View File

@ -361,22 +361,22 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) {
Builder b(ctx);
return llvm::StringSwitch<std::optional<FloatType>>(name)
.Case("f4E2M1FN", b.getFloat4E2M1FNType())
.Case("f6E2M3FN", b.getFloat6E2M3FNType())
.Case("f6E3M2FN", b.getFloat6E3M2FNType())
.Case("f8E5M2", b.getFloat8E5M2Type())
.Case("f8E4M3", b.getFloat8E4M3Type())
.Case("f8E4M3FN", b.getFloat8E4M3FNType())
.Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
.Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
.Case("f8E3M4", b.getFloat8E3M4Type())
.Case("f8E8M0FNU", b.getFloat8E8M0FNUType())
.Case("bf16", b.getBF16Type())
.Case("f16", b.getF16Type())
.Case("f32", b.getF32Type())
.Case("f64", b.getF64Type())
.Case("f80", b.getF80Type())
.Case("f128", b.getF128Type())
.Case("f4E2M1FN", b.getType<Float4E2M1FNType>())
.Case("f6E2M3FN", b.getType<Float6E2M3FNType>())
.Case("f6E3M2FN", b.getType<Float6E3M2FNType>())
.Case("f8E5M2", b.getType<Float8E5M2Type>())
.Case("f8E4M3", b.getType<Float8E4M3Type>())
.Case("f8E4M3FN", b.getType<Float8E4M3FNType>())
.Case("f8E5M2FNUZ", b.getType<Float8E5M2FNUZType>())
.Case("f8E4M3FNUZ", b.getType<Float8E4M3FNUZType>())
.Case("f8E3M4", b.getType<Float8E3M4Type>())
.Case("f8E8M0FNU", b.getType<Float8E8M0FNUType>())
.Case("bf16", b.getType<BFloat16Type>())
.Case("f16", b.getType<Float16Type>())
.Case("f32", b.getType<Float32Type>())
.Case("f64", b.getType<Float64Type>())
.Case("f80", b.getType<Float80Type>())
.Case("f128", b.getType<Float128Type>())
.Default(std::nullopt);
}

View File

@ -34,44 +34,6 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
// Types.
//===----------------------------------------------------------------------===//
FloatType Builder::getFloat4E2M1FNType() {
return Float4E2M1FNType::get(context);
}
FloatType Builder::getFloat6E2M3FNType() {
return Float6E2M3FNType::get(context);
}
FloatType Builder::getFloat6E3M2FNType() {
return Float6E3M2FNType::get(context);
}
FloatType Builder::getFloat8E5M2Type() { return Float8E5M2Type::get(context); }
FloatType Builder::getFloat8E4M3Type() { return Float8E4M3Type::get(context); }
FloatType Builder::getFloat8E4M3FNType() {
return Float8E4M3FNType::get(context);
}
FloatType Builder::getFloat8E5M2FNUZType() {
return Float8E5M2FNUZType::get(context);
}
FloatType Builder::getFloat8E4M3FNUZType() {
return Float8E4M3FNUZType::get(context);
}
FloatType Builder::getFloat8E4M3B11FNUZType() {
return Float8E4M3B11FNUZType::get(context);
}
FloatType Builder::getFloat8E3M4Type() { return Float8E3M4Type::get(context); }
FloatType Builder::getFloat8E8M0FNUType() {
return Float8E8M0FNUType::get(context);
}
FloatType Builder::getBF16Type() { return BFloat16Type::get(context); }
FloatType Builder::getF16Type() { return Float16Type::get(context); }

View File

@ -221,17 +221,6 @@ public:
llvm::DenseMap<StringRef, AbstractType *> nameToType;
/// Cached Type Instances.
Float4E2M1FNType f4E2M1FNTy;
Float6E2M3FNType f6E2M3FNTy;
Float6E3M2FNType f6E3M2FNTy;
Float8E5M2Type f8E5M2Ty;
Float8E4M3Type f8E4M3Ty;
Float8E4M3FNType f8E4M3FNTy;
Float8E5M2FNUZType f8E5M2FNUZTy;
Float8E4M3FNUZType f8E4M3FNUZTy;
Float8E4M3B11FNUZType f8E4M3B11FNUZTy;
Float8E3M4Type f8E3M4Ty;
Float8E8M0FNUType f8E8M0FNUTy;
BFloat16Type bf16Ty;
Float16Type f16Ty;
FloatTF32Type tf32Ty;
@ -317,17 +306,6 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
//// Types.
/// Floating-point Types.
impl->f4E2M1FNTy = TypeUniquer::get<Float4E2M1FNType>(this);
impl->f6E2M3FNTy = TypeUniquer::get<Float6E2M3FNType>(this);
impl->f6E3M2FNTy = TypeUniquer::get<Float6E3M2FNType>(this);
impl->f8E5M2Ty = TypeUniquer::get<Float8E5M2Type>(this);
impl->f8E4M3Ty = TypeUniquer::get<Float8E4M3Type>(this);
impl->f8E4M3FNTy = TypeUniquer::get<Float8E4M3FNType>(this);
impl->f8E5M2FNUZTy = TypeUniquer::get<Float8E5M2FNUZType>(this);
impl->f8E4M3FNUZTy = TypeUniquer::get<Float8E4M3FNUZType>(this);
impl->f8E4M3B11FNUZTy = TypeUniquer::get<Float8E4M3B11FNUZType>(this);
impl->f8E3M4Ty = TypeUniquer::get<Float8E3M4Type>(this);
impl->f8E8M0FNUTy = TypeUniquer::get<Float8E8M0FNUType>(this);
impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
impl->f16Ty = TypeUniquer::get<Float16Type>(this);
impl->tf32Ty = TypeUniquer::get<FloatTF32Type>(this);
@ -1044,39 +1022,6 @@ AbstractType::lookup(StringRef name, MLIRContext *context) {
/// This should not be used directly.
StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
Float4E2M1FNType Float4E2M1FNType::get(MLIRContext *context) {
return context->getImpl().f4E2M1FNTy;
}
Float6E2M3FNType Float6E2M3FNType::get(MLIRContext *context) {
return context->getImpl().f6E2M3FNTy;
}
Float6E3M2FNType Float6E3M2FNType::get(MLIRContext *context) {
return context->getImpl().f6E3M2FNTy;
}
Float8E5M2Type Float8E5M2Type::get(MLIRContext *context) {
return context->getImpl().f8E5M2Ty;
}
Float8E4M3Type Float8E4M3Type::get(MLIRContext *context) {
return context->getImpl().f8E4M3Ty;
}
Float8E4M3FNType Float8E4M3FNType::get(MLIRContext *context) {
return context->getImpl().f8E4M3FNTy;
}
Float8E5M2FNUZType Float8E5M2FNUZType::get(MLIRContext *context) {
return context->getImpl().f8E5M2FNUZTy;
}
Float8E4M3FNUZType Float8E4M3FNUZType::get(MLIRContext *context) {
return context->getImpl().f8E4M3FNUZTy;
}
Float8E4M3B11FNUZType Float8E4M3B11FNUZType::get(MLIRContext *context) {
return context->getImpl().f8E4M3B11FNUZTy;
}
Float8E3M4Type Float8E3M4Type::get(MLIRContext *context) {
return context->getImpl().f8E3M4Ty;
}
Float8E8M0FNUType Float8E8M0FNUType::get(MLIRContext *context) {
return context->getImpl().f8E8M0FNUTy;
}
BFloat16Type BFloat16Type::get(MLIRContext *context) {
return context->getImpl().bf16Ty;
}