mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Remove checks for jaxlib v0.4.33 in tests
This commit is contained in:
parent
e9c7ff0b7d
commit
5ed2f4ef1c
@ -26,7 +26,6 @@ from jax import numpy as jnp
|
||||
from jax._src import config
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.numpy.util import promote_dtypes_complex
|
||||
from jax._src.numpy.fft import _fft_norm
|
||||
|
||||
@ -482,8 +481,6 @@ class FftTest(jtu.JaxTestCase):
|
||||
# reported in https://github.com/jax-ml/jax/issues/23827
|
||||
if not config.enable_x64.value:
|
||||
raise self.skipTest("requires jax_enable_x64=true")
|
||||
if jaxlib_version <= (0, 4, 33):
|
||||
raise self.skipTest("requires jaxlib version > 0.4.33")
|
||||
n = 31
|
||||
a = np.ones((n, 15), dtype="complex128")
|
||||
self.assertArraysAllClose(
|
||||
|
@ -46,7 +46,6 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.internal_test_util import lax_test_util
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.util import NumpyComplexWarning, safe_zip
|
||||
from jax._src.tree_util import tree_map
|
||||
|
||||
@ -1081,9 +1080,6 @@ class LaxTest(jtu.JaxTestCase):
|
||||
if xla_bridge.using_pjrt_c_api():
|
||||
raise SkipTest(
|
||||
"The dot algorithm attribute is not supported by PJRT C API.")
|
||||
if jaxlib_version <= (0, 4, 33):
|
||||
raise SkipTest(
|
||||
"The dot algorithm attribute is only supported for jaxlib >0.4.33.")
|
||||
if jtu.test_device_matches(["cpu"]):
|
||||
if algorithm not in {
|
||||
lax.DotAlgorithmPreset.DEFAULT,
|
||||
@ -1127,9 +1123,6 @@ class LaxTest(jtu.JaxTestCase):
|
||||
if xla_bridge.using_pjrt_c_api():
|
||||
raise SkipTest(
|
||||
"The dot algorithm attribute is not supported by PJRT C API.")
|
||||
if jaxlib_version <= (0, 4, 33):
|
||||
raise SkipTest(
|
||||
"The dot algorithm attribute is only supported for jaxlib >0.4.33.")
|
||||
if jtu.test_device_matches(["cpu"]):
|
||||
raise SkipTest("Not supported on CPU.")
|
||||
lhs_shape = (3, 4)
|
||||
@ -1143,9 +1136,6 @@ class LaxTest(jtu.JaxTestCase):
|
||||
if xla_bridge.using_pjrt_c_api():
|
||||
raise SkipTest(
|
||||
"The dot algorithm attribute is not supported by PJRT C API.")
|
||||
if jaxlib_version <= (0, 4, 33):
|
||||
raise SkipTest(
|
||||
"The dot algorithm attribute is only supported for jaxlib >0.4.33.")
|
||||
def fun(lhs, rhs):
|
||||
return lax.dot(lhs, rhs, precision="F32_F32_F32")
|
||||
lhs_shape = (3, 4)
|
||||
|
Loading…
x
Reference in New Issue
Block a user