From da3aaa1960703073621f17195164244558a33260 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 17 Feb 2022 14:58:58 -0800 Subject: [PATCH] Add deprecation warning to JaxTestCase and JaxTestLoader --- CHANGELOG.md | 29 ++++++++++++++++++----------- jax/test_util.py | 27 +++++++++++++++++++++++++-- tests/mesh_utils_test.py | 2 +- tests/svd_test.py | 2 +- 4 files changed, 45 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 04304216b..301dd270d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,17 +11,24 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. ## jax 0.3.1 (Unreleased) * [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.0...main). -* `jax.test_util.JaxTestCase` now sets `jax_numpy_rank_promotion='raise'` by - default. To recover the previous behavior, use the `jax.test_util.with_config` - decorator: - ```python - @jtu.with_config(jax_numpy_rank_promotion='allow') - class MyTestCase(jtu.JaxTestCase): - ... - ``` -* Added ``jax.scipy.linalg.schur``, ``jax.scipy.linalg.sqrtm``, - ``jax.scipy.signal.csd``, ``jax.scipy.signal.stft``, - ``jax.scipy.signal.welch``. + +* Changes: + * `jax.test_util.JaxTestCase` and `jax.test_util.JaxTestLoader` are now deprecated. + The suggested replacement is to use `parametrized.TestCase` directly. For tests that + rely on custom asserts such as `JaxTestCase.assertAllClose()`, the suggested replacement + is to use standard numpy testing utilities such as {func}`numpy.testing.assert_allclose()`, + which work directly with JAX arrays ({jax-issue}`#9620`). + * `jax.test_util.JaxTestCase` now sets `jax_numpy_rank_promotion='raise'` by default + ({jax-issue}`#9562`). To recover the previous behavior, use the new + `jax.test_util.with_config` decorator: + ```python + @jtu.with_config(jax_numpy_rank_promotion='allow') + class MyTestCase(jtu.JaxTestCase): + ... + ``` + * Added {func}`jax.scipy.linalg.schur`, {func}`jax.scipy.linalg.sqrtm`, + {func}`jax.scipy.signal.csd`, {func}`jax.scipy.signal.stft`, + {func}`jax.scipy.signal.welch`. ## jaxlib 0.3.1 (Unreleased) * Changes diff --git a/jax/test_util.py b/jax/test_util.py index bfa74e095..8df6ea923 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -15,8 +15,8 @@ # flake8: noqa: F401 # TODO(phawkins): remove all exports except check_grads/check_jvp/check_vjp. from jax._src.test_util import ( - JaxTestCase, - JaxTestLoader, + JaxTestCase as _PrivateJaxTestCase, + JaxTestLoader as _PrivateJaxTestLoader, cases_from_list, check_close, check_eq, @@ -31,3 +31,26 @@ from jax._src.test_util import ( xla_bridge, _default_tolerance ) + +class JaxTestCase(_PrivateJaxTestCase): + def __init__(self, *args, **kwargs): + import warnings + import textwrap + warnings.warn(textwrap.dedent("""\ + jax.test_util.JaxTestCase is deprecated as of jax version 0.3.1: + The suggested replacement is to use parametrized.TestCase directly. + For tests that rely on custom asserts such as JaxTestCase.assertAllClose(), + the suggested replacement is to use standard numpy testing utilities such + as np.testing.assert_allclose(), which work directly with JAX arrays."""), + category=DeprecationWarning) + super().__init__(*args, **kwargs) + +class JaxTestLoader(_PrivateJaxTestLoader): + def __init__(self, *args, **kwargs): + import warnings + warnings.warn( + "jax.test_util.JaxTestLoader is deprecated as of jax version 0.3.1. Use absltest.TestLoader directly.", + category=DeprecationWarning) + super().__init__(*args, **kwargs) + +del _PrivateJaxTestCase, _PrivateJaxTestLoader \ No newline at end of file diff --git a/tests/mesh_utils_test.py b/tests/mesh_utils_test.py index cf9d93dcf..5a4d9868d 100644 --- a/tests/mesh_utils_test.py +++ b/tests/mesh_utils_test.py @@ -22,9 +22,9 @@ from typing import Sequence from absl import logging from absl.testing import absltest from absl.testing import parameterized -from jax import test_util from jax.experimental import mesh_utils from jax.experimental.maps import Mesh +from jax._src import test_util @dataclasses.dataclass diff --git a/tests/svd_test.py b/tests/svd_test.py index 28cec1849..aa50804c3 100644 --- a/tests/svd_test.py +++ b/tests/svd_test.py @@ -16,12 +16,12 @@ import functools import jax -from jax import test_util as jtu from jax.config import config import jax.numpy as jnp import numpy as np import scipy.linalg as osp_linalg from jax._src.lax import svd +from jax._src import test_util as jtu from absl.testing import absltest from absl.testing import parameterized