Make JAX tests that check for errors from dict key comparators in pytrees more relaxed, in preparation for https://github.com/openxla/xla/pull/9529.

PiperOrigin-RevId: 610819296
This commit is contained in:
Peter Hawkins 2024-02-27 11:29:31 -08:00 committed by jax authors
parent e00149c39f
commit fdbee314d3
2 changed files with 6 additions and 2 deletions

View File

@ -1329,7 +1329,9 @@ class JitTest(jtu.BufferDonationTestCase):
def f(d) -> float:
return d[E.A]
with self.assertRaisesRegex(TypeError, "'<' not supported.*"):
with self.assertRaisesRegex(
(TypeError, ValueError),
"('<' not supported|Comparator raised exception).*"):
f({E.A: 1.0, E.B: 2.0})
def test_jit_static_argnums_requires_type_equality(self):

View File

@ -582,7 +582,9 @@ class TreeTest(jtu.JaxTestCase):
def testDictKeysSortable(self):
d = {"a": 1, 2: "b"}
with self.assertRaisesRegex(TypeError, "'<' not supported"):
with self.assertRaisesRegex(
(TypeError, ValueError),
"('<' not supported|Comparator raised exception).*"):
_, _ = tree_util.tree_flatten(d)
def testFlattenDictKeyOrder(self):