Opt into higher matmul precision for A100 and TPU tests.

PiperOrigin-RevId: 509598465
This commit is contained in:
Peter Hawkins 2023-02-14 12:01:35 -08:00 committed by jax authors
parent aa98c99d3a
commit 33bed1e520
6 changed files with 11 additions and 7 deletions

View File

@ -1040,6 +1040,9 @@ jax_test(
"tpu",
"cpu",
],
disable_configs = [
"gpu_a100", # Numerical precision problems.
],
deps = [
"//jax:rnn",
],

View File

@ -1684,7 +1684,7 @@ class APITest(jtu.JaxTestCase):
else:
self.assertEqual(x.device_buffer.device(), cpu_device)
@jtu.skip_on_devices("tpu")
@jax.default_matmul_precision("float32")
def test_jacobian(self):
R = self.rng().randn
A = R(4, 3)
@ -1697,7 +1697,7 @@ class APITest(jtu.JaxTestCase):
f = lambda x: jnp.tanh(jnp.dot(A, x))
assert np.allclose(jacfwd(f)(x), jacrev(f)(x))
@jtu.skip_on_devices("tpu")
@jax.default_matmul_precision("float32")
def test_hessian(self):
R = self.rng().randn
A = R(4, 4)
@ -1706,7 +1706,7 @@ class APITest(jtu.JaxTestCase):
f = lambda x: jnp.dot(x, jnp.dot(A, x))
assert np.allclose(hessian(f)(x), A + A.T)
@jtu.skip_on_devices("tpu")
@jax.default_matmul_precision("float32")
def test_hessian_holomorphic(self):
R = self.rng().randn
A = R(4, 4)
@ -1715,7 +1715,7 @@ class APITest(jtu.JaxTestCase):
f = lambda x: jnp.dot(x, jnp.dot(A.astype(x.dtype), x))
assert np.allclose(hessian(f, holomorphic=True)(x), A + A.T)
@jtu.skip_on_devices("tpu")
@jax.default_matmul_precision("float32")
def test_hessian_aux(self):
R = self.rng().randn
A = R(4, 4)

View File

@ -33,7 +33,7 @@ class RnnTest(jtu.JaxTestCase):
num_layers=[1, 4],
bidirectional=[True, False],
)
@jtu.skip_on_devices("cpu", "tpu","rocm")
@jtu.skip_on_devices("cpu", "tpu", "rocm")
def test_lstm(self, batch_size: int, seq_len: int, input_size: int,
hidden_size: int, num_layers: int, bidirectional: bool):
batch_size = 6

View File

@ -737,6 +737,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
or len(jtu._dims_of_shape(rhs_shape)) == 0
or lhs_shape[-1] == rhs_shape[-1]],
)
@jax.default_matmul_precision("float32")
def testInner(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
@ -748,8 +749,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
jnp_fun = lambda lhs, rhs: jnp.inner(lhs, rhs)
tol_spec = {np.float16: 1e-2, np.float32: 1e-5, np.float64: 1e-13,
np.complex64: 1e-5}
if jtu.device_under_test() == "tpu":
tol_spec[np.float32] = tol_spec[np.complex64] = 2e-1
tol = max(jtu.tolerance(lhs_dtype, tol_spec),
jtu.tolerance(rhs_dtype, tol_spec))
# TODO(phawkins): there are float32/float64 disagreements for some inputs.

View File

@ -883,6 +883,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
method=[None, "scott", "silverman", 1.5, "callable"],
func=[None, "evaluate", "logpdf", "pdf"],
)
@jax.default_matmul_precision("float32")
def testKde(self, inshape, dtype, outsize, weights, method, func):
if method == "callable":
method = lambda kde: kde.neff ** -1./(kde.d+4)

View File

@ -2289,6 +2289,7 @@ class SparseGradTest(sptu.SparseTestCase):
has_aux=[True, False],
transform=['jacrev', 'jacfwd', 'jacobian']
)
@jax.default_matmul_precision("float32")
def test_sparse_jacobian(self, has_aux, transform):
jac_dense = getattr(jax, transform)
jac_sparse = getattr(sparse, transform)