[Mosaic TPU] Add optimized casts for bf16->s4 in TPUv6

PiperOrigin-RevId: 723455843
This commit is contained in:
Adam Paszke 2025-02-05 04:21:11 -08:00 committed by jax authors
parent 1fbc4a15dd
commit e7a4f89343

View File

@ -561,9 +561,11 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx,
if (dst_bitwidth > 32) {
return op.emitOpError("Target bitwidth too large");
}
// We have low-level optimized code for bf16->s8 and bf16->s4 casts on v6.
if (ctx.hardware_generation >= 6 && is_vector &&
src_vty.getElementType().isBF16() &&
dst_vty.getElementType().isSignlessInteger(8)) {
(dst_vty.getElementType().isSignlessInteger(8) ||
dst_vty.getElementType().isSignlessInteger(4))) {
auto new_op = builder.create<tpu::FPToSIOp>(
op.getType(), op.getIn(), tpu::RoundingMode::kTowardsZero);
op.replaceAllUsesWith(new_op.getResult());