From b7968474c20b61c9aa85b2ab3dd08446218a1258 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 20 Feb 2025 10:43:46 -0800 Subject: [PATCH] [Pallas][Mosaic] Support float8_e4m3b11fnuz PiperOrigin-RevId: 729169181 --- jax/_src/pallas/mosaic/lowering.py | 1 + jaxlib/mosaic/dialect/tpu/util.h | 8 +++----- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index f0fb23b1f..0cea22150 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1851,6 +1851,7 @@ def _dot_general_lowering_rule( ir.F32Type, ir.Float8E5M2Type, ir.Float8E4M3FNType, + ir.Float8E4M3B11FNUZType, ] ): val = ir.FloatAttr.get(val_type, 0.0) diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index c25872668..3277ff0d9 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -162,13 +162,11 @@ FailureOr getTypeBitwidth(Type ty) { if (auto bf16_ty = dyn_cast(ty)) { return 16; } - if (auto f8e5m2_ty = dyn_cast(ty)) { + if (isa(ty)) { return 8; } - if (auto f8e4m3fn_ty = dyn_cast(ty)) { - return 8; - } - return emitError(UnknownLoc::get(ty.getContext()), "Unsupported type: ") + return emitError(UnknownLoc::get(ty.getContext()), + "Unsupported type in mosaic dialect: ") << ty; }