From 493e2f8ae2c028263d735eefc995f363af8af24a Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 17 Nov 2023 20:48:22 -0800 Subject: [PATCH] If a function returns no output, xla_executable.get_output_shardings() returns 1 sharding because for XLA the output is an empty tuple which has a tuple sharding. PiperOrigin-RevId: 583555384 --- jax/_src/interpreters/pxla.py | 5 +++++ tests/pjit_test.py | 10 ++++++++++ 2 files changed, 15 insertions(+) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 18195a813..01bbcb4b0 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2303,6 +2303,11 @@ def get_gspmd_shardings_from_executable( if num_ordered_effects > 0: out_op_shardings = out_op_shardings[num_ordered_effects:] + # This means that there are no outputs for JAX but for XLA there is an empty + # tuple output which gets a replicated sharding. + if num_out_avals == 0 and len(out_op_shardings) == 1: + return None + # This condition happens when all the elements in the output tuple have the # same sharding, so XLA decides to run the `FusionTupleDeduplicator` to # put the sharding on ROOT instead of the tuple. diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 37e175dd4..ece381cdc 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3666,6 +3666,16 @@ class ArrayPjitTest(jtu.JaxTestCase): ' platform.*'): f(jnp.arange(8)) + def test_no_output_multiple_devices(self): + mesh = jtu.create_global_mesh((2,), ('x',)) + + @pjit + def f(): + return + + with mesh: + f() # doesn't crash + class TempSharding(Sharding):