mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Pallas][Mosaic] Support float8_e4m3b11fnuz
PiperOrigin-RevId: 729169181
This commit is contained in:
parent
ddcb7deeaf
commit
b7968474c2
@ -1851,6 +1851,7 @@ def _dot_general_lowering_rule(
|
||||
ir.F32Type,
|
||||
ir.Float8E5M2Type,
|
||||
ir.Float8E4M3FNType,
|
||||
ir.Float8E4M3B11FNUZType,
|
||||
]
|
||||
):
|
||||
val = ir.FloatAttr.get(val_type, 0.0)
|
||||
|
@ -162,13 +162,11 @@ FailureOr<int8_t> getTypeBitwidth(Type ty) {
|
||||
if (auto bf16_ty = dyn_cast<BFloat16Type>(ty)) {
|
||||
return 16;
|
||||
}
|
||||
if (auto f8e5m2_ty = dyn_cast<Float8E5M2Type>(ty)) {
|
||||
if (isa<Float8E5M2Type, Float8E4M3FNType, Float8E4M3B11FNUZType>(ty)) {
|
||||
return 8;
|
||||
}
|
||||
if (auto f8e4m3fn_ty = dyn_cast<Float8E4M3FNType>(ty)) {
|
||||
return 8;
|
||||
}
|
||||
return emitError(UnknownLoc::get(ty.getContext()), "Unsupported type: ")
|
||||
return emitError(UnknownLoc::get(ty.getContext()),
|
||||
"Unsupported type in mosaic dialect: ")
|
||||
<< ty;
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user