mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Easy][Mosaic] Tiny refactor for clarity in getTypeBitwidth
PiperOrigin-RevId: 730906329
This commit is contained in:
parent
3d87a01bea
commit
083ffd3717
@ -20,6 +20,7 @@
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "mlir/include/mlir/IR/Attributes.h"
|
||||
#include "mlir/include/mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/include/mlir/IR/Diagnostics.h"
|
||||
#include "mlir/include/mlir/IR/Value.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/layout.h"
|
||||
@ -156,14 +157,9 @@ FailureOr<int8_t> getTypeBitwidth(Type ty) {
|
||||
return width;
|
||||
}
|
||||
}
|
||||
if (auto f32_ty = dyn_cast<Float32Type>(ty)) {
|
||||
return 32;
|
||||
}
|
||||
if (auto bf16_ty = dyn_cast<BFloat16Type>(ty)) {
|
||||
return 16;
|
||||
}
|
||||
if (isa<Float8E5M2Type, Float8E4M3FNType, Float8E4M3B11FNUZType>(ty)) {
|
||||
return 8;
|
||||
if (isa<IntegerType, Float32Type, BFloat16Type, Float8E5M2Type,
|
||||
Float8E4M3FNType, Float8E4M3B11FNUZType>(ty)) {
|
||||
return ty.getIntOrFloatBitWidth();
|
||||
}
|
||||
return emitError(UnknownLoc::get(ty.getContext()),
|
||||
"Unsupported type in mosaic dialect: ")
|
||||
|
Loading…
x
Reference in New Issue
Block a user