mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fix process_allgather of global jax.Arrays with shardy
PiperOrigin-RevId: 738823617
This commit is contained in:
parent
dad1b41f7b
commit
1ec0585361
@ -99,8 +99,11 @@ def _identity_fn(x):
|
||||
|
||||
def _handle_array_process_allgather(inp, tiled):
|
||||
if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable:
|
||||
reps = sharding_impls.GSPMDSharding.get_replicated(
|
||||
inp.sharding._device_assignment)
|
||||
if isinstance(inp.sharding, sharding_impls.NamedSharding):
|
||||
reps = inp.sharding.with_spec(P())
|
||||
else:
|
||||
reps = sharding_impls.GSPMDSharding.get_replicated(
|
||||
inp.sharding._device_assignment, memory_kind=inp.sharding.memory_kind)
|
||||
out = jax.jit(_identity_fn, out_shardings=reps)(inp)
|
||||
else:
|
||||
# All inputs here will be fully addressable.
|
||||
|
Loading…
x
Reference in New Issue
Block a user