Cleanup: make skip_if_unsupported_type more robust (#3912)

This commit is contained in:
Jake Vanderplas 2020-07-30 11:07:56 -07:00 committed by GitHub
parent 9d3d09198c
commit 0cbb4279ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 45 deletions

View File

@ -366,9 +366,10 @@ def supported_dtypes():
return types
def skip_if_unsupported_type(dtype):
if dtype not in supported_dtypes():
dtype = np.dtype(dtype)
if dtype.type not in supported_dtypes():
raise unittest.SkipTest(
f"Type {dtype} not supported on {device_under_test()}")
f"Type {dtype.name} not supported on {device_under_test()}")
def skip_on_devices(*disabled_devices):
"""A decorator for test methods to skip the test on certain devices."""

View File

@ -42,12 +42,6 @@ T = lambda x: np.swapaxes(x, -1, -2)
float_types = [np.float32, np.float64]
complex_types = [np.complex64, np.complex128]
def _skip_if_unsupported_type(dtype):
dtype = np.dtype(dtype)
if (not FLAGS.jax_enable_x64 and
dtype in (np.dtype('float64'), np.dtype('complex128'))):
raise unittest.SkipTest("--jax_enable_x64 is not set")
class NumpyLinalgTest(jtu.JaxTestCase):
@ -66,7 +60,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for rng_factory in [jtu.rand_default]))
def testCholesky(self, shape, dtype, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
def args_maker():
factor_shape = shape[:-1] + (2 * shape[-1],)
a = rng(factor_shape, dtype)
@ -99,7 +93,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for rng_factory in [jtu.rand_default]))
def testDet(self, n, dtype, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
args_maker = lambda: [rng((n, n), dtype)]
self._CheckAgainstNumpy(np.linalg.det, jnp.linalg.det, args_maker, tol=1e-3)
@ -121,7 +115,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testDetGrad(self, shape, dtype, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
a = rng(shape, dtype)
jtu.check_grads(jnp.linalg.det, (a,), 2, atol=1e-1, rtol=1e-1)
# make sure there are no NaNs when a matrix is zero
@ -162,7 +156,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for rng_factory in [jtu.rand_default]))
def testTensorsolve(self, m, nq, dtype, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
# According to numpy docs the shapes are as follows:
# Coefficient tensor (a), of shape b.shape + Q.
@ -198,7 +192,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu")
def testSlogdet(self, shape, dtype, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np.linalg.slogdet, jnp.linalg.slogdet, args_maker,
@ -216,7 +210,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testSlogdetGrad(self, shape, dtype, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
a = rng(shape, dtype)
jtu.check_grads(jnp.linalg.slogdet, (a,), 2, atol=1e-1, rtol=1e-1)
@ -239,7 +233,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
@jtu.skip_on_devices("gpu", "tpu")
def testEig(self, shape, dtype, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
n = shape[-1]
args_maker = lambda: [rng(shape, dtype)]
@ -267,7 +261,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
@jtu.skip_on_devices("gpu", "tpu")
def testEigvals(self, shape, dtype, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
args_maker = lambda: [rng(shape, dtype)]
a, = args_maker()
w1, _ = jnp.linalg.eig(a)
@ -290,7 +284,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
@jtu.skip_on_devices("gpu", "tpu")
def testEigBatching(self, shape, dtype, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
shape = (10,) + shape
args = rng(shape, dtype)
ws, vs = vmap(jnp.linalg.eig)(args)
@ -307,7 +301,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for rng_factory in [jtu.rand_default]))
def testEigh(self, n, dtype, lower, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
tol = 30
if jtu.device_under_test() == "tpu":
if jnp.issubdtype(dtype, np.complexfloating):
@ -342,7 +336,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for rng_factory in [jtu.rand_default]))
def testEigvalsh(self, shape, dtype, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
if jtu.device_under_test() == "tpu":
if jnp.issubdtype(dtype, jnp.complexfloating):
raise unittest.SkipTest("No complex eigh on TPU")
@ -395,7 +389,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu")
def testEighGradVectorComplex(self, shape, dtype, rng_factory, lower, eps):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
# Special case to test for complex eigenvector grad correctness.
# Exact eigenvector coordinate gradients are hard to test numerically for complex
# eigensystem solvers given the extra degrees of per-eigenvector phase freedom.
@ -441,7 +435,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for rng_factory in [jtu.rand_default]))
def testEighBatching(self, shape, dtype, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
if (jtu.device_under_test() == "tpu" and
jnp.issubdtype(dtype, np.complexfloating)):
raise unittest.SkipTest("No complex eigh on TPU")
@ -474,7 +468,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for rng_factory in [jtu.rand_default])) # type: ignore
def testNorm(self, shape, dtype, ord, axis, keepdims, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
if (ord in ('nuc', 2, -2) and (
jtu.device_under_test() != "cpu" or
(isinstance(axis, tuple) and len(axis) == 2))):
@ -505,7 +499,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
jtu.device_under_test() == "tpu"):
raise unittest.SkipTest("No complex SVD implementation")
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
args_maker = lambda: [rng(b + (m, n), dtype)]
# Norm, adjusted for dimension and type.
@ -558,7 +552,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for rng_factory in [jtu.rand_default]))
def testQr(self, shape, dtype, full_matrices, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
if (jnp.issubdtype(dtype, np.complexfloating) and
jtu.device_under_test() == "tpu"):
raise unittest.SkipTest("No complex QR implementation")
@ -630,7 +624,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for dtype in float_types + complex_types))
@jtu.skip_on_devices("gpu") # TODO(#2203): numerical errors
def testCond(self, shape, pnorm, dtype):
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
if (jnp.issubdtype(dtype, np.complexfloating) and
jtu.device_under_test() == "tpu"):
raise unittest.SkipTest("No complex SVD implementation")
@ -665,7 +659,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for dtype in float_types
for rng_factory in [jtu.rand_default]))
def testTensorinv(self, shape, dtype, rng_factory):
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
rng = rng_factory(self.rng())
def tensor_maker():
@ -703,7 +697,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for rng_factory in [jtu.rand_default]))
def testSolve(self, lhs_shape, rhs_shape, dtype, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
self._CheckAgainstNumpy(np.linalg.solve, jnp.linalg.solve, args_maker,
@ -719,7 +713,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for rng_factory in [jtu.rand_default]))
def testInv(self, shape, dtype, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
if jtu.device_under_test() == "gpu" and shape == (200, 200):
raise unittest.SkipTest("Test is flaky on GPU")
@ -750,7 +744,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
jtu.device_under_test() == "tpu"):
raise unittest.SkipTest("No complex SVD implementation")
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np.linalg.pinv, jnp.linalg.pinv, args_maker,
@ -785,7 +779,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu") # TODO(b/149870255): Bug in XLA:TPU?.
def testMatrixPower(self, shape, dtype, n, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
args_maker = lambda: [rng(shape, dtype)]
tol = 1e-1 if jtu.device_under_test() == "tpu" else 1e-3
self._CheckAgainstNumpy(partial(np.linalg.matrix_power, n=n),
@ -806,7 +800,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
jtu.device_under_test() == "tpu"):
raise unittest.SkipTest("No complex SVD implementation")
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
args_maker = lambda: [rng(shape, dtype)]
a, = args_maker()
self._CheckAgainstNumpy(np.linalg.matrix_rank, jnp.linalg.matrix_rank,
@ -827,7 +821,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for rng_factory in [jtu.rand_default]))
def testMultiDot(self, shapes, dtype, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
args_maker = lambda: [[rng(shape, dtype) for shape in shapes]]
np_fun = np.linalg.multi_dot
@ -861,7 +855,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
@jtu.skip_on_devices("cpu", "gpu") # TODO(jakevdp) Test fails numerically
def testLstsq(self, lhs_shape, rhs_shape, dtype, lowrank, rcond, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
np_fun = partial(np.linalg.lstsq, rcond=rcond)
jnp_fun = partial(jnp.linalg.lstsq, rcond=rcond)
jnp_fun_numpy_resid = partial(jnp.linalg.lstsq, rcond=rcond, numpy_resid=True)
@ -949,7 +943,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
for rng_factory in [jtu.rand_default]))
def testLu(self, shape, dtype, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
args_maker = lambda: [rng(shape, dtype)]
x, = args_maker()
p, l, u = jsp.linalg.lu(x)
@ -974,7 +968,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testLuGrad(self, shape, dtype, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
a = rng(shape, dtype)
lu = vmap(jsp.linalg.lu) if len(shape) > 2 else jsp.linalg.lu
jtu.check_grads(lu, (a,), 2, atol=5e-2, rtol=3e-1)
@ -988,7 +982,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
for rng_factory in [jtu.rand_default]))
def testLuBatching(self, shape, dtype, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
args = [rng(shape, jnp.float32) for _ in range(10)]
expected = list(osp.linalg.lu(x) for x in args)
ps = np.stack([out[0] for out in expected])
@ -1009,7 +1003,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
for rng_factory in [jtu.rand_default]))
def testLuFactor(self, n, dtype, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
args_maker = lambda: [rng((n, n), dtype)]
x, = args_maker()
@ -1041,7 +1035,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
@jtu.skip_on_devices("cpu") # TODO(frostig): Test fails on CPU sometimes
def testLuSolve(self, lhs_shape, rhs_shape, dtype, trans, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
osp_fun = lambda lu, piv, rhs: osp.linalg.lu_solve((lu, piv), rhs, trans=trans)
jsp_fun = lambda lu, piv, rhs: jsp.linalg.lu_solve((lu, piv), rhs, trans=trans)
@ -1075,7 +1069,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
for rng_factory in [jtu.rand_default]))
def testSolve(self, lhs_shape, rhs_shape, dtype, sym_pos, lower, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
if (sym_pos and jnp.issubdtype(dtype, np.complexfloating) and
jtu.device_under_test() == "tpu"):
raise unittest.SkipTest(
@ -1114,7 +1108,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
for rng_factory in [jtu.rand_default]))
def testSolveTriangular(self, lower, transpose_a, unit_diagonal, lhs_shape,
rhs_shape, dtype, rng_factory):
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
rng = rng_factory(self.rng())
k = rng(lhs_shape, dtype)
l = np.linalg.cholesky(np.matmul(k, T(k))
@ -1173,7 +1167,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
def testTriangularSolveGrad(
self, lower, transpose_a, conjugate_a, unit_diagonal, left_side, a_shape,
b_shape, dtype, rng_factory):
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
rng = rng_factory(self.rng())
# Test lax_linalg.triangular_solve instead of scipy.linalg.solve_triangular
# because it exposes more options.
@ -1231,7 +1225,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
for rng_factory in [jtu.rand_small]))
def testExpm(self, n, dtype, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
args_maker = lambda: [rng((n, n), dtype)]
osp_fun = lambda a: osp.linalg.expm(a)
@ -1275,7 +1269,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
for rng_factory in [jtu.rand_default]))
def testChoSolve(self, lhs_shape, rhs_shape, dtype, lower, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
def args_maker():
b = rng(rhs_shape, dtype)
if lower:
@ -1297,7 +1291,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
for rng_factory in [jtu.rand_small]))
def testExpmFrechet(self, n, dtype, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
args_maker = lambda: [rng((n, n), dtype), rng((n, n), dtype),]
#compute_expm is True
@ -1322,7 +1316,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
for rng_factory in [jtu.rand_small]))
def testExpmGrad(self, n, dtype, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
jtu.skip_if_unsupported_type(dtype)
a = rng((n, n), dtype)
jtu.check_grads(jsp.linalg.expm, (a,), modes=["fwd"], order=1)