mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Drop support in tests for NumPy < 1.16 and jaxlib < 0.1.60
These are now the minimum versions.
This commit is contained in:
parent
016cc83162
commit
dfdf582ea0
@ -75,8 +75,6 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
for dtype in dlpack_dtypes
|
||||
for take_ownership in [False, True]))
|
||||
def testJaxRoundTrip(self, shape, dtype, take_ownership):
|
||||
if jax.lib.version < (0, 1, 57) and not take_ownership:
|
||||
raise unittest.SkipTest("Requires jaxlib >= 0.1.57");
|
||||
rng = jtu.rand_default(self.rng())
|
||||
np = rng(shape, dtype)
|
||||
x = jnp.array(np)
|
||||
@ -120,8 +118,6 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
for dtype in dlpack_dtypes))
|
||||
@unittest.skipIf(not tf, "Test requires TensorFlow")
|
||||
def testJaxToTensorFlow(self, shape, dtype):
|
||||
if jax.lib.version < (0, 1, 57):
|
||||
raise unittest.SkipTest("Requires jaxlib >= 0.1.57");
|
||||
if not FLAGS.jax_enable_x64 and dtype in [jnp.int64, jnp.uint64,
|
||||
jnp.float64]:
|
||||
self.skipTest("x64 types are disabled by jax_enable_x64")
|
||||
@ -164,8 +160,6 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
for dtype in torch_dtypes))
|
||||
@unittest.skipIf(not torch, "Test requires PyTorch")
|
||||
def testJaxToTorch(self, shape, dtype):
|
||||
if jax.lib.version < (0, 1, 57):
|
||||
raise unittest.SkipTest("Requires jaxlib >= 0.1.57");
|
||||
if not FLAGS.jax_enable_x64 and dtype in [jnp.int64, jnp.float64]:
|
||||
self.skipTest("x64 types are disabled by jax_enable_x64")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -202,10 +196,8 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase):
|
||||
|
||||
class Bfloat16Test(jtu.JaxTestCase):
|
||||
|
||||
@unittest.skipIf((not tf or tf_version < (2, 5, 0) or
|
||||
jax.lib.version < (0, 1, 58)),
|
||||
"Test requires TensorFlow 2.5.0 or newer and jaxlib 0.1.58 "
|
||||
"or newer")
|
||||
@unittest.skipIf((not tf or tf_version < (2, 5, 0)),
|
||||
"Test requires TensorFlow 2.5.0 or newer")
|
||||
def testJaxAndTfHaveTheSameBfloat16Type(self):
|
||||
self.assertEqual(np.dtype(jnp.bfloat16).num,
|
||||
np.dtype(tf.dtypes.bfloat16.as_numpy_dtype).num)
|
||||
|
@ -302,7 +302,11 @@ JAX_COMPOUND_OP_RECORDS = [
|
||||
# numpy.unwrap always returns float64
|
||||
check_dtypes=False,
|
||||
# numpy cumsum is inaccurate, see issue #3517
|
||||
tolerance={dtypes.bfloat16: 1e-1, np.float16: 1e-1})
|
||||
tolerance={dtypes.bfloat16: 1e-1, np.float16: 1e-1}),
|
||||
op_record("isclose", 2, [t for t in all_dtypes if t != jnp.bfloat16],
|
||||
all_shapes, jtu.rand_small_positive, []),
|
||||
op_record("gcd", 2, int_dtypes_no_uint64, all_shapes, jtu.rand_default, []),
|
||||
op_record("lcm", 2, int_dtypes_no_uint64, all_shapes, jtu.rand_default, []),
|
||||
]
|
||||
|
||||
JAX_BITWISE_OP_RECORDS = [
|
||||
@ -351,6 +355,7 @@ JAX_REDUCER_NO_DTYPE_RECORDS = [
|
||||
[], inexact=True),
|
||||
op_record("nanstd", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan,
|
||||
[], inexact=True),
|
||||
op_record("ptp", 1, number_dtypes, nonempty_shapes, jtu.rand_default, []),
|
||||
]
|
||||
|
||||
JAX_ARGMINMAX_RECORDS = [
|
||||
@ -427,18 +432,6 @@ for rec in JAX_OPERATOR_OVERLOADS + JAX_RIGHT_OPERATOR_OVERLOADS:
|
||||
setattr(_OverrideNothing, rec.name, lambda self, other: NotImplemented)
|
||||
|
||||
|
||||
if numpy_version >= (1, 15):
|
||||
JAX_COMPOUND_OP_RECORDS += [
|
||||
op_record("isclose", 2, [t for t in all_dtypes if t != jnp.bfloat16],
|
||||
all_shapes, jtu.rand_small_positive, []),
|
||||
op_record("gcd", 2, int_dtypes_no_uint64, all_shapes, jtu.rand_default, []),
|
||||
op_record("lcm", 2, int_dtypes_no_uint64, all_shapes, jtu.rand_default, []),
|
||||
]
|
||||
JAX_REDUCER_NO_DTYPE_RECORDS += [
|
||||
op_record("ptp", 1, number_dtypes, nonempty_shapes, jtu.rand_default, []),
|
||||
]
|
||||
|
||||
|
||||
def _dtypes_are_compatible_for_bitwise_ops(args):
|
||||
if len(args) <= 1:
|
||||
return True
|
||||
@ -3740,8 +3733,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
'midpoint']))
|
||||
def testQuantile(self, op, a_rng, q_rng, a_shape, a_dtype, q_shape, q_dtype,
|
||||
axis, keepdims, interpolation):
|
||||
if "quantile" in op and numpy_version < (1, 15):
|
||||
raise SkipTest("Numpy < 1.15 does not have np.quantile")
|
||||
a_rng = a_rng(self.rng())
|
||||
q_rng = q_rng(self.rng())
|
||||
if "median" in op:
|
||||
|
@ -15,11 +15,9 @@
|
||||
|
||||
import collections
|
||||
|
||||
from unittest import skipIf
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
from jax import test_util as jtu
|
||||
from jax import tree_util
|
||||
|
||||
@ -202,7 +200,6 @@ class TreeTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out, (((1, [3]), (2, None)),
|
||||
(([3, 4, 5], ({"foo": "bar"}, 7, [5, 6])))))
|
||||
|
||||
@skipIf(jax.lib.version < (0, 1, 58), "test requires Jaxlib >= 0.1.58")
|
||||
def testFlattenIsLeaf(self):
|
||||
x = [(1, 2), (3, 4), (5, 6)]
|
||||
leaves, _ = tree_util.tree_flatten(x, is_leaf=lambda t: False)
|
||||
@ -220,7 +217,6 @@ class TreeTest(jtu.JaxTestCase):
|
||||
y, is_leaf=lambda t: isinstance(t, tuple))
|
||||
self.assertEqual(leaves, [(1,), (2,), (3,)])
|
||||
|
||||
@skipIf(jax.lib.version < (0, 1, 58), "test requires Jaxlib >= 0.1.58")
|
||||
@parameterized.parameters(*TREES)
|
||||
def testRoundtripIsLeaf(self, tree):
|
||||
xs, treedef = tree_util.tree_flatten(
|
||||
|
@ -88,8 +88,6 @@ def with_mesh_from_kwargs(f):
|
||||
|
||||
class XMapTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
if jax.lib.version < (0, 1, 58):
|
||||
raise SkipTest("xmap requires jaxlib version >= 0.1.58")
|
||||
if not config.omnistaging_enabled:
|
||||
raise SkipTest("xmap requires omnistaging")
|
||||
|
||||
@ -374,8 +372,6 @@ class XMapTestSPMD(XMapTest):
|
||||
|
||||
class NamedNumPyTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
if jax.lib.version < (0, 1, 58):
|
||||
raise SkipTest("xmap requires jaxlib version >= 0.1.58")
|
||||
if not config.omnistaging_enabled:
|
||||
raise SkipTest("xmap requires omnistaging")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user