Remove checks for jaxlib v0.4.33 in tests

This commit is contained in:
Dan Foreman-Mackey 2024-10-11 15:39:24 -04:00
parent e9c7ff0b7d
commit 5ed2f4ef1c
2 changed files with 0 additions and 13 deletions

View File

@ -26,7 +26,6 @@ from jax import numpy as jnp
from jax._src import config
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.lib import version as jaxlib_version
from jax._src.numpy.util import promote_dtypes_complex
from jax._src.numpy.fft import _fft_norm
@ -482,8 +481,6 @@ class FftTest(jtu.JaxTestCase):
# reported in https://github.com/jax-ml/jax/issues/23827
if not config.enable_x64.value:
raise self.skipTest("requires jax_enable_x64=true")
if jaxlib_version <= (0, 4, 33):
raise self.skipTest("requires jaxlib version > 0.4.33")
n = 31
a = np.ones((n, 15), dtype="complex128")
self.assertArraysAllClose(

View File

@ -46,7 +46,6 @@ from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.internal_test_util import lax_test_util
from jax._src.lax import lax as lax_internal
from jax._src.lib import version as jaxlib_version
from jax._src.util import NumpyComplexWarning, safe_zip
from jax._src.tree_util import tree_map
@ -1081,9 +1080,6 @@ class LaxTest(jtu.JaxTestCase):
if xla_bridge.using_pjrt_c_api():
raise SkipTest(
"The dot algorithm attribute is not supported by PJRT C API.")
if jaxlib_version <= (0, 4, 33):
raise SkipTest(
"The dot algorithm attribute is only supported for jaxlib >0.4.33.")
if jtu.test_device_matches(["cpu"]):
if algorithm not in {
lax.DotAlgorithmPreset.DEFAULT,
@ -1127,9 +1123,6 @@ class LaxTest(jtu.JaxTestCase):
if xla_bridge.using_pjrt_c_api():
raise SkipTest(
"The dot algorithm attribute is not supported by PJRT C API.")
if jaxlib_version <= (0, 4, 33):
raise SkipTest(
"The dot algorithm attribute is only supported for jaxlib >0.4.33.")
if jtu.test_device_matches(["cpu"]):
raise SkipTest("Not supported on CPU.")
lhs_shape = (3, 4)
@ -1143,9 +1136,6 @@ class LaxTest(jtu.JaxTestCase):
if xla_bridge.using_pjrt_c_api():
raise SkipTest(
"The dot algorithm attribute is not supported by PJRT C API.")
if jaxlib_version <= (0, 4, 33):
raise SkipTest(
"The dot algorithm attribute is only supported for jaxlib >0.4.33.")
def fun(lhs, rhs):
return lax.dot(lhs, rhs, precision="F32_F32_F32")
lhs_shape = (3, 4)