diff --git a/tests/lax_scipy_sparse_test.py b/tests/lax_scipy_sparse_test.py index 0216be48f..9114592c7 100644 --- a/tests/lax_scipy_sparse_test.py +++ b/tests/lax_scipy_sparse_test.py @@ -253,6 +253,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase): for dtype in float_types + complex_types for preconditioner in [None, 'identity', 'exact'] )) + @jtu.skip_on_devices("gpu") def test_bicgstab_on_identity_system(self, shape, dtype, preconditioner): A = jnp.eye(shape[1], dtype=dtype) solution = jnp.ones(shape[1], dtype=dtype) @@ -277,6 +278,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase): for dtype in float_types + complex_types for preconditioner in [None, 'identity', 'exact'] )) + @jtu.skip_on_devices("gpu") def test_bicgstab_on_random_system(self, shape, dtype, preconditioner): rng = jtu.rand_default(self.rng()) A = rng(shape, dtype) @@ -365,6 +367,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase): for preconditioner in [None, 'identity', 'exact'] for solve_method in ['batched', 'incremental'] )) + @jtu.skip_on_devices("gpu") def test_gmres_on_identity_system(self, shape, dtype, preconditioner, solve_method): A = jnp.eye(shape[1], dtype=dtype) @@ -395,6 +398,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase): for preconditioner in [None, 'identity', 'exact'] for solve_method in ['incremental', 'batched'] )) + @jtu.skip_on_devices("gpu") def test_gmres_on_random_system(self, shape, dtype, preconditioner, solve_method): rng = jtu.rand_default(self.rng())