From 8c351917256ffbf48e34d983104b58d2fa2f3e92 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Mon, 17 Mar 2025 16:48:57 -0700 Subject: [PATCH] Enable `jax.device_put` to a sharding with no local devices. PiperOrigin-RevId: 737797815 --- jax/_src/dispatch.py | 13 ++++++++----- jax/_src/interpreters/pxla.py | 2 +- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 050d6c394..2330f7628 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -466,11 +466,14 @@ def _device_put_sharding_impl(x, aval, device, copy): if not s.is_fully_addressable: if ((isinstance(x, array.ArrayImpl) and not x._committed) or type(x) in array_types): - multihost_utils.assert_equal( - x, fail_message=( - f"{type(x)} passed to device_put is not the same on each" - " process. Make sure you are passing the same value of" - f" {type(x)} on each process.")) + # TODO(emilyaf): Remove this condition when jit works when a sharding + # has no local devices. + if not config.enable_empty_arrays.value: + multihost_utils.assert_equal( + x, fail_message=( + f"{type(x)} passed to device_put is not the same on each" + " process. Make sure you are passing the same value of" + f" {type(x)} on each process.")) return _DeferredShardArg(x, s, aval, True, copy) # TODO(yashkatariya,mattjj): Link to a doc about McJAX and jax.Array. raise ValueError( diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index e28896802..c06eda521 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -237,7 +237,7 @@ def batched_device_put(aval: core.ShapedArray, if (isinstance(x, array.ArrayImpl) and dispatch.is_single_device_sharding(x.sharding) and x.devices() == {d})] - if len(bufs) == len(xs): + if len(bufs) == len(xs) > 0: return array.ArrayImpl( aval, sharding, bufs, committed=committed, _skip_checks=True) return xc.batched_device_put(aval, sharding, xs, list(devices), committed)