Merge pull request #19485 from ROCmSoftwarePlatform:rocm-enable_tridiagonal_solve

PiperOrigin-RevId: 601613417
This commit is contained in:
jax authors 2024-01-25 17:19:00 -08:00
commit 70ea84d67f

View File

@ -2008,7 +2008,6 @@ class LaxLinalgTest(jtu.JaxTestCase):
eigvals_all[first:(last + 1)], eigvals_index, atol=atol)
@jtu.sample_product(dtype=[np.float32, np.float64])
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
def test_tridiagonal_solve(self, dtype):
dl = np.array([0.0, 2.0, 3.0], dtype=dtype)
d = np.ones(3, dtype=dtype)