Reverts 86643a1b3e0516e1a2ddbdabbb714cf8c0301f18

PiperOrigin-RevId: 721776251
This commit is contained in:
Vladimir Belitskiy 2025-01-31 08:05:04 -08:00 committed by jax authors
parent e6e7621f0b
commit 1bfdd504ed

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import unittest
from absl.testing import absltest
import numpy as np
import jax
@ -180,7 +179,6 @@ class RnnTest(jtu.JaxTestCase):
y_padded = y_ref[i, seq_lengths[i]:]
np.testing.assert_allclose(y_padded, jnp.zeros_like(y_padded))
@unittest.skip('https://github.com/jax-ml/jax/issues/25825')
@jtu.run_on_devices("cuda")
def test_struct_encoding_determinism(self):
def f(k1, k2, k3, k4):