mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Enable jax.device_put
to a sharding with no local devices.
PiperOrigin-RevId: 737797815
This commit is contained in:
parent
051687dc4c
commit
8c35191725
@ -466,6 +466,9 @@ def _device_put_sharding_impl(x, aval, device, copy):
|
|||||||
if not s.is_fully_addressable:
|
if not s.is_fully_addressable:
|
||||||
if ((isinstance(x, array.ArrayImpl) and not x._committed) or
|
if ((isinstance(x, array.ArrayImpl) and not x._committed) or
|
||||||
type(x) in array_types):
|
type(x) in array_types):
|
||||||
|
# TODO(emilyaf): Remove this condition when jit works when a sharding
|
||||||
|
# has no local devices.
|
||||||
|
if not config.enable_empty_arrays.value:
|
||||||
multihost_utils.assert_equal(
|
multihost_utils.assert_equal(
|
||||||
x, fail_message=(
|
x, fail_message=(
|
||||||
f"{type(x)} passed to device_put is not the same on each"
|
f"{type(x)} passed to device_put is not the same on each"
|
||||||
|
@ -237,7 +237,7 @@ def batched_device_put(aval: core.ShapedArray,
|
|||||||
if (isinstance(x, array.ArrayImpl) and
|
if (isinstance(x, array.ArrayImpl) and
|
||||||
dispatch.is_single_device_sharding(x.sharding) and
|
dispatch.is_single_device_sharding(x.sharding) and
|
||||||
x.devices() == {d})]
|
x.devices() == {d})]
|
||||||
if len(bufs) == len(xs):
|
if len(bufs) == len(xs) > 0:
|
||||||
return array.ArrayImpl(
|
return array.ArrayImpl(
|
||||||
aval, sharding, bufs, committed=committed, _skip_checks=True)
|
aval, sharding, bufs, committed=committed, _skip_checks=True)
|
||||||
return xc.batched_device_put(aval, sharding, xs, list(devices), committed)
|
return xc.batched_device_put(aval, sharding, xs, list(devices), committed)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user