Merge pull request #18515 from gnecula:export_call_bool

PiperOrigin-RevId: 582311324
This commit is contained in:
jax authors 2023-11-14 07:13:43 -08:00
commit 2356d7afd0
3 changed files with 15 additions and 3 deletions

View File

@ -1941,8 +1941,9 @@ def convert_hlo(ctx: LoweringRuleContext, x, aval_in, aval_out):
compare_type = "SIGNED"
else:
compare_type = "UNSIGNED"
return compare_hlo(x, full_like_aval(ctx, 0, aval_in), "NE",
compare_type).result
x = compare_hlo(x, full_like_aval(ctx, 0, aval_in), "NE",
compare_type).result
# continue, to adjust the shape if needed
return hlo.ConvertOp(aval_to_ir_type(aval_out), x).result
def _wrap_with_spmd_op(name: str,

View File

@ -1166,7 +1166,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.AbstractValue) -> ir.Value:
new_ir_type = mlir.aval_to_ir_type(new_aval)
if x.type != new_ir_type:
return mlir.convert_hlo(ctx, x, x_aval, new_aval)
return hlo.ConvertOp(mlir.aval_to_ir_type(new_aval), x).result
else:
return x

View File

@ -671,6 +671,17 @@ class JaxExportTest(jtu.JaxTestCase):
export.poly_spec(x.shape, x.dtype, poly_spec))
export.call_exported(exp)(x)
def test_poly_booleans(self):
# For booleans we use a special case ConvertOp to cast to and from
# dynamic shapes arguments.
def f_jax(x): # x: bool[b]
return jnp.logical_not(x)
x = np.array([True, False, True, False], dtype=np.bool_)
exp = export.export(f_jax)(export.poly_spec(x.shape, x.dtype, "b"))
res = export.call_exported(exp)(x)
self.assertAllClose(f_jax(x), res)
def test_with_sharding(self):
nr_devices = 2
if len(jax.devices()) < nr_devices: