[Easy][Mosaic] Tiny refactor for clarity in getTypeBitwidth

PiperOrigin-RevId: 730906329
This commit is contained in:
jax authors 2025-02-25 08:57:42 -08:00
parent 3d87a01bea
commit 083ffd3717

View File

@ -20,6 +20,7 @@
#include "absl/status/status.h"
#include "absl/types/span.h"
#include "mlir/include/mlir/IR/Attributes.h"
#include "mlir/include/mlir/IR/BuiltinTypes.h"
#include "mlir/include/mlir/IR/Diagnostics.h"
#include "mlir/include/mlir/IR/Value.h"
#include "jaxlib/mosaic/dialect/tpu/layout.h"
@ -156,14 +157,9 @@ FailureOr<int8_t> getTypeBitwidth(Type ty) {
return width;
}
}
if (auto f32_ty = dyn_cast<Float32Type>(ty)) {
return 32;
}
if (auto bf16_ty = dyn_cast<BFloat16Type>(ty)) {
return 16;
}
if (isa<Float8E5M2Type, Float8E4M3FNType, Float8E4M3B11FNUZType>(ty)) {
return 8;
if (isa<IntegerType, Float32Type, BFloat16Type, Float8E5M2Type,
Float8E4M3FNType, Float8E4M3B11FNUZType>(ty)) {
return ty.getIntOrFloatBitWidth();
}
return emitError(UnknownLoc::get(ty.getContext()),
"Unsupported type in mosaic dialect: ")