Merge pull request #10294 from jakevdp:fix-gpu-test

PiperOrigin-RevId: 441855902
This commit is contained in:
jax authors 2022-04-14 14:31:07 -07:00
commit b290b6eaa3

View File

@ -442,6 +442,7 @@ class cuSparseTest(jtu.JaxTestCase):
self.assertArraysEqual(mat.todense(), mat_resorted.todense())
@unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse")
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
@jtu.skip_on_devices("rocm") # TODO(rocm): see SWDEV-328107
def test_coo_sorted_indices_gpu_lowerings(self):
dtype = jnp.float32