Fix tolerance and shard_count for experimental_rnn_test

This should fix the current GPU test timeout.

PiperOrigin-RevId: 522167894
This commit is contained in:
Jake VanderPlas 2023-04-05 15:18:43 -07:00 committed by jax authors
parent f6da71c807
commit 8c8f50f688
2 changed files with 2 additions and 1 deletions

View File

@ -1055,6 +1055,7 @@ jax_test(
disable_configs = [
"gpu_a100", # Numerical precision problems.
],
shard_count = 8,
deps = [
"//jax:rnn",
],

View File

@ -76,7 +76,7 @@ class RnnTest(jtu.JaxTestCase):
loss = jnp.sum(jnp.where(seq_length_mask[..., None], y, 0.))
return loss, (y, h, c)
jtu.check_grads(f, (weights, x, h_0, c_0), modes=["rev"], order=1)
jtu.check_grads(f, (weights, x, h_0, c_0), modes=["rev"], order=1, atol=5E-3, rtol=5E-3)
(loss, (y, h_n, c_n)), weights_grad = jax.value_and_grad(f, has_aux=True)(
weights, x, h_0, c_0)