Disable the RunnTest.test_lstm1 test since it is fixed for cudnn >= 8.8

PiperOrigin-RevId: 536693061
This commit is contained in:
Yash Katariya 2023-05-31 06:20:26 -07:00 committed by jax authors
parent 758d68df13
commit 6d6ba70c78

View File

@ -42,6 +42,12 @@ class RnnTest(jtu.JaxTestCase):
# TODO(sharadmv, zhangqiaorjc): remove this when minimum jaxlib version is
# bumped
self.skipTest("Need latest jaxlib for this test to pass.")
# TODO(phawkins): Partially disable this on cudnn version per b/281071013
if (batch_size == 1 and seq_len == 4 and input_size == 1 and
hidden_size == 6 and num_layers == 4 and bidirectional == False):
self.skipTest("Test requires cudnn >= 8.8")
num_directions = 2 if bidirectional else 1
seq_length_key, root_key = jax.random.split(jax.random.PRNGKey(0))