From 33bed1e52099ede4d486b49565040f3a4d94d669 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 14 Feb 2023 12:01:35 -0800 Subject: [PATCH] Opt into higher matmul precision for A100 and TPU tests. PiperOrigin-RevId: 509598465 --- tests/BUILD | 3 +++ tests/api_test.py | 8 ++++---- tests/experimental_rnn_test.py | 2 +- tests/lax_numpy_test.py | 3 +-- tests/scipy_stats_test.py | 1 + tests/sparse_test.py | 1 + 6 files changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index 92dd12df4..0d8494df4 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1040,6 +1040,9 @@ jax_test( "tpu", "cpu", ], + disable_configs = [ + "gpu_a100", # Numerical precision problems. + ], deps = [ "//jax:rnn", ], diff --git a/tests/api_test.py b/tests/api_test.py index 66ba5f39a..8da3f0fed 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1684,7 +1684,7 @@ class APITest(jtu.JaxTestCase): else: self.assertEqual(x.device_buffer.device(), cpu_device) - @jtu.skip_on_devices("tpu") + @jax.default_matmul_precision("float32") def test_jacobian(self): R = self.rng().randn A = R(4, 3) @@ -1697,7 +1697,7 @@ class APITest(jtu.JaxTestCase): f = lambda x: jnp.tanh(jnp.dot(A, x)) assert np.allclose(jacfwd(f)(x), jacrev(f)(x)) - @jtu.skip_on_devices("tpu") + @jax.default_matmul_precision("float32") def test_hessian(self): R = self.rng().randn A = R(4, 4) @@ -1706,7 +1706,7 @@ class APITest(jtu.JaxTestCase): f = lambda x: jnp.dot(x, jnp.dot(A, x)) assert np.allclose(hessian(f)(x), A + A.T) - @jtu.skip_on_devices("tpu") + @jax.default_matmul_precision("float32") def test_hessian_holomorphic(self): R = self.rng().randn A = R(4, 4) @@ -1715,7 +1715,7 @@ class APITest(jtu.JaxTestCase): f = lambda x: jnp.dot(x, jnp.dot(A.astype(x.dtype), x)) assert np.allclose(hessian(f, holomorphic=True)(x), A + A.T) - @jtu.skip_on_devices("tpu") + @jax.default_matmul_precision("float32") def test_hessian_aux(self): R = self.rng().randn A = R(4, 4) diff --git a/tests/experimental_rnn_test.py b/tests/experimental_rnn_test.py index bacb0efbd..dc18c1f97 100644 --- a/tests/experimental_rnn_test.py +++ b/tests/experimental_rnn_test.py @@ -33,7 +33,7 @@ class RnnTest(jtu.JaxTestCase): num_layers=[1, 4], bidirectional=[True, False], ) - @jtu.skip_on_devices("cpu", "tpu","rocm") + @jtu.skip_on_devices("cpu", "tpu", "rocm") def test_lstm(self, batch_size: int, seq_len: int, input_size: int, hidden_size: int, num_layers: int, bidirectional: bool): batch_size = 6 diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 52c61f75f..7b074affa 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -737,6 +737,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): or len(jtu._dims_of_shape(rhs_shape)) == 0 or lhs_shape[-1] == rhs_shape[-1]], ) + @jax.default_matmul_precision("float32") def testInner(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] @@ -748,8 +749,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): jnp_fun = lambda lhs, rhs: jnp.inner(lhs, rhs) tol_spec = {np.float16: 1e-2, np.float32: 1e-5, np.float64: 1e-13, np.complex64: 1e-5} - if jtu.device_under_test() == "tpu": - tol_spec[np.float32] = tol_spec[np.complex64] = 2e-1 tol = max(jtu.tolerance(lhs_dtype, tol_spec), jtu.tolerance(rhs_dtype, tol_spec)) # TODO(phawkins): there are float32/float64 disagreements for some inputs. diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index 8982ae035..0b5984d58 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -883,6 +883,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): method=[None, "scott", "silverman", 1.5, "callable"], func=[None, "evaluate", "logpdf", "pdf"], ) + @jax.default_matmul_precision("float32") def testKde(self, inshape, dtype, outsize, weights, method, func): if method == "callable": method = lambda kde: kde.neff ** -1./(kde.d+4) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 460cb7d95..62eedd654 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -2289,6 +2289,7 @@ class SparseGradTest(sptu.SparseTestCase): has_aux=[True, False], transform=['jacrev', 'jacfwd', 'jacobian'] ) + @jax.default_matmul_precision("float32") def test_sparse_jacobian(self, has_aux, transform): jac_dense = getattr(jax, transform) jac_sparse = getattr(sparse, transform)