From 8c8f50f6886a4a3158030d3d57aa74aa878dfd06 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 5 Apr 2023 15:18:43 -0700 Subject: [PATCH] Fix tolerance and shard_count for experimental_rnn_test This should fix the current GPU test timeout. PiperOrigin-RevId: 522167894 --- tests/BUILD | 1 + tests/experimental_rnn_test.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/BUILD b/tests/BUILD index a062f8f49..c3e5a447e 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1055,6 +1055,7 @@ jax_test( disable_configs = [ "gpu_a100", # Numerical precision problems. ], + shard_count = 8, deps = [ "//jax:rnn", ], diff --git a/tests/experimental_rnn_test.py b/tests/experimental_rnn_test.py index b04c25960..657f8b3ef 100644 --- a/tests/experimental_rnn_test.py +++ b/tests/experimental_rnn_test.py @@ -76,7 +76,7 @@ class RnnTest(jtu.JaxTestCase): loss = jnp.sum(jnp.where(seq_length_mask[..., None], y, 0.)) return loss, (y, h, c) - jtu.check_grads(f, (weights, x, h_0, c_0), modes=["rev"], order=1) + jtu.check_grads(f, (weights, x, h_0, c_0), modes=["rev"], order=1, atol=5E-3, rtol=5E-3) (loss, (y, h_n, c_n)), weights_grad = jax.value_and_grad(f, has_aux=True)( weights, x, h_0, c_0)