Return arrays from ArrayImpl._check_and_rearrange.

This is in preparation for a larger change, so that input buffers can be checked before Array creation in XLA and the user gets more helpful JAX error messages instead of XLA errors.

PiperOrigin-RevId: 721412760
This commit is contained in:
Emily Fertig 2025-01-30 09:10:12 -08:00 committed by jax authors
parent d8f3b33ae4
commit bb951136e9
2 changed files with 90 additions and 38 deletions

View File

@ -39,6 +39,7 @@ from jax._src.interpreters import xla
from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension as xe
from jax._src.lib import xla_extension_version
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
PmapSharding, SingleDeviceSharding,
@ -55,7 +56,10 @@ PRNGKeyArray = Any # TODO(jakevdp): fix cycles and import this.
def _get_device(a: ArrayImpl) -> Device:
devices = a.sharding._internal_device_list # pytype: disable=attribute-error
assert len(devices) == 1
if len(devices) != 1:
raise ValueError(
"When making an array from single-device arrays the input arrays must "
f"have one shard each. An argument array had {len(devices)} shard(s).")
return devices[0]
@ -195,54 +199,102 @@ class ArrayImpl(basearray.Array):
self.aval = aval
self._sharding = sharding
self._arrays = [a._arrays[0] for a in arrays]
self._committed = committed
self._npy_value = None
arrays = [a._arrays[0] for a in arrays]
# Don't rearrange if skip_checks is enabled because this assumes that the
# input buffers are already arranged properly. This usually happens when
# Array's are created as output of a JAX transformation
# (like pjit, etc).
if not _skip_checks or config.enable_checks.value:
self._check_and_rearrange()
arrays = self._check_and_rearrange(arrays, self._sharding, self.aval)
self._arrays = arrays # type: ignore
def _check_and_rearrange(self):
device_id_to_buffer = {_get_device(db).id: db for db in self._arrays}
if xla_extension_version >= 308:
def _check_and_rearrange(self, arrays, sharding, aval):
device_id_to_buffer = {_get_device(db).id: db for db in arrays}
addressable_dev = self.sharding.addressable_devices
if len(self._arrays) != len(addressable_dev):
raise ValueError(
f"Expected {len(addressable_dev)} per-device arrays "
"(this is how many devices are addressable by the sharding), but "
f"got {len(self._arrays)}")
addressable_dev = sharding.addressable_devices
if len(arrays) != len(addressable_dev):
raise ValueError(
f"Expected {len(addressable_dev)} per-device arrays "
"(this is how many devices are addressable by the sharding), but "
f"got {len(arrays)}")
array_device_ids = set(device_id_to_buffer.keys())
addressable_device_ids = {d.id for d in addressable_dev}
# Calculate a symmetric difference because the device ids between sharding
# and _arrays should match.
diff = array_device_ids ^ addressable_device_ids
if diff:
dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids
dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids
err_msg = (
"Addressable devices and per-device arrays devices do not match.")
if dev_in_sharding_not_in_arrays:
err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} "
"that are not present in per-device arrays.")
if dev_in_arrays_not_in_sharding:
err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} "
"that are not present in the sharding.")
raise ValueError(err_msg)
array_device_ids = set(device_id_to_buffer.keys())
addressable_device_ids = {d.id for d in addressable_dev}
if len(array_device_ids) != len(arrays):
buffer_device_ids = [_get_device(db).id for db in arrays]
raise ValueError(
"When making an array from single-device arrays, the input arrays"
" must be from distinct devices, but got device IDs"
f" {buffer_device_ids}")
_validate_shape_and_dtype_for_per_device_arrays(
self._arrays,
sharding=self.sharding,
aval=self.aval,
expected_shape=self.sharding.shard_shape(self.shape),
)
# Rearrange arrays based on the device assignment.
addressable_da = self.sharding._addressable_device_assignment
self._arrays = [device_id_to_buffer[device.id] for device in addressable_da]
# Calculate a symmetric difference because the device ids between sharding
# and _arrays should match.
diff = array_device_ids ^ addressable_device_ids
if diff:
dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids
dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids
err_msg = (
"Addressable devices and per-device arrays devices do not match.")
if dev_in_sharding_not_in_arrays:
err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} "
"that are not present in per-device arrays.")
if dev_in_arrays_not_in_sharding:
err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} "
"that are not present in the sharding.")
raise ValueError(err_msg)
_validate_shape_and_dtype_for_per_device_arrays(
arrays,
sharding=sharding,
aval=aval,
expected_shape=sharding.shard_shape(aval.shape),
)
# Rearrange arrays based on the device assignment.
addressable_da = sharding._addressable_device_assignment
return [device_id_to_buffer[device.id] for device in addressable_da]
else:
def _check_and_rearrange(self): # type: ignore
device_id_to_buffer = {_get_device(db).id: db for db in self._arrays}
addressable_dev = self.sharding.addressable_devices
if len(self._arrays) != len(addressable_dev):
raise ValueError(
f"Expected {len(addressable_dev)} per-device arrays "
"(this is how many devices are addressable by the sharding), but "
f"got {len(self._arrays)}")
array_device_ids = set(device_id_to_buffer.keys())
addressable_device_ids = {d.id for d in addressable_dev}
# Calculate a symmetric difference because the device ids between sharding
# and _arrays should match.
diff = array_device_ids ^ addressable_device_ids
if diff:
dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids
dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids
err_msg = (
"Addressable devices and per-device arrays devices do not match.")
if dev_in_sharding_not_in_arrays:
err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} "
"that are not present in per-device arrays.")
if dev_in_arrays_not_in_sharding:
err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} "
"that are not present in the sharding.")
raise ValueError(err_msg)
_validate_shape_and_dtype_for_per_device_arrays(
self._arrays,
sharding=self.sharding,
aval=self.aval,
expected_shape=self.sharding.shard_shape(self.shape),
)
# Rearrange arrays based on the device assignment.
addressable_da = self.sharding._addressable_device_assignment
self._arrays = [device_id_to_buffer[device.id] for device in addressable_da]
@property
def shape(self) -> Shape:

View File

@ -374,7 +374,7 @@ class JaxArrayTest(jtu.JaxTestCase):
# Sharding device ids = {0, 1}
s = jax.sharding.NamedSharding(mesh, P('x'))
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
# _arrays device ids = {0, 2}
# _arrays device ids = {0, 0}
bufs = [jax.device_put(inp_data, jax.devices()[0]) for _ in range(2)]
with self.assertRaisesRegex(
ValueError,