Drop support in tests for NumPy < 1.16 and jaxlib < 0.1.60

These are now the minimum versions.
This commit is contained in:
Peter Hawkins 2021-02-03 20:57:50 -05:00
parent 016cc83162
commit dfdf582ea0
4 changed files with 8 additions and 33 deletions

View File

@ -75,8 +75,6 @@ class DLPackTest(jtu.JaxTestCase):
for dtype in dlpack_dtypes
for take_ownership in [False, True]))
def testJaxRoundTrip(self, shape, dtype, take_ownership):
if jax.lib.version < (0, 1, 57) and not take_ownership:
raise unittest.SkipTest("Requires jaxlib >= 0.1.57");
rng = jtu.rand_default(self.rng())
np = rng(shape, dtype)
x = jnp.array(np)
@ -120,8 +118,6 @@ class DLPackTest(jtu.JaxTestCase):
for dtype in dlpack_dtypes))
@unittest.skipIf(not tf, "Test requires TensorFlow")
def testJaxToTensorFlow(self, shape, dtype):
if jax.lib.version < (0, 1, 57):
raise unittest.SkipTest("Requires jaxlib >= 0.1.57");
if not FLAGS.jax_enable_x64 and dtype in [jnp.int64, jnp.uint64,
jnp.float64]:
self.skipTest("x64 types are disabled by jax_enable_x64")
@ -164,8 +160,6 @@ class DLPackTest(jtu.JaxTestCase):
for dtype in torch_dtypes))
@unittest.skipIf(not torch, "Test requires PyTorch")
def testJaxToTorch(self, shape, dtype):
if jax.lib.version < (0, 1, 57):
raise unittest.SkipTest("Requires jaxlib >= 0.1.57");
if not FLAGS.jax_enable_x64 and dtype in [jnp.int64, jnp.float64]:
self.skipTest("x64 types are disabled by jax_enable_x64")
rng = jtu.rand_default(self.rng())
@ -202,10 +196,8 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase):
class Bfloat16Test(jtu.JaxTestCase):
@unittest.skipIf((not tf or tf_version < (2, 5, 0) or
jax.lib.version < (0, 1, 58)),
"Test requires TensorFlow 2.5.0 or newer and jaxlib 0.1.58 "
"or newer")
@unittest.skipIf((not tf or tf_version < (2, 5, 0)),
"Test requires TensorFlow 2.5.0 or newer")
def testJaxAndTfHaveTheSameBfloat16Type(self):
self.assertEqual(np.dtype(jnp.bfloat16).num,
np.dtype(tf.dtypes.bfloat16.as_numpy_dtype).num)

View File

@ -302,7 +302,11 @@ JAX_COMPOUND_OP_RECORDS = [
# numpy.unwrap always returns float64
check_dtypes=False,
# numpy cumsum is inaccurate, see issue #3517
tolerance={dtypes.bfloat16: 1e-1, np.float16: 1e-1})
tolerance={dtypes.bfloat16: 1e-1, np.float16: 1e-1}),
op_record("isclose", 2, [t for t in all_dtypes if t != jnp.bfloat16],
all_shapes, jtu.rand_small_positive, []),
op_record("gcd", 2, int_dtypes_no_uint64, all_shapes, jtu.rand_default, []),
op_record("lcm", 2, int_dtypes_no_uint64, all_shapes, jtu.rand_default, []),
]
JAX_BITWISE_OP_RECORDS = [
@ -351,6 +355,7 @@ JAX_REDUCER_NO_DTYPE_RECORDS = [
[], inexact=True),
op_record("nanstd", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan,
[], inexact=True),
op_record("ptp", 1, number_dtypes, nonempty_shapes, jtu.rand_default, []),
]
JAX_ARGMINMAX_RECORDS = [
@ -427,18 +432,6 @@ for rec in JAX_OPERATOR_OVERLOADS + JAX_RIGHT_OPERATOR_OVERLOADS:
setattr(_OverrideNothing, rec.name, lambda self, other: NotImplemented)
if numpy_version >= (1, 15):
JAX_COMPOUND_OP_RECORDS += [
op_record("isclose", 2, [t for t in all_dtypes if t != jnp.bfloat16],
all_shapes, jtu.rand_small_positive, []),
op_record("gcd", 2, int_dtypes_no_uint64, all_shapes, jtu.rand_default, []),
op_record("lcm", 2, int_dtypes_no_uint64, all_shapes, jtu.rand_default, []),
]
JAX_REDUCER_NO_DTYPE_RECORDS += [
op_record("ptp", 1, number_dtypes, nonempty_shapes, jtu.rand_default, []),
]
def _dtypes_are_compatible_for_bitwise_ops(args):
if len(args) <= 1:
return True
@ -3740,8 +3733,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
'midpoint']))
def testQuantile(self, op, a_rng, q_rng, a_shape, a_dtype, q_shape, q_dtype,
axis, keepdims, interpolation):
if "quantile" in op and numpy_version < (1, 15):
raise SkipTest("Numpy < 1.15 does not have np.quantile")
a_rng = a_rng(self.rng())
q_rng = q_rng(self.rng())
if "median" in op:

View File

@ -15,11 +15,9 @@
import collections
from unittest import skipIf
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import test_util as jtu
from jax import tree_util
@ -202,7 +200,6 @@ class TreeTest(jtu.JaxTestCase):
self.assertEqual(out, (((1, [3]), (2, None)),
(([3, 4, 5], ({"foo": "bar"}, 7, [5, 6])))))
@skipIf(jax.lib.version < (0, 1, 58), "test requires Jaxlib >= 0.1.58")
def testFlattenIsLeaf(self):
x = [(1, 2), (3, 4), (5, 6)]
leaves, _ = tree_util.tree_flatten(x, is_leaf=lambda t: False)
@ -220,7 +217,6 @@ class TreeTest(jtu.JaxTestCase):
y, is_leaf=lambda t: isinstance(t, tuple))
self.assertEqual(leaves, [(1,), (2,), (3,)])
@skipIf(jax.lib.version < (0, 1, 58), "test requires Jaxlib >= 0.1.58")
@parameterized.parameters(*TREES)
def testRoundtripIsLeaf(self, tree):
xs, treedef = tree_util.tree_flatten(

View File

@ -88,8 +88,6 @@ def with_mesh_from_kwargs(f):
class XMapTest(jtu.JaxTestCase):
def setUp(self):
if jax.lib.version < (0, 1, 58):
raise SkipTest("xmap requires jaxlib version >= 0.1.58")
if not config.omnistaging_enabled:
raise SkipTest("xmap requires omnistaging")
@ -374,8 +372,6 @@ class XMapTestSPMD(XMapTest):
class NamedNumPyTest(jtu.JaxTestCase):
def setUp(self):
if jax.lib.version < (0, 1, 58):
raise SkipTest("xmap requires jaxlib version >= 0.1.58")
if not config.omnistaging_enabled:
raise SkipTest("xmap requires omnistaging")