Fix process_allgather of global jax.Arrays with shardy

PiperOrigin-RevId: 738823617
This commit is contained in:
Yash Katariya 2025-03-20 09:06:27 -07:00 committed by jax authors
parent dad1b41f7b
commit 1ec0585361

View File

@ -99,8 +99,11 @@ def _identity_fn(x):
def _handle_array_process_allgather(inp, tiled):
if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable:
reps = sharding_impls.GSPMDSharding.get_replicated(
inp.sharding._device_assignment)
if isinstance(inp.sharding, sharding_impls.NamedSharding):
reps = inp.sharding.with_spec(P())
else:
reps = sharding_impls.GSPMDSharding.get_replicated(
inp.sharding._device_assignment, memory_kind=inp.sharding.memory_kind)
out = jax.jit(_identity_fn, out_shardings=reps)(inp)
else:
# All inputs here will be fully addressable.