mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #18515 from gnecula:export_call_bool
PiperOrigin-RevId: 582311324
This commit is contained in:
commit
2356d7afd0
@ -1941,8 +1941,9 @@ def convert_hlo(ctx: LoweringRuleContext, x, aval_in, aval_out):
|
||||
compare_type = "SIGNED"
|
||||
else:
|
||||
compare_type = "UNSIGNED"
|
||||
return compare_hlo(x, full_like_aval(ctx, 0, aval_in), "NE",
|
||||
compare_type).result
|
||||
x = compare_hlo(x, full_like_aval(ctx, 0, aval_in), "NE",
|
||||
compare_type).result
|
||||
# continue, to adjust the shape if needed
|
||||
return hlo.ConvertOp(aval_to_ir_type(aval_out), x).result
|
||||
|
||||
def _wrap_with_spmd_op(name: str,
|
||||
|
@ -1166,7 +1166,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
|
||||
def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.AbstractValue) -> ir.Value:
|
||||
new_ir_type = mlir.aval_to_ir_type(new_aval)
|
||||
if x.type != new_ir_type:
|
||||
return mlir.convert_hlo(ctx, x, x_aval, new_aval)
|
||||
return hlo.ConvertOp(mlir.aval_to_ir_type(new_aval), x).result
|
||||
else:
|
||||
return x
|
||||
|
||||
|
@ -671,6 +671,17 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
export.poly_spec(x.shape, x.dtype, poly_spec))
|
||||
export.call_exported(exp)(x)
|
||||
|
||||
def test_poly_booleans(self):
|
||||
# For booleans we use a special case ConvertOp to cast to and from
|
||||
# dynamic shapes arguments.
|
||||
def f_jax(x): # x: bool[b]
|
||||
return jnp.logical_not(x)
|
||||
|
||||
x = np.array([True, False, True, False], dtype=np.bool_)
|
||||
exp = export.export(f_jax)(export.poly_spec(x.shape, x.dtype, "b"))
|
||||
res = export.call_exported(exp)(x)
|
||||
self.assertAllClose(f_jax(x), res)
|
||||
|
||||
def test_with_sharding(self):
|
||||
nr_devices = 2
|
||||
if len(jax.devices()) < nr_devices:
|
||||
|
Loading…
x
Reference in New Issue
Block a user