mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
update pjit test
This commit is contained in:
parent
cea2b6b6f8
commit
358775f901
@ -3405,12 +3405,9 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
# pjit(lambda x: x, p, p)([x, x, x]) # Error, but make sure we hint at singleton tuple
|
||||
|
||||
error = re.escape(
|
||||
"pytree structure error: different numbers of pytree children at "
|
||||
"pytree structure error: different lengths of list at "
|
||||
"key path\n"
|
||||
" pjit out_axis_resources tree root\n"
|
||||
"At that key path, the prefix pytree pjit out_axis_resources has a "
|
||||
"subtree of type\n"
|
||||
" <class 'list'>\n")
|
||||
" pjit out_axis_resources tree root\n")
|
||||
with self.assertRaisesRegex(ValueError, error):
|
||||
pjit(lambda x: x, (p,), [p, None])([x, x, x]) # Error, we raise a generic tree mismatch message
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user