mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Avoid IndexError when constructing a ValueError for a DeviceAssignmentMismatchError.
_get_arg_names was throwing IndexError when handling functions with variadic args. PiperOrigin-RevId: 537308439
This commit is contained in:
parent
2e12add64a
commit
cfabad5886
@ -120,7 +120,10 @@ def _get_arg_names(fun, in_tree, args_flat):
|
||||
ak, *rem_keys = arg_key
|
||||
if sig is not None:
|
||||
loc = ''.join(str(k) for k in rem_keys)
|
||||
arg_name = f'{list(sig.arguments.keys())[ak.idx]}{loc}'
|
||||
try:
|
||||
arg_name = f'{list(sig.arguments.keys())[ak.idx]}{loc}'
|
||||
except IndexError:
|
||||
arg_name = '' # E.g. variadic positional argument.
|
||||
else:
|
||||
arg_name = ''
|
||||
arg_names.append(arg_name)
|
||||
|
@ -2035,6 +2035,18 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
r"argument y of.*\<lambda\> with shape int.*\[3\] and device ids \[1\].*"):
|
||||
pjit(lambda x, y: (x, y))(a, b)
|
||||
|
||||
def test_pjit_committed_array_different_devices_variadic_args(self):
|
||||
if jax.device_count() < 2:
|
||||
self.skipTest('Test requires >= 2 devices')
|
||||
a = jax.device_put(np.array([1, 2, 3]), jax.devices()[0])
|
||||
b = jax.device_put(np.array([4, 5, 6]), jax.devices()[1])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Received incompatible devices for pjitted computation. Got argument "
|
||||
r"x of.*\<lambda\> with shape int.*\[3\] and device ids \[0\].*and "
|
||||
r"argument of.*\<lambda\> with shape int.*\[3\] and device ids \[1\].*"):
|
||||
pjit(lambda *x: x)(a, b)
|
||||
|
||||
def test_pjit_pytree_inp_device_assignment_mismatch(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
a = jax.device_put(np.array([1, 2, 3]), jax.devices()[0])
|
||||
|
Loading…
x
Reference in New Issue
Block a user