Skip gather and reduce scatter grad tests on GPU

Recent changes in XLA:GPU seem to be causing deadlocks.

PiperOrigin-RevId: 499832080
This commit is contained in:
Adam Paszke 2023-01-05 05:19:32 -08:00 committed by jax authors
parent 904cd4b98d
commit 6655f2ba8d

View File

@ -974,6 +974,8 @@ class PythonPmapTest(jtu.JaxTestCase):
))
@ignore_slow_all_to_all_warning()
def testGradOf(self, prim, tiled, use_axis_index_groups):
if jtu.device_under_test() == "gpu":
raise SkipTest("XLA:GPU with ReduceScatter deadlocks") # b/264516146
axis_index_groups = None
devices = jax.devices()