From b2e5a1cf6a5a9406da82d64ee1c54975dd98e0de Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Tue, 5 Sep 2023 20:25:09 -0700 Subject: [PATCH] [Pallas] Enable int8->fp32 conversions PiperOrigin-RevId: 562969276 --- jax/_src/pallas/mosaic/lowering.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 79166f21c..79c003e3b 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -780,6 +780,16 @@ def _convert_element_type_lowering_rule( elif jnp.issubdtype(old_dtype, jnp.signedinteger) and jnp.issubdtype( new_dtype, jnp.floating ): + # TODO(sharadmv,apaszke): remove this when Mosaic handles SIToFP with + # differing element bitwidths + if old_dtype.itemsize < new_dtype.itemsize: + ext_dtype = _INT_DTYPES[new_dtype.itemsize * 8] + ext_type = aval_to_ir_type(out_aval.update(dtype=ext_dtype)) + x = arith.ExtSIOp(ext_type, x).result + elif old_dtype.itemsize > new_dtype.itemsize: + ext_dtype = _INT_DTYPES[new_dtype.itemsize * 8] + ext_type = aval_to_ir_type(out_aval.update(dtype=ext_dtype)) + x = arith.TruncIOp(ext_type, x).result return arith.SIToFPOp(out_type, x).result elif jnp.issubdtype(old_dtype, jnp.signedinteger) and jnp.issubdtype( new_dtype, jnp.signedinteger