mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fix tolerance and shard_count for experimental_rnn_test
This should fix the current GPU test timeout. PiperOrigin-RevId: 522167894
This commit is contained in:
parent
f6da71c807
commit
8c8f50f688
@ -1055,6 +1055,7 @@ jax_test(
|
||||
disable_configs = [
|
||||
"gpu_a100", # Numerical precision problems.
|
||||
],
|
||||
shard_count = 8,
|
||||
deps = [
|
||||
"//jax:rnn",
|
||||
],
|
||||
|
@ -76,7 +76,7 @@ class RnnTest(jtu.JaxTestCase):
|
||||
loss = jnp.sum(jnp.where(seq_length_mask[..., None], y, 0.))
|
||||
return loss, (y, h, c)
|
||||
|
||||
jtu.check_grads(f, (weights, x, h_0, c_0), modes=["rev"], order=1)
|
||||
jtu.check_grads(f, (weights, x, h_0, c_0), modes=["rev"], order=1, atol=5E-3, rtol=5E-3)
|
||||
|
||||
(loss, (y, h_n, c_n)), weights_grad = jax.value_and_grad(f, has_aux=True)(
|
||||
weights, x, h_0, c_0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user