From 9041b02dff85bcb2f45294881510e6ab4b290e4e Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 18 Dec 2024 17:50:43 -0800 Subject: [PATCH] Account for tokens in `allow_spmd_sharding_propagation_to_parameters` and `allow_spmd_sharding_propagation_to_output` compile options PiperOrigin-RevId: 707723232 --- jax/_src/interpreters/pxla.py | 6 +++--- tests/pjit_test.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index a5cd193b5..97df2c4bd 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2909,9 +2909,9 @@ class UnloadedMeshExecutable: da = _create_da_object(tuple(device_assignment)) del device_assignment - allow_prop_to_inputs = tuple(isinstance(i, (UnspecifiedValue, AUTO)) - for i in in_shardings) - allow_prop_to_outputs = tuple( + allow_prop_to_inputs = (False,) * len(ordered_effects) + tuple( + isinstance(i, (UnspecifiedValue, AUTO)) for i in in_shardings) + allow_prop_to_outputs = (False,) * len(ordered_effects) + tuple( isinstance(o, (UnspecifiedValue, AUTO)) or mlir.contains_unconstrained(o) for o in out_shardings) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 30ec7fb8f..8a7c0b2e6 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4305,6 +4305,24 @@ class ArrayPjitTest(jtu.JaxTestCase): self.assertLen(traced.in_avals[0], 1) self.assertLen(traced.in_avals[1], 0) # empty kwarg + def test_empty_io_callback_under_shard_map(self): + if config.use_shardy_partitioner.value: + self.skipTest("Shardy errors out on empty callbacks.") + mesh = jtu.create_mesh((4,), 'i') + + def empty_callback(x): + return + + def _f(x, y): + jax.experimental.io_callback( + empty_callback, (), x, ordered=True) + return x + y[..., jnp.newaxis] + + f = jax.jit(shard_map( + _f, mesh, in_specs=(P(None, 'i'), P(None)), + out_specs=P(None, 'i'))) + f(jnp.zeros((2, 16)), jnp.ones(2)) + def test_jit_trace_lower_and_compile(self): def f(x): return x * 2