mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
If callback returns a fully replicated global array, return it as is.
Also take the batched_device_put fast path for non-jax.Array's since slicing can return arrays on multiple devices which batched_device_put doesn't support. PiperOrigin-RevId: 624763603
This commit is contained in:
parent
4a6ee78f4f
commit
2c85ca6fec
@ -695,13 +695,8 @@ def make_array_from_callback(
|
||||
>>> arr.addressable_data(0).shape
|
||||
(4, 2)
|
||||
"""
|
||||
has_device_assignment = False
|
||||
if sharding.is_fully_replicated:
|
||||
if isinstance(sharding, XLACompatibleSharding):
|
||||
devices = list(sharding._addressable_device_assignment)
|
||||
has_device_assignment = True
|
||||
else:
|
||||
devices = list(sharding.addressable_devices)
|
||||
devices = list(sharding._internal_device_list.addressable_device_list) # type: ignore
|
||||
per_device_values = [data_callback((slice(None),) * len(shape))] * len(devices)
|
||||
else:
|
||||
device_to_index_map = sharding.addressable_devices_indices_map(shape)
|
||||
@ -716,13 +711,11 @@ def make_array_from_callback(
|
||||
first_value = xla.canonicalize_dtype(per_device_values[0])
|
||||
aval = core.ShapedArray(shape, first_value.dtype, weak_type=False)
|
||||
|
||||
# TODO(yashkatariya): Look into taking this path for non-fully replicated
|
||||
# shardings too.
|
||||
if (sharding.is_fully_replicated and has_device_assignment and
|
||||
not dtypes.issubdtype(aval.dtype, dtypes.extended)):
|
||||
# first value can be numpy array, python scalar, etc.
|
||||
if (sharding.is_fully_replicated and not isinstance(first_value, ArrayImpl)
|
||||
and not dtypes.issubdtype(aval.dtype, dtypes.extended)):
|
||||
# Do this check outside because `batched_device_put` won't do these checks
|
||||
# like ArrayImpl. This is a fast path for fully replicated arrays with
|
||||
# xla compatible sharding.
|
||||
# like ArrayImpl.
|
||||
if shape != first_value.shape:
|
||||
raise ValueError(
|
||||
f"Expected shard shape {shape} doesn't match the single device "
|
||||
@ -731,6 +724,11 @@ def make_array_from_callback(
|
||||
return pxla.batched_device_put(
|
||||
aval, sharding, per_device_values, devices, committed=True)
|
||||
|
||||
if (sharding.is_fully_replicated and isinstance(first_value, ArrayImpl) and
|
||||
first_value.is_fully_replicated and
|
||||
first_value.sharding._device_assignment == devices):
|
||||
return first_value
|
||||
|
||||
arrays = api.device_put(per_device_values, devices)
|
||||
if dtypes.issubdtype(aval.dtype, dtypes.extended):
|
||||
return aval.dtype._rules.make_sharded_array(aval, sharding, arrays,
|
||||
|
@ -794,6 +794,24 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
for shard in x.addressable_shards:
|
||||
self.assertEqual(shard.data.dtype, dtype)
|
||||
|
||||
def test_make_array_from_callback_global_array(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
sharding = jax.sharding.NamedSharding(mesh, P())
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
arr = jax.device_put(np_inp, sharding)
|
||||
|
||||
out = jax.make_array_from_callback(np_inp.shape, sharding,
|
||||
lambda idx: arr[idx])
|
||||
self.assertArraysEqual(out, arr)
|
||||
self.assertEqual(out.sharding, sharding)
|
||||
|
||||
sharding2 = NamedSharding(mesh, P('x', 'y'))
|
||||
arr2 = jax.device_put(np_inp, sharding2)
|
||||
out2 = jax.make_array_from_callback(np_inp.shape, sharding2,
|
||||
lambda idx: arr2[idx])
|
||||
self.assertArraysEqual(out2, arr2)
|
||||
self.assertEqual(out2.sharding, sharding2)
|
||||
|
||||
|
||||
class ShardingTest(jtu.JaxTestCase):
|
||||
|
||||
|
@ -4062,6 +4062,7 @@ class TempSharding(Sharding):
|
||||
if xla_extension_version >= 235:
|
||||
super().__init__()
|
||||
self._devices = devices
|
||||
self._internal_device_list = xc.DeviceList(tuple(self._devices))
|
||||
|
||||
@property
|
||||
def device_set(self):
|
||||
@ -4073,6 +4074,10 @@ class TempSharding(Sharding):
|
||||
def shard_shape(self, global_shape):
|
||||
return global_shape
|
||||
|
||||
@property
|
||||
def memory_kind(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_fully_replicated(self):
|
||||
return True
|
||||
|
Loading…
x
Reference in New Issue
Block a user