mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
If a function returns no output, xla_executable.get_output_shardings() returns 1 sharding because for XLA the output is an empty tuple which has a tuple sharding.
PiperOrigin-RevId: 583555384
This commit is contained in:
parent
41f0b336e3
commit
493e2f8ae2
@ -2303,6 +2303,11 @@ def get_gspmd_shardings_from_executable(
|
||||
if num_ordered_effects > 0:
|
||||
out_op_shardings = out_op_shardings[num_ordered_effects:]
|
||||
|
||||
# This means that there are no outputs for JAX but for XLA there is an empty
|
||||
# tuple output which gets a replicated sharding.
|
||||
if num_out_avals == 0 and len(out_op_shardings) == 1:
|
||||
return None
|
||||
|
||||
# This condition happens when all the elements in the output tuple have the
|
||||
# same sharding, so XLA decides to run the `FusionTupleDeduplicator` to
|
||||
# put the sharding on ROOT instead of the tuple.
|
||||
|
@ -3666,6 +3666,16 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
' platform.*'):
|
||||
f(jnp.arange(8))
|
||||
|
||||
def test_no_output_multiple_devices(self):
|
||||
mesh = jtu.create_global_mesh((2,), ('x',))
|
||||
|
||||
@pjit
|
||||
def f():
|
||||
return
|
||||
|
||||
with mesh:
|
||||
f() # doesn't crash
|
||||
|
||||
|
||||
class TempSharding(Sharding):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user