Clean up some device opt-in/opt-outs in test suite.

Use allowlists rather than denylists in a few places.

PiperOrigin-RevId: 568968749
This commit is contained in:
Peter Hawkins 2023-09-27 14:55:21 -07:00 committed by jax authors
parent 53845615ff
commit 6be860bda8
7 changed files with 31 additions and 39 deletions

View File

@ -74,7 +74,7 @@ class DLPackTest(jtu.JaxTestCase):
def testJaxRoundTrip(self, shape, dtype, take_ownership, gpu):
rng = jtu.rand_default(self.rng())
np = rng(shape, dtype)
if gpu and jax.default_backend() == "cpu":
if gpu and jax.test_device_matches(["cpu"]):
raise unittest.SkipTest("Skipping GPU test case on CPU")
device = jax.devices("gpu" if gpu else "cpu")[0]
x = jax.device_put(np, device)
@ -180,7 +180,7 @@ class DLPackTest(jtu.JaxTestCase):
dtype=numpy_dtypes,
)
@unittest.skipIf(numpy_version < (1, 23, 0), "Requires numpy 1.23 or newer")
@jtu.skip_on_devices("gpu") #NumPy only accepts cpu DLPacks
@jtu.run_on_devices("cpu") # NumPy only accepts cpu DLPacks
def testJaxToNumpy(self, shape, dtype):
rng = jtu.rand_default(self.rng())
x_jax = jnp.array(rng(shape, dtype))
@ -192,7 +192,7 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if not jtu.test_device_matches(["gpu"]):
if not jtu.test_device_matches(["cuda"]):
self.skipTest("__cuda_array_interface__ is only supported on GPU")
@jtu.sample_product(

View File

@ -98,10 +98,8 @@ class DebugPrintTest(jtu.JaxTestCase):
jax.effects_barrier()
self.assertEqual(output(), "x: 2\n")
@jtu.device_supports_buffer_donation()
def test_can_stage_out_debug_print_with_donate_argnums(self):
if jax.default_backend() not in {"gpu", "tpu"}:
raise unittest.SkipTest("Donate argnums not supported.")
def f(x, y):
debug_print('x: {x}', x=x)
return x + y
@ -120,10 +118,8 @@ class DebugPrintTest(jtu.JaxTestCase):
jax.effects_barrier()
self.assertEqual(output(), "x: 2\n")
@jtu.device_supports_buffer_donation()
def test_can_stage_out_ordered_print_with_donate_argnums(self):
if jax.default_backend() not in {"gpu", "tpu"}:
raise unittest.SkipTest("Donate argnums not supported.")
def f(x, y):
debug_print('x: {x}', x=x, ordered=True)
return x + y
@ -133,10 +129,8 @@ class DebugPrintTest(jtu.JaxTestCase):
jax.effects_barrier()
self.assertEqual(output(), "x: 2\n")
@jtu.device_supports_buffer_donation()
def test_can_stage_out_prints_with_donate_argnums(self):
if jax.default_backend() not in {"gpu", "tpu"}:
raise unittest.SkipTest("Donate argnums not supported.")
def f(x, y):
debug_print('x: {x}', x=x, ordered=True)
debug_print('x: {x}', x=x)

View File

@ -1275,7 +1275,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testPoly(self, a_shape, dtype, rank):
if dtype in (np.float16, jnp.bfloat16, np.int16):
self.skipTest(f"{dtype} gets promoted to {np.float16}, which is not supported.")
elif rank == 2 and jtu.test_device_matches(["tpu", "gpu"]):
elif rank == 2 and not jtu.test_device_matches(["cpu"]):
self.skipTest("Nonsymmetric eigendecomposition is only implemented on the CPU backend.")
rng = jtu.rand_default(self.rng())
tol = { np.int8: 2e-3, np.int32: 1e-3, np.float32: 1e-3, np.float64: 1e-6 }
@ -1914,7 +1914,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
xshape=one_dim_array_shapes,
yshape=one_dim_array_shapes,
)
@jtu.skip_on_devices("gpu", "tpu", "rocm") # backends don't support all dtypes.
@jtu.skip_on_devices("cuda", "tpu", "rocm") # backends don't support all dtypes.
def testConvolutionsPreferredElementType(self, xshape, yshape, dtype, mode, op):
jnp_op = getattr(jnp, op)
np_op = getattr(np, op)

