Enable jax.device_put to a sharding with no local devices.

PiperOrigin-RevId: 737797815
This commit is contained in:
Emily Fertig 2025-03-17 16:48:57 -07:00 committed by jax authors
parent 051687dc4c
commit 8c35191725
2 changed files with 9 additions and 6 deletions

View File

@ -466,6 +466,9 @@ def _device_put_sharding_impl(x, aval, device, copy):
if not s.is_fully_addressable: if not s.is_fully_addressable:
if ((isinstance(x, array.ArrayImpl) and not x._committed) or if ((isinstance(x, array.ArrayImpl) and not x._committed) or
type(x) in array_types): type(x) in array_types):
# TODO(emilyaf): Remove this condition when jit works when a sharding
# has no local devices.
if not config.enable_empty_arrays.value:
multihost_utils.assert_equal( multihost_utils.assert_equal(
x, fail_message=( x, fail_message=(
f"{type(x)} passed to device_put is not the same on each" f"{type(x)} passed to device_put is not the same on each"

View File

@ -237,7 +237,7 @@ def batched_device_put(aval: core.ShapedArray,
if (isinstance(x, array.ArrayImpl) and if (isinstance(x, array.ArrayImpl) and
dispatch.is_single_device_sharding(x.sharding) and dispatch.is_single_device_sharding(x.sharding) and
x.devices() == {d})] x.devices() == {d})]
if len(bufs) == len(xs): if len(bufs) == len(xs) > 0:
return array.ArrayImpl( return array.ArrayImpl(
aval, sharding, bufs, committed=committed, _skip_checks=True) aval, sharding, bufs, committed=committed, _skip_checks=True)
return xc.batched_device_put(aval, sharding, xs, list(devices), committed) return xc.batched_device_put(aval, sharding, xs, list(devices), committed)