If callback returns a fully replicated global array, return it as is.

Also take the batched_device_put fast path for non-jax.Array's since slicing can return arrays on multiple devices which batched_device_put doesn't support.

PiperOrigin-RevId: 624763603
This commit is contained in:
Yash Katariya 2024-04-14 14:35:13 -07:00 committed by jax authors
parent 4a6ee78f4f
commit 2c85ca6fec
3 changed files with 33 additions and 12 deletions

View File

@ -695,13 +695,8 @@ def make_array_from_callback(
>>> arr.addressable_data(0).shape
(4, 2)
"""
has_device_assignment = False
if sharding.is_fully_replicated:
if isinstance(sharding, XLACompatibleSharding):
devices = list(sharding._addressable_device_assignment)
has_device_assignment = True
else:
devices = list(sharding.addressable_devices)
devices = list(sharding._internal_device_list.addressable_device_list) # type: ignore
per_device_values = [data_callback((slice(None),) * len(shape))] * len(devices)
else:
device_to_index_map = sharding.addressable_devices_indices_map(shape)
@ -716,13 +711,11 @@ def make_array_from_callback(
first_value = xla.canonicalize_dtype(per_device_values[0])
aval = core.ShapedArray(shape, first_value.dtype, weak_type=False)
# TODO(yashkatariya): Look into taking this path for non-fully replicated
# shardings too.
if (sharding.is_fully_replicated and has_device_assignment and
not dtypes.issubdtype(aval.dtype, dtypes.extended)):
# first value can be numpy array, python scalar, etc.
if (sharding.is_fully_replicated and not isinstance(first_value, ArrayImpl)
and not dtypes.issubdtype(aval.dtype, dtypes.extended)):
# Do this check outside because `batched_device_put` won't do these checks
# like ArrayImpl. This is a fast path for fully replicated arrays with
# xla compatible sharding.
# like ArrayImpl.
if shape != first_value.shape:
raise ValueError(
f"Expected shard shape {shape} doesn't match the single device "
@ -731,6 +724,11 @@ def make_array_from_callback(
return pxla.batched_device_put(
aval, sharding, per_device_values, devices, committed=True)
if (sharding.is_fully_replicated and isinstance(first_value, ArrayImpl) and
first_value.is_fully_replicated and
first_value.sharding._device_assignment == devices):
return first_value
arrays = api.device_put(per_device_values, devices)
if dtypes.issubdtype(aval.dtype, dtypes.extended):
return aval.dtype._rules.make_sharded_array(aval, sharding, arrays,

View File

@ -794,6 +794,24 @@ class JaxArrayTest(jtu.JaxTestCase):
for shard in x.addressable_shards:
self.assertEqual(shard.data.dtype, dtype)
def test_make_array_from_callback_global_array(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P())
np_inp = np.arange(16).reshape(8, 2)
arr = jax.device_put(np_inp, sharding)
out = jax.make_array_from_callback(np_inp.shape, sharding,
lambda idx: arr[idx])
self.assertArraysEqual(out, arr)
self.assertEqual(out.sharding, sharding)
sharding2 = NamedSharding(mesh, P('x', 'y'))
arr2 = jax.device_put(np_inp, sharding2)
out2 = jax.make_array_from_callback(np_inp.shape, sharding2,
lambda idx: arr2[idx])
self.assertArraysEqual(out2, arr2)
self.assertEqual(out2.sharding, sharding2)
class ShardingTest(jtu.JaxTestCase):

View File

@ -4062,6 +4062,7 @@ class TempSharding(Sharding):
if xla_extension_version >= 235:
super().__init__()
self._devices = devices
self._internal_device_list = xc.DeviceList(tuple(self._devices))
@property
def device_set(self):
@ -4073,6 +4074,10 @@ class TempSharding(Sharding):
def shard_shape(self, global_shape):
return global_shape
@property
def memory_kind(self):
return None
@property
def is_fully_replicated(self):
return True