update pjit test

This commit is contained in:
Matthew Johnson 2023-01-20 11:40:22 -08:00
parent cea2b6b6f8
commit 358775f901

View File

@ -3405,12 +3405,9 @@ class PJitErrorTest(jtu.JaxTestCase):
# pjit(lambda x: x, p, p)([x, x, x]) # Error, but make sure we hint at singleton tuple
error = re.escape(
"pytree structure error: different numbers of pytree children at "
"pytree structure error: different lengths of list at "
"key path\n"
" pjit out_axis_resources tree root\n"
"At that key path, the prefix pytree pjit out_axis_resources has a "
"subtree of type\n"
" <class 'list'>\n")
" pjit out_axis_resources tree root\n")
with self.assertRaisesRegex(ValueError, error):
pjit(lambda x: x, (p,), [p, None])([x, x, x]) # Error, we raise a generic tree mismatch message