mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Disable the RunnTest.test_lstm1 test since it is fixed for cudnn >= 8.8
PiperOrigin-RevId: 536693061
This commit is contained in:
parent
758d68df13
commit
6d6ba70c78
@ -42,6 +42,12 @@ class RnnTest(jtu.JaxTestCase):
|
||||
# TODO(sharadmv, zhangqiaorjc): remove this when minimum jaxlib version is
|
||||
# bumped
|
||||
self.skipTest("Need latest jaxlib for this test to pass.")
|
||||
|
||||
# TODO(phawkins): Partially disable this on cudnn version per b/281071013
|
||||
if (batch_size == 1 and seq_len == 4 and input_size == 1 and
|
||||
hidden_size == 6 and num_layers == 4 and bidirectional == False):
|
||||
self.skipTest("Test requires cudnn >= 8.8")
|
||||
|
||||
num_directions = 2 if bidirectional else 1
|
||||
seq_length_key, root_key = jax.random.split(jax.random.PRNGKey(0))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user