diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index ce2e54ce3..380f9b30a 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -44,6 +44,7 @@ from jax import lax from jax._src import api from jax._src import config from jax._src import core +from jax._src import deprecations from jax._src import dispatch from jax._src import dtypes as _dtypes from jax._src import lib as _jaxlib @@ -1310,6 +1311,16 @@ class JaxTestCase(parameterized.TestCase): def rng(self): return self._rng + def assertDeprecationWarnsOrRaises(self, deprecation_id: str, message: str): + """Assert warning or error, depending on deprecation state. + + For use with functions that call :func:`jax._src.deprecations.warn`. + """ + if deprecations.is_accelerated(deprecation_id): + return self.assertRaisesRegex(ValueError, message) + else: + return self.assertWarnsRegex(DeprecationWarning, message) + def assertArraysEqual(self, x, y, *, check_dtypes=True, err_msg='', allow_object_dtype=False, verbose=True): """Assert that x and y arrays are exactly equal.""" diff --git a/tests/deprecation_test.py b/tests/deprecation_test.py index f9313449a..c2bd599d8 100644 --- a/tests/deprecation_test.py +++ b/tests/deprecation_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib + from absl.testing import absltest from jax._src import deprecations from jax._src import test_util as jtu @@ -20,6 +22,14 @@ from jax._src.internal_test_util import deprecation_module as m class DeprecationTest(absltest.TestCase): + @contextlib.contextmanager + def deprecation_context(self, deprecation_id): + deprecations.register(deprecation_id) + try: + yield + finally: + deprecations.unregister(deprecation_id) + def testModuleDeprecation(self): with test_warning_util.raise_on_warnings(): self.assertEqual(m.x, 42) @@ -36,13 +46,10 @@ class DeprecationTest(absltest.TestCase): def testNamedDeprecation(self): some_unique_id = "some-unique-id" - try: - deprecations.register(some_unique_id) + with self.deprecation_context(some_unique_id): self.assertFalse(deprecations.is_accelerated(some_unique_id)) deprecations.accelerate(some_unique_id) self.assertTrue(deprecations.is_accelerated(some_unique_id)) - finally: - deprecations.unregister(some_unique_id) msg = f"deprecation_id={some_unique_id!r} not registered" with self.assertRaisesRegex(ValueError, msg): @@ -52,6 +59,19 @@ class DeprecationTest(absltest.TestCase): with self.assertRaisesRegex(ValueError, msg): deprecations.unregister(some_unique_id) + def testNamedDeprecationWarns(self): + deprecation_id = "some-unique-id" + deprecation_message = "This API is deprecated." + with self.deprecation_context(deprecation_id): + self.assertFalse(deprecations.is_accelerated(deprecation_id)) + with self.assertWarnsRegex(DeprecationWarning, deprecation_message): + deprecations.warn(deprecation_id, deprecation_message, stacklevel=1) + + deprecations.accelerate(deprecation_id) + self.assertTrue(deprecations.is_accelerated(deprecation_id)) + with self.assertRaisesRegex(ValueError, deprecation_message): + deprecations.warn(deprecation_id, deprecation_message, stacklevel=1) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/ffi_test.py b/tests/ffi_test.py index b36c461d4..40dff4b9f 100644 --- a/tests/ffi_test.py +++ b/tests/ffi_test.py @@ -28,7 +28,6 @@ import jax.sharding as shd from jax._src import config from jax._src import core -from jax._src import deprecations from jax._src import test_util as jtu from jax._src.interpreters import mlir from jax._src.layout import DeviceLocalLayout @@ -210,12 +209,8 @@ class FfiTest(jtu.JaxTestCase): def fun(x): return jax.ffi.ffi_call("test_ffi", x, x, param=0.5) msg = "Calling ffi_call directly with input arguments is deprecated" - if deprecations.is_accelerated("jax-ffi-call-args"): - with self.assertRaisesRegex(ValueError, msg): - jax.jit(fun).lower(jnp.ones(5)) - else: - with self.assertWarnsRegex(DeprecationWarning, msg): - jax.jit(fun).lower(jnp.ones(5)) + with self.assertDeprecationWarnsOrRaises("jax-ffi-call-args", msg): + jax.jit(fun).lower(jnp.ones(5)) def test_input_output_aliases(self): def fun(x): diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index f4667df85..027eac86f 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -27,7 +27,6 @@ import jax from jax import numpy as jnp from jax._src import config -from jax._src import deprecations from jax._src import dtypes from jax._src import test_util as jtu from jax._src.util import NumpyComplexWarning @@ -454,12 +453,8 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): x = jnp.zeros((10,), dtype) where = jnp.ones(10, dtype=int) func = getattr(jnp, rec.name) - def assert_warns_or_errors(msg): - if deprecations.is_accelerated("jax-numpy-reduction-non-boolean-where"): - return self.assertRaisesRegex(ValueError, msg) - else: - return self.assertWarnsRegex(DeprecationWarning, msg) - with assert_warns_or_errors(f"jnp.{rec.name}: where must be None or a boolean array"): + with self.assertDeprecationWarnsOrRaises("jax-numpy-reduction-non-boolean-where", + f"jnp.{rec.name}: where must be None or a boolean array"): func(x, where=where, initial=jnp.array(0, dtype=dtype)) @jtu.sample_product(rec=JAX_REDUCER_WHERE_NO_INITIAL_RECORDS) @@ -468,12 +463,8 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): x = jnp.zeros((10,), dtype) where = jnp.ones(10, dtype=int) func = getattr(jnp, rec.name) - def assert_warns_or_errors(msg): - if deprecations.is_accelerated("jax-numpy-reduction-non-boolean-where"): - return self.assertRaisesRegex(ValueError, msg) - else: - return self.assertWarnsRegex(DeprecationWarning, msg) - with assert_warns_or_errors(f"jnp.{rec.name}: where must be None or a boolean array"): + with self.assertDeprecationWarnsOrRaises("jax-numpy-reduction-non-boolean-where", + f"jnp.{rec.name}: where must be None or a boolean array"): func(x, where=where) @parameterized.parameters(itertools.chain.from_iterable( @@ -756,13 +747,8 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): ) def testQuantileDeprecatedArgs(self, op): func = getattr(jnp, op) - msg = f"The interpolation= argument to '{op}' is deprecated. " - def assert_warns_or_errors(msg=msg): - if deprecations.is_accelerated("jax-numpy-quantile-interpolation"): - return self.assertRaisesRegex(ValueError, msg) - else: - return self.assertWarnsRegex(DeprecationWarning, msg) - with assert_warns_or_errors(msg): + with self.assertDeprecationWarnsOrRaises("jax-numpy-quantile-interpolation", + f"The interpolation= argument to '{op}' is deprecated. "): func(jnp.arange(4), 0.5, interpolation='linear') @unittest.skipIf(not config.enable_x64.value, "test requires X64") diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index c9b9779f0..b1328016b 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -47,7 +47,6 @@ from jax.test_util import check_grads from jax._src import array from jax._src import config from jax._src import core -from jax._src import deprecations from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal @@ -1061,13 +1060,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): jnp.clip(x, max=jnp.array([-1+5j])) def testClipDeprecatedArgs(self): - msg = "Passing arguments 'a', 'a_min' or 'a_max' to jax.numpy.clip is deprecated" - def assert_warns_or_errors(msg=msg): - if deprecations.is_accelerated("jax-numpy-clip-args"): - return self.assertRaisesRegex(ValueError, msg) - else: - return self.assertWarnsRegex(DeprecationWarning, msg) - with assert_warns_or_errors(msg): + with self.assertDeprecationWarnsOrRaises("jax-numpy-clip-args", + "Passing arguments 'a', 'a_min' or 'a_max' to jax.numpy.clip is deprecated"): jnp.clip(jnp.arange(4), a_min=2, a_max=3) def testHypotComplexInputError(self): @@ -4186,13 +4180,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): def testAstypeComplexDowncast(self): x = jnp.array(2.0+1.5j, dtype='complex64') - msg = "Casting from complex to real dtypes.*" - def assert_warns_or_errors(msg=msg): - if deprecations.is_accelerated("jax-numpy-astype-complex-to-real"): - return self.assertRaisesRegex(ValueError, msg) - else: - return self.assertWarnsRegex(DeprecationWarning, msg) - with assert_warns_or_errors(): + with self.assertDeprecationWarnsOrRaises("jax-numpy-astype-complex-to-real", + "Casting from complex to real dtypes.*"): x.astype('float32') @parameterized.parameters('int2', 'int4') diff --git a/tests/linalg_test.py b/tests/linalg_test.py index cbbab9dae..8327e8da4 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -32,7 +32,6 @@ from jax import lax from jax import numpy as jnp from jax import scipy as jsp from jax._src import config -from jax._src import deprecations from jax._src.lax import linalg as lax_linalg from jax._src import test_util as jtu from jax._src import xla_bridge @@ -1177,14 +1176,9 @@ class NumpyLinalgTest(jtu.JaxTestCase): jtu.check_grads(jnp_fn, args_maker(), 1, rtol=6e-2, atol=1e-3) def testPinvDeprecatedArgs(self): - msg = "The rcond argument for linalg.pinv is deprecated." - def assert_warns_or_errors(msg=msg): - if deprecations.is_accelerated("jax-numpy-linalg-pinv-rcond"): - return self.assertRaisesRegex(ValueError, msg) - else: - return self.assertWarnsRegex(DeprecationWarning, msg) x = jnp.ones((3, 3)) - with assert_warns_or_errors(msg): + with self.assertDeprecationWarnsOrRaises("jax-numpy-linalg-pinv-rcond", + "The rcond argument for linalg.pinv is deprecated."): jnp.linalg.pinv(x, rcond=1E-2) def testPinvGradIssue2792(self): @@ -1230,14 +1224,9 @@ class NumpyLinalgTest(jtu.JaxTestCase): check_dtypes=False, rtol=1e-3) def testMatrixRankDeprecatedArgs(self): - msg = "The tol argument for linalg.matrix_rank is deprecated." - def assert_warns_or_errors(msg=msg): - if deprecations.is_accelerated("jax-numpy-linalg-matrix_rank-tol"): - return self.assertRaisesRegex(ValueError, msg) - else: - return self.assertWarnsRegex(DeprecationWarning, msg) x = jnp.ones((3, 3)) - with assert_warns_or_errors(msg): + with self.assertDeprecationWarnsOrRaises("jax-numpy-linalg-matrix_rank-tol", + "The tol argument for linalg.matrix_rank is deprecated."): jnp.linalg.matrix_rank(x, tol=1E-2) @jtu.sample_product( diff --git a/tests/nn_test.py b/tests/nn_test.py index 7f0f80bb9..09390912d 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -27,7 +27,6 @@ import scipy.stats from jax._src import ad_checkpoint from jax._src import config from jax._src import core -from jax._src import deprecations from jax._src import test_util as jtu from jax._src.lib import cuda_versions from jax.test_util import check_grads @@ -531,12 +530,8 @@ class NNFunctionsTest(jtu.JaxTestCase): self.assertAllClose(actual, expected, check_dtypes=False) def testOneHotNonInteger(self): - def assert_warns_or_errors(msg): - if deprecations.is_accelerated("jax-nn-one-hot-float-input"): - return self.assertRaisesRegex(ValueError, msg) - else: - return self.assertWarnsRegex(DeprecationWarning, msg) - with assert_warns_or_errors("jax.nn.one_hot input should be integer-typed"): + with self.assertDeprecationWarnsOrRaises("jax-nn-one-hot-float-input", + "jax.nn.one_hot input should be integer-typed"): nn.one_hot(jnp.array([1.0]), 3) def testTanhExists(self):