diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index 3277ff0d9..2e19cb820 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -20,6 +20,7 @@ #include "absl/status/status.h" #include "absl/types/span.h" #include "mlir/include/mlir/IR/Attributes.h" +#include "mlir/include/mlir/IR/BuiltinTypes.h" #include "mlir/include/mlir/IR/Diagnostics.h" #include "mlir/include/mlir/IR/Value.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" @@ -156,14 +157,9 @@ FailureOr getTypeBitwidth(Type ty) { return width; } } - if (auto f32_ty = dyn_cast(ty)) { - return 32; - } - if (auto bf16_ty = dyn_cast(ty)) { - return 16; - } - if (isa(ty)) { - return 8; + if (isa(ty)) { + return ty.getIntOrFloatBitWidth(); } return emitError(UnknownLoc::get(ty.getContext()), "Unsupported type in mosaic dialect: ")