[Pallas] Enable int8->fp32 conversions

PiperOrigin-RevId: 562969276
This commit is contained in:
Sharad Vikram 2023-09-05 20:25:09 -07:00 committed by jax authors
parent 4f805c2d8f
commit b2e5a1cf6a

View File

@ -780,6 +780,16 @@ def _convert_element_type_lowering_rule(
elif jnp.issubdtype(old_dtype, jnp.signedinteger) and jnp.issubdtype(
new_dtype, jnp.floating
):
# TODO(sharadmv,apaszke): remove this when Mosaic handles SIToFP with
# differing element bitwidths
if old_dtype.itemsize < new_dtype.itemsize:
ext_dtype = _INT_DTYPES[new_dtype.itemsize * 8]
ext_type = aval_to_ir_type(out_aval.update(dtype=ext_dtype))
x = arith.ExtSIOp(ext_type, x).result
elif old_dtype.itemsize > new_dtype.itemsize:
ext_dtype = _INT_DTYPES[new_dtype.itemsize * 8]
ext_type = aval_to_ir_type(out_aval.update(dtype=ext_dtype))
x = arith.TruncIOp(ext_type, x).result
return arith.SIToFPOp(out_type, x).result
elif jnp.issubdtype(old_dtype, jnp.signedinteger) and jnp.issubdtype(
new_dtype, jnp.signedinteger