diff --git a/tests/api_test.py b/tests/api_test.py index 5aef53b27..23c365006 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1329,7 +1329,9 @@ class JitTest(jtu.BufferDonationTestCase): def f(d) -> float: return d[E.A] - with self.assertRaisesRegex(TypeError, "'<' not supported.*"): + with self.assertRaisesRegex( + (TypeError, ValueError), + "('<' not supported|Comparator raised exception).*"): f({E.A: 1.0, E.B: 2.0}) def test_jit_static_argnums_requires_type_equality(self): diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index 360bfdeed..f6204ff3b 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -582,7 +582,9 @@ class TreeTest(jtu.JaxTestCase): def testDictKeysSortable(self): d = {"a": 1, 2: "b"} - with self.assertRaisesRegex(TypeError, "'<' not supported"): + with self.assertRaisesRegex( + (TypeError, ValueError), + "('<' not supported|Comparator raised exception).*"): _, _ = tree_util.tree_flatten(d) def testFlattenDictKeyOrder(self):