Improve the error message for device mismatch. Print the platform and the device ids rather than the entire device which is not readable.

PiperOrigin-RevId: 480685550
This commit is contained in:
Yash Katariya 2022-10-12 12:11:47 -07:00 committed by jax authors
parent 19e217bdf4
commit 4ea9d2b8df
2 changed files with 23 additions and 9 deletions

View File

@ -2700,15 +2700,25 @@ def _get_and_check_device_assignment(
arr_device_assignment = list(i._device_assignment) # type: ignore
if not devices:
if first_device_assignment != arr_device_assignment:
raise ValueError("Devices of all `Array` inputs and outputs should be "
"the same. "
f"Got array devices: {first_device_assignment},\n "
f"another array devices: {arr_device_assignment}")
p1 = first_device_assignment[0].platform.upper()
fda_ids = [d.id for d in first_device_assignment]
a_ids = [d.id for d in arr_device_assignment]
p2 = arr_device_assignment[0].platform.upper()
raise ValueError(
"Devices of all `Array` inputs and outputs should be "
"the same. "
f"Got array device ids {fda_ids} on platform {p1} and "
f"another array's device ids {a_ids} on platform {p2}")
else:
if devices != arr_device_assignment:
raise ValueError("Pjit's devices and Array's devices should be equal. "
f"Got Pjit devices: {devices},\n "
f"Array devices: {arr_device_assignment}")
p1 = devices[0].platform.upper()
dev_ids = [d.id for d in devices]
a_ids = [d.id for d in arr_device_assignment]
p2 = arr_device_assignment[0].platform.upper()
raise ValueError(
"Pjit's devices and Array's devices should be equal. "
f"Got Pjit's device ids {dev_ids} on platform {p1} and "
f"Array's device ids {a_ids} on platform {p2}")
if first_device_assignment is None and devices:
final_device_assignment = devices
elif first_device_assignment is None:

View File

@ -2278,7 +2278,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
b = jax.device_put(np.array([4, 5, 6]), jax.devices()[1])
with self.assertRaisesRegex(
ValueError,
"Devices of all `Array` inputs and outputs should be the same"):
"Devices of all `Array` inputs and outputs should be the same. "
r"Got array device ids \[0\] on platform.*and "
r"another array's device ids \[1\] on platform"):
pjit(lambda x, y: (x, y))(a, b)
@jax_array(True)
@ -2409,7 +2411,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
ValueError,
"Pjit's devices and Array's devices should be equal."):
"Pjit's devices and Array's devices should be equal. "
r"Got Pjit's device ids \[0\] on platform.*and "
r"Array's device ids \[0, 1, 2, 3\] on platform"):
sharded_zeros((4096, 3072), P('x', 'y'))
@jax_array(True)