If a function returns no output, xla_executable.get_output_shardings() returns 1 sharding because for XLA the output is an empty tuple which has a tuple sharding.

PiperOrigin-RevId: 583555384
This commit is contained in:
Yash Katariya 2023-11-17 20:48:22 -08:00 committed by jax authors
parent 41f0b336e3
commit 493e2f8ae2
2 changed files with 15 additions and 0 deletions

View File

@ -2303,6 +2303,11 @@ def get_gspmd_shardings_from_executable(
if num_ordered_effects > 0:
out_op_shardings = out_op_shardings[num_ordered_effects:]
# This means that there are no outputs for JAX but for XLA there is an empty
# tuple output which gets a replicated sharding.
if num_out_avals == 0 and len(out_op_shardings) == 1:
return None
# This condition happens when all the elements in the output tuple have the
# same sharding, so XLA decides to run the `FusionTupleDeduplicator` to
# put the sharding on ROOT instead of the tuple.

View File

@ -3666,6 +3666,16 @@ class ArrayPjitTest(jtu.JaxTestCase):
' platform.*'):
f(jnp.arange(8))
def test_no_output_multiple_devices(self):
mesh = jtu.create_global_mesh((2,), ('x',))
@pjit
def f():
return
with mesh:
f() # doesn't crash
class TempSharding(Sharding):