mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Return arrays from ArrayImpl._check_and_rearrange
.
This is in preparation for a larger change, so that input buffers can be checked before Array creation in XLA and the user gets more helpful JAX error messages instead of XLA errors. PiperOrigin-RevId: 721412760
This commit is contained in:
parent
d8f3b33ae4
commit
bb951136e9
@ -39,6 +39,7 @@ from jax._src.interpreters import xla
|
||||
from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension as xe
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
PmapSharding, SingleDeviceSharding,
|
||||
@ -55,7 +56,10 @@ PRNGKeyArray = Any # TODO(jakevdp): fix cycles and import this.
|
||||
|
||||
def _get_device(a: ArrayImpl) -> Device:
|
||||
devices = a.sharding._internal_device_list # pytype: disable=attribute-error
|
||||
assert len(devices) == 1
|
||||
if len(devices) != 1:
|
||||
raise ValueError(
|
||||
"When making an array from single-device arrays the input arrays must "
|
||||
f"have one shard each. An argument array had {len(devices)} shard(s).")
|
||||
return devices[0]
|
||||
|
||||
|
||||
@ -195,54 +199,102 @@ class ArrayImpl(basearray.Array):
|
||||
|
||||
self.aval = aval
|
||||
self._sharding = sharding
|
||||
self._arrays = [a._arrays[0] for a in arrays]
|
||||
self._committed = committed
|
||||
self._npy_value = None
|
||||
arrays = [a._arrays[0] for a in arrays]
|
||||
|
||||
# Don't rearrange if skip_checks is enabled because this assumes that the
|
||||
# input buffers are already arranged properly. This usually happens when
|
||||
# Array's are created as output of a JAX transformation
|
||||
# (like pjit, etc).
|
||||
if not _skip_checks or config.enable_checks.value:
|
||||
self._check_and_rearrange()
|
||||
arrays = self._check_and_rearrange(arrays, self._sharding, self.aval)
|
||||
self._arrays = arrays # type: ignore
|
||||
|
||||
def _check_and_rearrange(self):
|
||||
device_id_to_buffer = {_get_device(db).id: db for db in self._arrays}
|
||||
if xla_extension_version >= 308:
|
||||
def _check_and_rearrange(self, arrays, sharding, aval):
|
||||
device_id_to_buffer = {_get_device(db).id: db for db in arrays}
|
||||
|
||||
addressable_dev = self.sharding.addressable_devices
|
||||
if len(self._arrays) != len(addressable_dev):
|
||||
raise ValueError(
|
||||
f"Expected {len(addressable_dev)} per-device arrays "
|
||||
"(this is how many devices are addressable by the sharding), but "
|
||||
f"got {len(self._arrays)}")
|
||||
addressable_dev = sharding.addressable_devices
|
||||
if len(arrays) != len(addressable_dev):
|
||||
raise ValueError(
|
||||
f"Expected {len(addressable_dev)} per-device arrays "
|
||||
"(this is how many devices are addressable by the sharding), but "
|
||||
f"got {len(arrays)}")
|
||||
|
||||
array_device_ids = set(device_id_to_buffer.keys())
|
||||
addressable_device_ids = {d.id for d in addressable_dev}
|
||||
# Calculate a symmetric difference because the device ids between sharding
|
||||
# and _arrays should match.
|
||||
diff = array_device_ids ^ addressable_device_ids
|
||||
if diff:
|
||||
dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids
|
||||
dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids
|
||||
err_msg = (
|
||||
"Addressable devices and per-device arrays devices do not match.")
|
||||
if dev_in_sharding_not_in_arrays:
|
||||
err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} "
|
||||
"that are not present in per-device arrays.")
|
||||
if dev_in_arrays_not_in_sharding:
|
||||
err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} "
|
||||
"that are not present in the sharding.")
|
||||
raise ValueError(err_msg)
|
||||
array_device_ids = set(device_id_to_buffer.keys())
|
||||
addressable_device_ids = {d.id for d in addressable_dev}
|
||||
if len(array_device_ids) != len(arrays):
|
||||
buffer_device_ids = [_get_device(db).id for db in arrays]
|
||||
raise ValueError(
|
||||
"When making an array from single-device arrays, the input arrays"
|
||||
" must be from distinct devices, but got device IDs"
|
||||
f" {buffer_device_ids}")
|
||||
|
||||
_validate_shape_and_dtype_for_per_device_arrays(
|
||||
self._arrays,
|
||||
sharding=self.sharding,
|
||||
aval=self.aval,
|
||||
expected_shape=self.sharding.shard_shape(self.shape),
|
||||
)
|
||||
# Rearrange arrays based on the device assignment.
|
||||
addressable_da = self.sharding._addressable_device_assignment
|
||||
self._arrays = [device_id_to_buffer[device.id] for device in addressable_da]
|
||||
# Calculate a symmetric difference because the device ids between sharding
|
||||
# and _arrays should match.
|
||||
diff = array_device_ids ^ addressable_device_ids
|
||||
if diff:
|
||||
dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids
|
||||
dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids
|
||||
err_msg = (
|
||||
"Addressable devices and per-device arrays devices do not match.")
|
||||
if dev_in_sharding_not_in_arrays:
|
||||
err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} "
|
||||
"that are not present in per-device arrays.")
|
||||
if dev_in_arrays_not_in_sharding:
|
||||
err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} "
|
||||
"that are not present in the sharding.")
|
||||
raise ValueError(err_msg)
|
||||
|
||||
_validate_shape_and_dtype_for_per_device_arrays(
|
||||
arrays,
|
||||
sharding=sharding,
|
||||
aval=aval,
|
||||
expected_shape=sharding.shard_shape(aval.shape),
|
||||
)
|
||||
|
||||
# Rearrange arrays based on the device assignment.
|
||||
addressable_da = sharding._addressable_device_assignment
|
||||
return [device_id_to_buffer[device.id] for device in addressable_da]
|
||||
else:
|
||||
def _check_and_rearrange(self): # type: ignore
|
||||
device_id_to_buffer = {_get_device(db).id: db for db in self._arrays}
|
||||
|
||||
addressable_dev = self.sharding.addressable_devices
|
||||
if len(self._arrays) != len(addressable_dev):
|
||||
raise ValueError(
|
||||
f"Expected {len(addressable_dev)} per-device arrays "
|
||||
"(this is how many devices are addressable by the sharding), but "
|
||||
f"got {len(self._arrays)}")
|
||||
|
||||
array_device_ids = set(device_id_to_buffer.keys())
|
||||
addressable_device_ids = {d.id for d in addressable_dev}
|
||||
# Calculate a symmetric difference because the device ids between sharding
|
||||
# and _arrays should match.
|
||||
diff = array_device_ids ^ addressable_device_ids
|
||||
if diff:
|
||||
dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids
|
||||
dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids
|
||||
err_msg = (
|
||||
"Addressable devices and per-device arrays devices do not match.")
|
||||
if dev_in_sharding_not_in_arrays:
|
||||
err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} "
|
||||
"that are not present in per-device arrays.")
|
||||
if dev_in_arrays_not_in_sharding:
|
||||
err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} "
|
||||
"that are not present in the sharding.")
|
||||
raise ValueError(err_msg)
|
||||
|
||||
_validate_shape_and_dtype_for_per_device_arrays(
|
||||
self._arrays,
|
||||
sharding=self.sharding,
|
||||
aval=self.aval,
|
||||
expected_shape=self.sharding.shard_shape(self.shape),
|
||||
)
|
||||
# Rearrange arrays based on the device assignment.
|
||||
addressable_da = self.sharding._addressable_device_assignment
|
||||
self._arrays = [device_id_to_buffer[device.id] for device in addressable_da]
|
||||
|
||||
@property
|
||||
def shape(self) -> Shape:
|
||||
|
@ -374,7 +374,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
# Sharding device ids = {0, 1}
|
||||
s = jax.sharding.NamedSharding(mesh, P('x'))
|
||||
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
# _arrays device ids = {0, 2}
|
||||
# _arrays device ids = {0, 0}
|
||||
bufs = [jax.device_put(inp_data, jax.devices()[0]) for _ in range(2)]
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
|
Loading…
x
Reference in New Issue
Block a user