mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00
Make JAX tests that check for errors from dict key comparators in pytrees more relaxed, in preparation for https://github.com/openxla/xla/pull/9529.
PiperOrigin-RevId: 610819296
This commit is contained in:
parent
e00149c39f
commit
fdbee314d3
@ -1329,7 +1329,9 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
def f(d) -> float:
|
||||
return d[E.A]
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "'<' not supported.*"):
|
||||
with self.assertRaisesRegex(
|
||||
(TypeError, ValueError),
|
||||
"('<' not supported|Comparator raised exception).*"):
|
||||
f({E.A: 1.0, E.B: 2.0})
|
||||
|
||||
def test_jit_static_argnums_requires_type_equality(self):
|
||||
|
@ -582,7 +582,9 @@ class TreeTest(jtu.JaxTestCase):
|
||||
|
||||
def testDictKeysSortable(self):
|
||||
d = {"a": 1, 2: "b"}
|
||||
with self.assertRaisesRegex(TypeError, "'<' not supported"):
|
||||
with self.assertRaisesRegex(
|
||||
(TypeError, ValueError),
|
||||
"('<' not supported|Comparator raised exception).*"):
|
||||
_, _ = tree_util.tree_flatten(d)
|
||||
|
||||
def testFlattenDictKeyOrder(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user