Account for tokens in allow_spmd_sharding_propagation_to_parameters and allow_spmd_sharding_propagation_to_output compile options

PiperOrigin-RevId: 707723232
This commit is contained in:
Yash Katariya 2024-12-18 17:50:43 -08:00 committed by jax authors
parent 46b18d272c
commit 9041b02dff
2 changed files with 21 additions and 3 deletions

View File

@ -2909,9 +2909,9 @@ class UnloadedMeshExecutable:
da = _create_da_object(tuple(device_assignment))
del device_assignment
allow_prop_to_inputs = tuple(isinstance(i, (UnspecifiedValue, AUTO))
for i in in_shardings)
allow_prop_to_outputs = tuple(
allow_prop_to_inputs = (False,) * len(ordered_effects) + tuple(
isinstance(i, (UnspecifiedValue, AUTO)) for i in in_shardings)
allow_prop_to_outputs = (False,) * len(ordered_effects) + tuple(
isinstance(o, (UnspecifiedValue, AUTO)) or mlir.contains_unconstrained(o)
for o in out_shardings)

View File

@ -4305,6 +4305,24 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertLen(traced.in_avals[0], 1)
self.assertLen(traced.in_avals[1], 0) # empty kwarg
def test_empty_io_callback_under_shard_map(self):
if config.use_shardy_partitioner.value:
self.skipTest("Shardy errors out on empty callbacks.")
mesh = jtu.create_mesh((4,), 'i')
def empty_callback(x):
return
def _f(x, y):
jax.experimental.io_callback(
empty_callback, (), x, ordered=True)
return x + y[..., jnp.newaxis]
f = jax.jit(shard_map(
_f, mesh, in_specs=(P(None, 'i'), P(None)),
out_specs=P(None, 'i')))
f(jnp.zeros((2, 16)), jnp.ones(2))
def test_jit_trace_lower_and_compile(self):
def f(x):
return x * 2