Avoid IndexError when constructing a ValueError for a DeviceAssignmentMismatchError.

_get_arg_names was throwing IndexError when handling functions with variadic args.

PiperOrigin-RevId: 537308439
This commit is contained in:
André Susano Pinto 2023-06-02 07:43:21 -07:00 committed by jax authors
parent 2e12add64a
commit cfabad5886
2 changed files with 16 additions and 1 deletions

View File

@ -120,7 +120,10 @@ def _get_arg_names(fun, in_tree, args_flat):
ak, *rem_keys = arg_key
if sig is not None:
loc = ''.join(str(k) for k in rem_keys)
arg_name = f'{list(sig.arguments.keys())[ak.idx]}{loc}'
try:
arg_name = f'{list(sig.arguments.keys())[ak.idx]}{loc}'
except IndexError:
arg_name = '' # E.g. variadic positional argument.
else:
arg_name = ''
arg_names.append(arg_name)

View File

@ -2035,6 +2035,18 @@ class ArrayPjitTest(jtu.JaxTestCase):
r"argument y of.*\<lambda\> with shape int.*\[3\] and device ids \[1\].*"):
pjit(lambda x, y: (x, y))(a, b)
def test_pjit_committed_array_different_devices_variadic_args(self):
if jax.device_count() < 2:
self.skipTest('Test requires >= 2 devices')
a = jax.device_put(np.array([1, 2, 3]), jax.devices()[0])
b = jax.device_put(np.array([4, 5, 6]), jax.devices()[1])
with self.assertRaisesRegex(
ValueError,
"Received incompatible devices for pjitted computation. Got argument "
r"x of.*\<lambda\> with shape int.*\[3\] and device ids \[0\].*and "
r"argument of.*\<lambda\> with shape int.*\[3\] and device ids \[1\].*"):
pjit(lambda *x: x)(a, b)
def test_pjit_pytree_inp_device_assignment_mismatch(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
a = jax.device_put(np.array([1, 2, 3]), jax.devices()[0])