mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Reverts 86643a1b3e0516e1a2ddbdabbb714cf8c0301f18
PiperOrigin-RevId: 721776251
This commit is contained in:
parent
e6e7621f0b
commit
1bfdd504ed
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import partial
|
||||
import unittest
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
import jax
|
||||
@ -180,7 +179,6 @@ class RnnTest(jtu.JaxTestCase):
|
||||
y_padded = y_ref[i, seq_lengths[i]:]
|
||||
np.testing.assert_allclose(y_padded, jnp.zeros_like(y_padded))
|
||||
|
||||
@unittest.skip('https://github.com/jax-ml/jax/issues/25825')
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_struct_encoding_determinism(self):
|
||||
def f(k1, k2, k3, k4):
|
||||
|
Loading…
x
Reference in New Issue
Block a user