Enable jax.device_put to a sharding with no local devices.

PiperOrigin-RevId: 737797815
This commit is contained in:
Emily Fertig 2025-03-17 16:48:57 -07:00 committed by jax authors
parent 051687dc4c
commit 8c35191725
2 changed files with 9 additions and 6 deletions

View File

@ -466,11 +466,14 @@ def _device_put_sharding_impl(x, aval, device, copy):
if not s.is_fully_addressable:
if ((isinstance(x, array.ArrayImpl) and not x._committed) or
type(x) in array_types):
multihost_utils.assert_equal(
x, fail_message=(
f"{type(x)} passed to device_put is not the same on each"
" process. Make sure you are passing the same value of"
f" {type(x)} on each process."))
# TODO(emilyaf): Remove this condition when jit works when a sharding
# has no local devices.
if not config.enable_empty_arrays.value:
multihost_utils.assert_equal(
x, fail_message=(
f"{type(x)} passed to device_put is not the same on each"
" process. Make sure you are passing the same value of"
f" {type(x)} on each process."))
return _DeferredShardArg(x, s, aval, True, copy)
# TODO(yashkatariya,mattjj): Link to a doc about McJAX and jax.Array.
raise ValueError(

View File

@ -237,7 +237,7 @@ def batched_device_put(aval: core.ShapedArray,
if (isinstance(x, array.ArrayImpl) and
dispatch.is_single_device_sharding(x.sharding) and
x.devices() == {d})]
if len(bufs) == len(xs):
if len(bufs) == len(xs) > 0:
return array.ArrayImpl(
aval, sharding, bufs, committed=committed, _skip_checks=True)
return xc.batched_device_put(aval, sharding, xs, list(devices), committed)