mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Skip gather and reduce scatter grad tests on GPU
Recent changes in XLA:GPU seem to be causing deadlocks. PiperOrigin-RevId: 499832080
This commit is contained in:
parent
904cd4b98d
commit
6655f2ba8d
@ -974,6 +974,8 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
))
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testGradOf(self, prim, tiled, use_axis_index_groups):
|
||||
if jtu.device_under_test() == "gpu":
|
||||
raise SkipTest("XLA:GPU with ReduceScatter deadlocks") # b/264516146
|
||||
axis_index_groups = None
|
||||
devices = jax.devices()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user