mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Cleanup: make skip_if_unsupported_type more robust (#3912)
This commit is contained in:
parent
9d3d09198c
commit
0cbb4279ee
@ -366,9 +366,10 @@ def supported_dtypes():
|
||||
return types
|
||||
|
||||
def skip_if_unsupported_type(dtype):
|
||||
if dtype not in supported_dtypes():
|
||||
dtype = np.dtype(dtype)
|
||||
if dtype.type not in supported_dtypes():
|
||||
raise unittest.SkipTest(
|
||||
f"Type {dtype} not supported on {device_under_test()}")
|
||||
f"Type {dtype.name} not supported on {device_under_test()}")
|
||||
|
||||
def skip_on_devices(*disabled_devices):
|
||||
"""A decorator for test methods to skip the test on certain devices."""
|
||||
|
@ -42,12 +42,6 @@ T = lambda x: np.swapaxes(x, -1, -2)
|
||||
float_types = [np.float32, np.float64]
|
||||
complex_types = [np.complex64, np.complex128]
|
||||
|
||||
def _skip_if_unsupported_type(dtype):
|
||||
dtype = np.dtype(dtype)
|
||||
if (not FLAGS.jax_enable_x64 and
|
||||
dtype in (np.dtype('float64'), np.dtype('complex128'))):
|
||||
raise unittest.SkipTest("--jax_enable_x64 is not set")
|
||||
|
||||
|
||||
class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
@ -66,7 +60,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testCholesky(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
def args_maker():
|
||||
factor_shape = shape[:-1] + (2 * shape[-1],)
|
||||
a = rng(factor_shape, dtype)
|
||||
@ -99,7 +93,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testDet(self, n, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
args_maker = lambda: [rng((n, n), dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(np.linalg.det, jnp.linalg.det, args_maker, tol=1e-3)
|
||||
@ -121,7 +115,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def testDetGrad(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
a = rng(shape, dtype)
|
||||
jtu.check_grads(jnp.linalg.det, (a,), 2, atol=1e-1, rtol=1e-1)
|
||||
# make sure there are no NaNs when a matrix is zero
|
||||
@ -162,7 +156,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testTensorsolve(self, m, nq, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
|
||||
# According to numpy docs the shapes are as follows:
|
||||
# Coefficient tensor (a), of shape b.shape + Q.
|
||||
@ -198,7 +192,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def testSlogdet(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(np.linalg.slogdet, jnp.linalg.slogdet, args_maker,
|
||||
@ -216,7 +210,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def testSlogdetGrad(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
a = rng(shape, dtype)
|
||||
jtu.check_grads(jnp.linalg.slogdet, (a,), 2, atol=1e-1, rtol=1e-1)
|
||||
|
||||
@ -239,7 +233,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
def testEig(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
n = shape[-1]
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
@ -267,7 +261,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
def testEigvals(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
a, = args_maker()
|
||||
w1, _ = jnp.linalg.eig(a)
|
||||
@ -290,7 +284,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
def testEigBatching(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
shape = (10,) + shape
|
||||
args = rng(shape, dtype)
|
||||
ws, vs = vmap(jnp.linalg.eig)(args)
|
||||
@ -307,7 +301,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testEigh(self, n, dtype, lower, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
tol = 30
|
||||
if jtu.device_under_test() == "tpu":
|
||||
if jnp.issubdtype(dtype, np.complexfloating):
|
||||
@ -342,7 +336,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testEigvalsh(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
if jtu.device_under_test() == "tpu":
|
||||
if jnp.issubdtype(dtype, jnp.complexfloating):
|
||||
raise unittest.SkipTest("No complex eigh on TPU")
|
||||
@ -395,7 +389,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def testEighGradVectorComplex(self, shape, dtype, rng_factory, lower, eps):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
# Special case to test for complex eigenvector grad correctness.
|
||||
# Exact eigenvector coordinate gradients are hard to test numerically for complex
|
||||
# eigensystem solvers given the extra degrees of per-eigenvector phase freedom.
|
||||
@ -441,7 +435,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testEighBatching(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
if (jtu.device_under_test() == "tpu" and
|
||||
jnp.issubdtype(dtype, np.complexfloating)):
|
||||
raise unittest.SkipTest("No complex eigh on TPU")
|
||||
@ -474,7 +468,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_default])) # type: ignore
|
||||
def testNorm(self, shape, dtype, ord, axis, keepdims, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
if (ord in ('nuc', 2, -2) and (
|
||||
jtu.device_under_test() != "cpu" or
|
||||
(isinstance(axis, tuple) and len(axis) == 2))):
|
||||
@ -505,7 +499,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
jtu.device_under_test() == "tpu"):
|
||||
raise unittest.SkipTest("No complex SVD implementation")
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
args_maker = lambda: [rng(b + (m, n), dtype)]
|
||||
|
||||
# Norm, adjusted for dimension and type.
|
||||
@ -558,7 +552,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testQr(self, shape, dtype, full_matrices, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
if (jnp.issubdtype(dtype, np.complexfloating) and
|
||||
jtu.device_under_test() == "tpu"):
|
||||
raise unittest.SkipTest("No complex QR implementation")
|
||||
@ -630,7 +624,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for dtype in float_types + complex_types))
|
||||
@jtu.skip_on_devices("gpu") # TODO(#2203): numerical errors
|
||||
def testCond(self, shape, pnorm, dtype):
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
if (jnp.issubdtype(dtype, np.complexfloating) and
|
||||
jtu.device_under_test() == "tpu"):
|
||||
raise unittest.SkipTest("No complex SVD implementation")
|
||||
@ -665,7 +659,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for dtype in float_types
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testTensorinv(self, shape, dtype, rng_factory):
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
rng = rng_factory(self.rng())
|
||||
|
||||
def tensor_maker():
|
||||
@ -703,7 +697,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testSolve(self, lhs_shape, rhs_shape, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(np.linalg.solve, jnp.linalg.solve, args_maker,
|
||||
@ -719,7 +713,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testInv(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
if jtu.device_under_test() == "gpu" and shape == (200, 200):
|
||||
raise unittest.SkipTest("Test is flaky on GPU")
|
||||
|
||||
@ -750,7 +744,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
jtu.device_under_test() == "tpu"):
|
||||
raise unittest.SkipTest("No complex SVD implementation")
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(np.linalg.pinv, jnp.linalg.pinv, args_maker,
|
||||
@ -785,7 +779,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_devices("tpu") # TODO(b/149870255): Bug in XLA:TPU?.
|
||||
def testMatrixPower(self, shape, dtype, n, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else 1e-3
|
||||
self._CheckAgainstNumpy(partial(np.linalg.matrix_power, n=n),
|
||||
@ -806,7 +800,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
jtu.device_under_test() == "tpu"):
|
||||
raise unittest.SkipTest("No complex SVD implementation")
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
a, = args_maker()
|
||||
self._CheckAgainstNumpy(np.linalg.matrix_rank, jnp.linalg.matrix_rank,
|
||||
@ -827,7 +821,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testMultiDot(self, shapes, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
args_maker = lambda: [[rng(shape, dtype) for shape in shapes]]
|
||||
|
||||
np_fun = np.linalg.multi_dot
|
||||
@ -861,7 +855,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_devices("cpu", "gpu") # TODO(jakevdp) Test fails numerically
|
||||
def testLstsq(self, lhs_shape, rhs_shape, dtype, lowrank, rcond, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
np_fun = partial(np.linalg.lstsq, rcond=rcond)
|
||||
jnp_fun = partial(jnp.linalg.lstsq, rcond=rcond)
|
||||
jnp_fun_numpy_resid = partial(jnp.linalg.lstsq, rcond=rcond, numpy_resid=True)
|
||||
@ -949,7 +943,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testLu(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
x, = args_maker()
|
||||
p, l, u = jsp.linalg.lu(x)
|
||||
@ -974,7 +968,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def testLuGrad(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
a = rng(shape, dtype)
|
||||
lu = vmap(jsp.linalg.lu) if len(shape) > 2 else jsp.linalg.lu
|
||||
jtu.check_grads(lu, (a,), 2, atol=5e-2, rtol=3e-1)
|
||||
@ -988,7 +982,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testLuBatching(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
args = [rng(shape, jnp.float32) for _ in range(10)]
|
||||
expected = list(osp.linalg.lu(x) for x in args)
|
||||
ps = np.stack([out[0] for out in expected])
|
||||
@ -1009,7 +1003,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testLuFactor(self, n, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
args_maker = lambda: [rng((n, n), dtype)]
|
||||
|
||||
x, = args_maker()
|
||||
@ -1041,7 +1035,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_devices("cpu") # TODO(frostig): Test fails on CPU sometimes
|
||||
def testLuSolve(self, lhs_shape, rhs_shape, dtype, trans, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
osp_fun = lambda lu, piv, rhs: osp.linalg.lu_solve((lu, piv), rhs, trans=trans)
|
||||
jsp_fun = lambda lu, piv, rhs: jsp.linalg.lu_solve((lu, piv), rhs, trans=trans)
|
||||
|
||||
@ -1075,7 +1069,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testSolve(self, lhs_shape, rhs_shape, dtype, sym_pos, lower, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
if (sym_pos and jnp.issubdtype(dtype, np.complexfloating) and
|
||||
jtu.device_under_test() == "tpu"):
|
||||
raise unittest.SkipTest(
|
||||
@ -1114,7 +1108,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testSolveTriangular(self, lower, transpose_a, unit_diagonal, lhs_shape,
|
||||
rhs_shape, dtype, rng_factory):
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
rng = rng_factory(self.rng())
|
||||
k = rng(lhs_shape, dtype)
|
||||
l = np.linalg.cholesky(np.matmul(k, T(k))
|
||||
@ -1173,7 +1167,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
def testTriangularSolveGrad(
|
||||
self, lower, transpose_a, conjugate_a, unit_diagonal, left_side, a_shape,
|
||||
b_shape, dtype, rng_factory):
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
rng = rng_factory(self.rng())
|
||||
# Test lax_linalg.triangular_solve instead of scipy.linalg.solve_triangular
|
||||
# because it exposes more options.
|
||||
@ -1231,7 +1225,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_small]))
|
||||
def testExpm(self, n, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
args_maker = lambda: [rng((n, n), dtype)]
|
||||
|
||||
osp_fun = lambda a: osp.linalg.expm(a)
|
||||
@ -1275,7 +1269,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testChoSolve(self, lhs_shape, rhs_shape, dtype, lower, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
def args_maker():
|
||||
b = rng(rhs_shape, dtype)
|
||||
if lower:
|
||||
@ -1297,7 +1291,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_small]))
|
||||
def testExpmFrechet(self, n, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
args_maker = lambda: [rng((n, n), dtype), rng((n, n), dtype),]
|
||||
|
||||
#compute_expm is True
|
||||
@ -1322,7 +1316,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_small]))
|
||||
def testExpmGrad(self, n, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
jtu.skip_if_unsupported_type(dtype)
|
||||
a = rng((n, n), dtype)
|
||||
jtu.check_grads(jsp.linalg.expm, (a,), modes=["fwd"], order=1)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user