View File

@ -210,7 +210,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
)
# TODO(phawkins): enable when there is an eigendecomposition implementation
# for GPU/TPU.
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testEig(self, shape, dtype, compute_left_eigenvectors,
compute_right_eigenvectors):
rng = jtu.rand_default(self.rng())
@ -252,7 +252,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
)
# TODO(phawkins): enable when there is an eigendecomposition implementation
# for GPU/TPU.
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testEigvalsGrad(self, shape, dtype):
# This test sometimes fails for large matrices. I (@j-towns) suspect, but
# haven't checked, that might be because of perturbations causing the
@ -271,7 +271,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
)
# TODO: enable when there is an eigendecomposition implementation
# for GPU/TPU.
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testEigvals(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
@ -280,7 +280,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
w2 = jnp.linalg.eigvals(a)
self.assertAllClose(w1, w2, rtol={np.complex64: 1e-5, np.complex128: 1e-14})
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testEigvalsInf(self):
# https://github.com/google/jax/issues/2661
x = jnp.array([[jnp.inf]])
@ -290,7 +290,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
shape=[(1, 1), (4, 4), (5, 5)],
dtype=float_types + complex_types,
)
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testEigBatching(self, shape, dtype):
rng = jtu.rand_default(self.rng())
shape = (10,) + shape
@ -688,7 +688,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
)
@jax.default_matmul_precision("float32")
def testQr(self, shape, dtype, full_matrices):
if (jtu.test_device_matches(["gpu"]) and
if (jtu.test_device_matches(["cuda"]) and
_is_required_cuda_version_satisfied(12000)):
self.skipTest("Triggers a bug in cuda-12 b/287345077")
rng = jtu.rand_default(self.rng())
@ -1287,7 +1287,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
dtype=int_types + float_types + complex_types
)
def testExpm(self, n, batch_size, dtype):
if (jtu.test_device_matches(["gpu"]) and
if (jtu.test_device_matches(["cuda"]) and
_is_required_cuda_version_satisfied(12000)):
self.skipTest("Triggers a bug in cuda-12 b/287345077")
@ -1357,7 +1357,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
dtype=float_types + complex_types,
calc_q=[False, True],
)
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testHessenberg(self, shape, dtype, calc_q):
rng = jtu.rand_default(self.rng())
jsp_func = partial(jax.scipy.linalg.hessenberg, calc_q=calc_q)
@ -1514,7 +1514,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
shape=[(4, 4), (15, 15), (50, 50), (100, 100)],
dtype=float_types + complex_types,
)
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testSchur(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
@ -1526,7 +1526,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
shape=[(1, 1), (4, 4), (15, 15), (50, 50), (100, 100)],
dtype=float_types + complex_types,
)
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testRsf2csf(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype), rng(shape, dtype)]
@ -1542,7 +1542,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
)
# funm uses jax.scipy.linalg.schur which is implemented for a CPU
# backend only, so tests on GPU and TPU backends are skipped here
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testFunm(self, shape, dtype, disp):
def func(x):
return x**-2.718
@ -1558,7 +1558,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
shape=[(4, 4), (15, 15), (50, 50), (100, 100)],
dtype=float_types + complex_types,
)
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testSqrtmPSDMatrix(self, shape, dtype):
# Checks against scipy.linalg.sqrtm when the principal square root
# is guaranteed to be unique (i.e no negative real eigenvalue)
@ -1581,7 +1581,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
shape=[(4, 4), (15, 15), (50, 50), (100, 100)],
dtype=float_types + complex_types,
)
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testSqrtmGenMatrix(self, shape, dtype):
rng = jtu.rand_default(self.rng())
arg = rng(shape, dtype)
@ -1600,7 +1600,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
],
dtype=float_types + complex_types,
)
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testSqrtmEdgeCase(self, diag, expected, dtype):
"""
Tests the zero numerator condition
@ -1773,7 +1773,7 @@ class LaxLinalgTest(jtu.JaxTestCase):
shape=[(4, 4), (15, 15), (50, 50), (100, 100)],
dtype=float_types + complex_types,
)
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testSchur(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
@ -1785,7 +1785,7 @@ class LaxLinalgTest(jtu.JaxTestCase):
shape=[(2, 2), (4, 4), (15, 15), (50, 50), (100, 100)],
dtype=float_types + complex_types,
)
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testSchurBatching(self, shape, dtype):
rng = jtu.rand_default(self.rng())
batch_size = 10

View File

@ -72,7 +72,7 @@ class TestPolynomial(jtu.JaxTestCase):
trailing=[0, 2],
)
# TODO(phawkins): no nonsymmetric eigendecomposition implementation on GPU.
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testRoots(self, dtype, length, leading, trailing):
rng = jtu.rand_some_zero(self.rng())
@ -98,7 +98,7 @@ class TestPolynomial(jtu.JaxTestCase):
trailing=[0, 2],
)
# TODO(phawkins): no nonsymmetric eigendecomposition implementation on GPU.
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testRootsNoStrip(self, dtype, length, leading, trailing):
rng = jtu.rand_some_zero(self.rng())

View File

@ -1958,7 +1958,7 @@ class BCOOTest(sptu.SparseTestCase):
@jax.default_matmul_precision("float32")
@jtu.ignore_warning(category=sparse.CuSparseEfficiencyWarning)
def test_bcoo_matmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype):
if (jtu.test_device_matches(["gpu"]) and
if (jtu.test_device_matches(["cuda"]) and
_is_required_cuda_version_satisfied(12000)):
raise unittest.SkipTest("Triggers a bug in cuda-12 b/287344632")
@ -2777,8 +2777,7 @@ class SparseSolverTest(sptu.SparseTestCase):
reorder=[0, 1, 2, 3],
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
)
@jtu.run_on_devices("cpu", "gpu")
@jtu.skip_on_devices("rocm") # test n gpu requires cusolver
@jtu.run_on_devices("cpu", "cuda")
def test_sparse_qr_linear_solver(self, size, reorder, dtype):
if jtu.test_device_matches(["cuda"]) and not GPU_LOWERING_ENABLED:
raise unittest.SkipTest('test requires cusparse/cusolver')
@ -2805,8 +2804,7 @@ class SparseSolverTest(sptu.SparseTestCase):
size=[10, 20, 50],
dtype=jtu.dtypes.floating,
)
@jtu.run_on_devices("cpu", "gpu")
@jtu.skip_on_devices("rocm") # test requires cusolver
@jtu.run_on_devices("cpu", "cuda")
def test_sparse_qr_linear_solver_grads(self, size, dtype):
if jtu.test_device_matches(["cuda"]) and not GPU_LOWERING_ENABLED:
raise unittest.SkipTest('test requires cusparse/cusolver')

View File

@ -459,7 +459,7 @@ class XMapTest(XMapTestCase):
self.assertAllClose(f_mapped(x, x), expected)
@jtu.with_and_without_mesh
@jtu.run_on_devices("gpu", "tpu") # In/out aliasing not supported on CPU.
@jtu.device_supports_buffer_donation() # In/out aliasing not supported on CPU
def testBufferDonation(self, mesh, axis_resources):
shard = lambda x: x
if axis_resources:
@ -476,7 +476,7 @@ class XMapTest(XMapTestCase):
self.assertNotDeleted(y)
self.assertDeleted(x)
@jtu.run_on_devices("gpu", "tpu") # In/out aliasing not supported on CPU.
@jtu.device_supports_buffer_donation() # In/out aliasing not supported on CPU
@jtu.with_mesh([('x', 2)])
@jtu.ignore_warning(category=UserWarning, # SPMD test generates warning.
message="Some donated buffers were not usable*")