mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Disabled tests known to fail on Mac, and optionally slow tests.
Issue: #2166 Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known to be slow.
This commit is contained in:
parent
d01210e9e3
commit
b18a4d8583
@ -21,6 +21,7 @@ These are the release notes for JAX.
|
||||
and Numba.
|
||||
* JAX CPU device buffers now implement the Python buffer protocol, which allows
|
||||
zero-copy buffer sharing between JAX and NumPy.
|
||||
* Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known as slow.
|
||||
|
||||
## jaxlib 0.1.38 (January 29, 2020)
|
||||
|
||||
|
@ -124,6 +124,9 @@ file directly to see more detailed information about the cases being run:
|
||||
|
||||
python tests/lax_numpy_test.py --num_generated_cases=5
|
||||
|
||||
You can skip a few tests known as slow, by passing environment variable
|
||||
JAX_SKIP_SLOW_TESTS=1.
|
||||
|
||||
The Colab notebooks are tested for errors as part of the documentation build.
|
||||
|
||||
Update documentation
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
|
||||
from contextlib import contextmanager
|
||||
from distutils.util import strtobool
|
||||
import functools
|
||||
import re
|
||||
import itertools as it
|
||||
@ -49,6 +50,13 @@ flags.DEFINE_integer(
|
||||
int(os.getenv('JAX_NUM_GENERATED_CASES', 10)),
|
||||
help='Number of generated cases to test')
|
||||
|
||||
flags.DEFINE_bool(
|
||||
'jax_skip_slow_tests',
|
||||
strtobool(os.getenv('JAX_SKIP_SLOW_TESTS', '0')),
|
||||
help=
|
||||
'Skip tests marked as slow (> 5 sec).'
|
||||
)
|
||||
|
||||
EPS = 1e-4
|
||||
|
||||
def _dtype(x):
|
||||
|
@ -957,6 +957,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
"jit_scan": jit_scan, "jit_f": jit_f}
|
||||
for jit_scan in [False, True]
|
||||
for jit_f in [False, True])
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def testScanGrad(self, jit_scan, jit_f):
|
||||
rng = onp.random.RandomState(0)
|
||||
|
||||
@ -987,6 +988,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(partial(scan, f), (c, as_), order=2, modes=["rev"],
|
||||
atol=1e-3, rtol=2e-3)
|
||||
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def testScanRnn(self):
|
||||
r = npr.RandomState(0)
|
||||
|
||||
@ -1444,6 +1446,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(results, 5.0 ** 1.5, check_dtypes=False,
|
||||
rtol={onp.float64:1e-7})
|
||||
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def test_custom_root_vector_with_solve_closure(self):
|
||||
|
||||
def vector_solve(f, y):
|
||||
@ -1513,6 +1516,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
{"testcase_name": "nonsymmetric", "symmetric": False},
|
||||
{"testcase_name": "symmetric", "symmetric": True},
|
||||
)
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def test_custom_linear_solve(self, symmetric):
|
||||
|
||||
def explicit_jacobian_solve(matvec, b):
|
||||
@ -1542,6 +1546,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
actual = api.vmap(linear_solve, (None, 1), 1)(a, c)
|
||||
self.assertAllClose(expected, actual, check_dtypes=True)
|
||||
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def test_custom_linear_solve_zeros(self):
|
||||
def explicit_jacobian_solve(matvec, b):
|
||||
return lax.stop_gradient(np.linalg.solve(api.jacobian(matvec)(b), b))
|
||||
@ -1561,6 +1566,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(lambda x: linear_solve(a, x), (b,), order=2,
|
||||
rtol={onp.float32: 5e-3})
|
||||
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def test_custom_linear_solve_iterative(self):
|
||||
|
||||
def richardson_iteration(matvec, b, omega=0.1, tolerance=1e-6):
|
||||
@ -1622,6 +1628,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
lambda x, y: positive_definite_solve(high_precision_dot(x, x.T), y),
|
||||
(a, b), order=2, rtol=1e-2)
|
||||
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def test_custom_linear_solve_lu(self):
|
||||
|
||||
# TODO(b/143528110): re-enable when underlying XLA TPU issue is fixed
|
||||
@ -1652,6 +1659,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(api.jit(linear_solve), (a, b), order=2,
|
||||
rtol={onp.float32: 2e-3})
|
||||
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def test_custom_linear_solve_without_transpose_solve(self):
|
||||
|
||||
def explicit_jacobian_solve(matvec, b):
|
||||
@ -1674,6 +1682,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(TypeError, "transpose_solve required"):
|
||||
api.grad(loss)(a, b)
|
||||
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def test_custom_linear_solve_pytree(self):
|
||||
"""Test custom linear solve with inputs and outputs that are pytrees."""
|
||||
|
||||
|
@ -564,6 +564,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
for dspec in [('NHWC', 'HWIO', 'NHWC'),]
|
||||
for rhs_dilation in [None, (2, 2)]
|
||||
for rng_factory in [jtu.rand_small]))
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def testConvTranspose2DT(self, lhs_shape, rhs_shape, dtype, strides,
|
||||
padding, dspec, rhs_dilation, rng_factory):
|
||||
rng = rng_factory()
|
||||
@ -602,6 +603,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
for dspec in [('NHWC', 'HWIO', 'NHWC'),]
|
||||
for rhs_dilation in [None, (2, 2)]
|
||||
for rng_factory in [jtu.rand_small]))
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def testConvTranspose2D(self, lhs_shape, rhs_shape, dtype, strides,
|
||||
padding, dspec, rhs_dilation, rng_factory):
|
||||
rng = rng_factory()
|
||||
@ -2281,6 +2283,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
for dtype in dtypes
|
||||
for padding in ["VALID", "SAME"]
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def testReduceWindowGrad(self, op, init_val, dtype, padding, rng_factory):
|
||||
rng = rng_factory()
|
||||
tol = {onp.float16: 1e-1, onp.float32: 1e-3}
|
||||
@ -2970,6 +2973,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
for dtype in float_dtypes
|
||||
for padding in ["VALID", "SAME"]
|
||||
for rng_factory in [jtu.rand_small]))
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def testSelectAndGatherAdd(self, dtype, padding, rng_factory):
|
||||
if jtu.device_under_test() == "tpu" and dtype == dtypes.bfloat16:
|
||||
raise SkipTest("bfloat16 _select_and_gather_add doesn't work on tpu")
|
||||
|
@ -17,6 +17,7 @@
|
||||
from functools import partial
|
||||
import itertools
|
||||
import unittest
|
||||
import sys
|
||||
|
||||
import numpy as onp
|
||||
import scipy as osp
|
||||
@ -51,6 +52,10 @@ def _skip_if_unsupported_type(dtype):
|
||||
dtype in (onp.dtype('float64'), onp.dtype('complex128'))):
|
||||
raise unittest.SkipTest("--jax_enable_x64 is not set")
|
||||
|
||||
# TODO(phawkins): bug https://github.com/google/jax/issues/2166
|
||||
def _skip_on_mac_xla_bug():
|
||||
if sys.platform == "darwin" and osp.version.version > "1.0.0":
|
||||
raise unittest.SkipTest("Test fails on Mac with new scipy (issue #2166)")
|
||||
|
||||
class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
@ -134,6 +139,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for dtype in float_types
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
@jtu.skip_on_devices("tpu")
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def testSlogdetGrad(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory()
|
||||
_skip_if_unsupported_type(dtype)
|
||||
@ -188,6 +194,8 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
def testEigvals(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory()
|
||||
_skip_if_unsupported_type(dtype)
|
||||
if shape == (50, 50) and dtype == onp.complex64:
|
||||
_skip_on_mac_xla_bug()
|
||||
n = shape[-1]
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
a, = args_maker()
|
||||
@ -565,6 +573,8 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
def testInv(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory()
|
||||
_skip_if_unsupported_type(dtype)
|
||||
if shape == (200, 200) and dtype == onp.float32:
|
||||
_skip_on_mac_xla_bug()
|
||||
if jtu.device_under_test() == "gpu" and shape == (200, 200):
|
||||
raise unittest.SkipTest("Test is flaky on GPU")
|
||||
|
||||
@ -594,6 +604,8 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
def testPinv(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory()
|
||||
_skip_if_unsupported_type(dtype)
|
||||
if shape == (7, 10000) and dtype in [onp.complex64, onp.float32]:
|
||||
_skip_on_mac_xla_bug()
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(onp.linalg.pinv, np.linalg.pinv, args_maker,
|
||||
@ -650,6 +662,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
xc = onp.eye(3, dtype=onp.complex)
|
||||
self.assertAllClose(xc, grad_test_jc(xc), check_dtypes=True)
|
||||
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def testIssue1151(self):
|
||||
A = np.array(onp.random.randn(100, 3, 3), dtype=np.float32)
|
||||
b = np.array(onp.random.randn(100, 3), dtype=np.float32)
|
||||
@ -661,6 +674,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
jac0 = jax.jacobian(np.linalg.solve, argnums=0)(A[0], b[0])
|
||||
jac1 = jax.jacobian(np.linalg.solve, argnums=1)(A[0], b[0])
|
||||
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def testIssue1383(self):
|
||||
seed = jax.random.PRNGKey(0)
|
||||
tmp = jax.random.uniform(seed, (2,2))
|
||||
@ -726,6 +740,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
for dtype in float_types + complex_types
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
@jtu.skip_on_devices("tpu") # TODO(phawkins): precision problems on TPU.
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def testLuGrad(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory()
|
||||
_skip_if_unsupported_type(dtype)
|
||||
@ -764,6 +779,8 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
def testLuFactor(self, n, dtype, rng_factory):
|
||||
rng = rng_factory()
|
||||
_skip_if_unsupported_type(dtype)
|
||||
if n == 200 and dtype == onp.complex64:
|
||||
_skip_on_mac_xla_bug()
|
||||
args_maker = lambda: [rng((n, n), dtype)]
|
||||
|
||||
x, = args_maker()
|
||||
@ -985,6 +1002,8 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
def testExpm(self, n, dtype, rng_factory):
|
||||
rng = rng_factory()
|
||||
_skip_if_unsupported_type(dtype)
|
||||
if n == 50 and dtype in [onp.complex64, onp.float32]:
|
||||
_skip_on_mac_xla_bug()
|
||||
args_maker = lambda: [rng((n, n), dtype)]
|
||||
|
||||
osp_fun = lambda a: osp.linalg.expm(a)
|
||||
|
@ -34,6 +34,7 @@ config.parse_flags_with_absl()
|
||||
|
||||
class NNFunctionsTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def testSoftplusGrad(self):
|
||||
check_grads(nn.softplus, (1e-8,), 4,
|
||||
rtol=1e-2 if jtu.device_under_test() == "tpu" else None)
|
||||
@ -42,6 +43,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
val = nn.softplus(89.)
|
||||
self.assertAllClose(val, 89., check_dtypes=False)
|
||||
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def testEluGrad(self):
|
||||
check_grads(nn.elu, (1e4,), 4, eps=1.)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user