mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
[Mosaic TPU] Add optimized casts for bf16->s4 in TPUv6
PiperOrigin-RevId: 723455843
This commit is contained in:
parent
1fbc4a15dd
commit
e7a4f89343
@ -561,9 +561,11 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx,
|
||||
if (dst_bitwidth > 32) {
|
||||
return op.emitOpError("Target bitwidth too large");
|
||||
}
|
||||
// We have low-level optimized code for bf16->s8 and bf16->s4 casts on v6.
|
||||
if (ctx.hardware_generation >= 6 && is_vector &&
|
||||
src_vty.getElementType().isBF16() &&
|
||||
dst_vty.getElementType().isSignlessInteger(8)) {
|
||||
(dst_vty.getElementType().isSignlessInteger(8) ||
|
||||
dst_vty.getElementType().isSignlessInteger(4))) {
|
||||
auto new_op = builder.create<tpu::FPToSIOp>(
|
||||
op.getType(), op.getIn(), tpu::RoundingMode::kTowardsZero);
|
||||
op.replaceAllUsesWith(new_op.getResult());
|
||||
|
Loading…
x
Reference in New Issue
Block a user