Add a missing jaxlib version check in ragged_collective_test

PiperOrigin-RevId: 725186144
This commit is contained in:
Adam Paszke 2025-02-10 06:11:55 -08:00 committed by jax authors
parent 91c9bbfa98
commit 26d8e112e3

View File

@ -37,6 +37,8 @@ class RaggedCollectiveTest(jtu.JaxTestCase):
super().setUp()
if jtu.test_device_matches(['cpu']):
self.skipTest('ragged-all-to-all is not supported on CPU')
if jtu.jaxlib_version() < (0, 5, 1):
self.skipTest('ragged-all-to-all is not supported on jaxlib version < 0.5.1')
@parameterized.named_parameters(
dict(