25 Commits

Author SHA1 Message Date
Dan Foreman-Mackey
8b7cfcb33c Fix integer overflow in workspace size computations for experimental.rnn.*.
PiperOrigin-RevId: 736139471
2025-03-12 08:22:04 -07:00
Dan Foreman-Mackey
c6e83903de Update RNN kernels to use FFI.
PiperOrigin-RevId: 724151647
2025-02-06 18:27:58 -08:00
Vladimir Belitskiy
1bfdd504ed Reverts 86643a1b3e0516e1a2ddbdabbb714cf8c0301f18
PiperOrigin-RevId: 721776251
2025-01-31 08:05:46 -08:00
jax authors
41993fdb24 Merge pull request #25755 from ROCm:ci_rnn_final-upstream
PiperOrigin-RevId: 715856939
2025-01-15 10:40:54 -08:00
Ruturaj4
fe68eb8b25 [ROCm] Implement RNN support 2025-01-14 19:04:49 -06:00
Vladimir Belitskiy
86643a1b3e Skip RnnTest.test_struct_encoding_determinism.
PiperOrigin-RevId: 714027519
2025-01-10 06:20:01 -08:00
Ilia Sergachev
f0e1c3cf36 Fix struct string encoding non-determinism in the RNN descriptor.
Boolean fields in the descriptor struct led to padding, which let random
bytes in the string representation of the struct and variance in HLO
from run to run.
2025-01-09 12:57:09 +00:00
Sergei Lebedev
1079304259 MAINT Do not import the config object in JAX internals
The longer term goal here is to move away from having the config object as
part of the public API and migrate towards module-level functions instead.

Note that we can preserve the dynamic attribute lookup behavior of the
config object via a module-level `__getattr__`
2023-10-18 10:55:13 +01:00
Peter Hawkins
f52926e832 Fix test breakage in RNN test with old jaxlibs.
Remove some outdated version guards.
2023-09-20 11:50:04 -04:00
Andrey Portnoy
fc1c31d958 Run LSTM test using FP32 math (as opposed to TF32)
1. Add (limited) precision specifier handling to LSTM

Enables differentiating between TF32 and FP32 math. TF32 math had insufficient
precision to reliably pass LSTM correctness tests on A100 and H100.

2. Run the test using FP32

TF32 precision is not sufficient for the test to pass reliably on Ampere+ GPUs
such as A100 and H100.
2023-09-19 14:45:14 -04:00
Jake Hall
f59a4163fa Test changes for out-of-tree backend. 2023-09-14 12:18:37 +01:00
Jake VanderPlas
2f878a7168 Tests: set jax_legacy_prng_key='error' 2023-08-28 10:56:09 -07:00
Yash Katariya
6d6ba70c78 Disable the RunnTest.test_lstm1 test since it is fixed for cudnn >= 8.8
PiperOrigin-RevId: 536693061
2023-05-31 06:21:01 -07:00
Jake VanderPlas
8c8f50f688 Fix tolerance and shard_count for experimental_rnn_test
This should fix the current GPU test timeout.

PiperOrigin-RevId: 522167894
2023-04-05 15:19:19 -07:00
jax authors
b361f4cd0c Merge pull request #15169 from cgarciae:fix-lstm
PiperOrigin-RevId: 521616002
2023-04-03 18:13:19 -07:00
Cristian Garcia
aa12e3597b handle seq_lengths in lstm_ref 2023-04-03 22:22:54 +00:00
Sharad Vikram
10dc941d8d Add jaxlib version guard for rnn test
PiperOrigin-RevId: 519833650
2023-03-27 14:43:46 -07:00
Sharad Vikram
3c3fa042e3 Copy seq_lengths before creating descriptor
PiperOrigin-RevId: 519771897
2023-03-27 10:59:44 -07:00
Peter Hawkins
33bed1e520 Opt into higher matmul precision for A100 and TPU tests.
PiperOrigin-RevId: 509598465
2023-02-14 12:03:12 -08:00
Rahul Batra
3391a5e385 [ROCm]: Disable some tests on ROCm platform 2022-12-19 21:33:13 +00:00
Peter Hawkins
2c6c30d458 Bump the minimum jaxlib version to 0.4.1.
Jaxlib 0.4.1 has XLA client version 109 and MLIR API version 39.
2022-12-19 17:49:24 +00:00
Qiao Zhang
55d6daacfa Skip test_lstm on CPU and TPU for jax OSS build.
PiperOrigin-RevId: 492722650
2022-12-03 14:16:07 -08:00
Qiao Zhang
4d1c4bc761 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 491445515
2022-11-28 14:31:48 -08:00
jax authors
d1fbdbc1cf Rollback of "Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module."
PiperOrigin-RevId: 490499003
2022-11-23 07:48:05 -08:00
Qiao Zhang
78963b6020 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 490387796
2022-11-22 18:53:29 -08:00