mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Improve the error message for device mismatch. Print the platform and the device ids rather than the entire device which is not readable.
PiperOrigin-RevId: 480685550
This commit is contained in:
parent
19e217bdf4
commit
4ea9d2b8df
@ -2700,15 +2700,25 @@ def _get_and_check_device_assignment(
|
||||
arr_device_assignment = list(i._device_assignment) # type: ignore
|
||||
if not devices:
|
||||
if first_device_assignment != arr_device_assignment:
|
||||
raise ValueError("Devices of all `Array` inputs and outputs should be "
|
||||
"the same. "
|
||||
f"Got array devices: {first_device_assignment},\n "
|
||||
f"another array devices: {arr_device_assignment}")
|
||||
p1 = first_device_assignment[0].platform.upper()
|
||||
fda_ids = [d.id for d in first_device_assignment]
|
||||
a_ids = [d.id for d in arr_device_assignment]
|
||||
p2 = arr_device_assignment[0].platform.upper()
|
||||
raise ValueError(
|
||||
"Devices of all `Array` inputs and outputs should be "
|
||||
"the same. "
|
||||
f"Got array device ids {fda_ids} on platform {p1} and "
|
||||
f"another array's device ids {a_ids} on platform {p2}")
|
||||
else:
|
||||
if devices != arr_device_assignment:
|
||||
raise ValueError("Pjit's devices and Array's devices should be equal. "
|
||||
f"Got Pjit devices: {devices},\n "
|
||||
f"Array devices: {arr_device_assignment}")
|
||||
p1 = devices[0].platform.upper()
|
||||
dev_ids = [d.id for d in devices]
|
||||
a_ids = [d.id for d in arr_device_assignment]
|
||||
p2 = arr_device_assignment[0].platform.upper()
|
||||
raise ValueError(
|
||||
"Pjit's devices and Array's devices should be equal. "
|
||||
f"Got Pjit's device ids {dev_ids} on platform {p1} and "
|
||||
f"Array's device ids {a_ids} on platform {p2}")
|
||||
if first_device_assignment is None and devices:
|
||||
final_device_assignment = devices
|
||||
elif first_device_assignment is None:
|
||||
|
@ -2278,7 +2278,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
b = jax.device_put(np.array([4, 5, 6]), jax.devices()[1])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Devices of all `Array` inputs and outputs should be the same"):
|
||||
"Devices of all `Array` inputs and outputs should be the same. "
|
||||
r"Got array device ids \[0\] on platform.*and "
|
||||
r"another array's device ids \[1\] on platform"):
|
||||
pjit(lambda x, y: (x, y))(a, b)
|
||||
|
||||
@jax_array(True)
|
||||
@ -2409,7 +2411,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Pjit's devices and Array's devices should be equal."):
|
||||
"Pjit's devices and Array's devices should be equal. "
|
||||
r"Got Pjit's device ids \[0\] on platform.*and "
|
||||
r"Array's device ids \[0, 1, 2, 3\] on platform"):
|
||||
sharded_zeros((4096, 3072), P('x', 'y'))
|
||||
|
||||
@jax_array(True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user