mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Pallas] Enable int8->fp32 conversions
PiperOrigin-RevId: 562969276
This commit is contained in:
parent
4f805c2d8f
commit
b2e5a1cf6a
@ -780,6 +780,16 @@ def _convert_element_type_lowering_rule(
|
||||
elif jnp.issubdtype(old_dtype, jnp.signedinteger) and jnp.issubdtype(
|
||||
new_dtype, jnp.floating
|
||||
):
|
||||
# TODO(sharadmv,apaszke): remove this when Mosaic handles SIToFP with
|
||||
# differing element bitwidths
|
||||
if old_dtype.itemsize < new_dtype.itemsize:
|
||||
ext_dtype = _INT_DTYPES[new_dtype.itemsize * 8]
|
||||
ext_type = aval_to_ir_type(out_aval.update(dtype=ext_dtype))
|
||||
x = arith.ExtSIOp(ext_type, x).result
|
||||
elif old_dtype.itemsize > new_dtype.itemsize:
|
||||
ext_dtype = _INT_DTYPES[new_dtype.itemsize * 8]
|
||||
ext_type = aval_to_ir_type(out_aval.update(dtype=ext_dtype))
|
||||
x = arith.TruncIOp(ext_type, x).result
|
||||
return arith.SIToFPOp(out_type, x).result
|
||||
elif jnp.issubdtype(old_dtype, jnp.signedinteger) and jnp.issubdtype(
|
||||
new_dtype, jnp.signedinteger
|
||||
|
Loading…
x
Reference in New Issue
Block a user