mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Enable jax.device_put
to a sharding with no local devices.
PiperOrigin-RevId: 737797815
This commit is contained in:
parent
051687dc4c
commit
8c35191725
@ -466,11 +466,14 @@ def _device_put_sharding_impl(x, aval, device, copy):
|
||||
if not s.is_fully_addressable:
|
||||
if ((isinstance(x, array.ArrayImpl) and not x._committed) or
|
||||
type(x) in array_types):
|
||||
multihost_utils.assert_equal(
|
||||
x, fail_message=(
|
||||
f"{type(x)} passed to device_put is not the same on each"
|
||||
" process. Make sure you are passing the same value of"
|
||||
f" {type(x)} on each process."))
|
||||
# TODO(emilyaf): Remove this condition when jit works when a sharding
|
||||
# has no local devices.
|
||||
if not config.enable_empty_arrays.value:
|
||||
multihost_utils.assert_equal(
|
||||
x, fail_message=(
|
||||
f"{type(x)} passed to device_put is not the same on each"
|
||||
" process. Make sure you are passing the same value of"
|
||||
f" {type(x)} on each process."))
|
||||
return _DeferredShardArg(x, s, aval, True, copy)
|
||||
# TODO(yashkatariya,mattjj): Link to a doc about McJAX and jax.Array.
|
||||
raise ValueError(
|
||||
|
@ -237,7 +237,7 @@ def batched_device_put(aval: core.ShapedArray,
|
||||
if (isinstance(x, array.ArrayImpl) and
|
||||
dispatch.is_single_device_sharding(x.sharding) and
|
||||
x.devices() == {d})]
|
||||
if len(bufs) == len(xs):
|
||||
if len(bufs) == len(xs) > 0:
|
||||
return array.ArrayImpl(
|
||||
aval, sharding, bufs, committed=committed, _skip_checks=True)
|
||||
return xc.batched_device_put(aval, sharding, xs, list(devices), committed)
|
||||
|
Loading…
x
Reference in New Issue
Block a user