Add deprecation warning to JaxTestCase and JaxTestLoader

This commit is contained in:
Jake VanderPlas 2022-02-17 14:58:58 -08:00
parent e545daa1e5
commit da3aaa1960
4 changed files with 45 additions and 15 deletions

View File

@ -11,17 +11,24 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.3.1 (Unreleased)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.0...main).
* `jax.test_util.JaxTestCase` now sets `jax_numpy_rank_promotion='raise'` by
default. To recover the previous behavior, use the `jax.test_util.with_config`
decorator:
```python
@jtu.with_config(jax_numpy_rank_promotion='allow')
class MyTestCase(jtu.JaxTestCase):
...
```
* Added ``jax.scipy.linalg.schur``, ``jax.scipy.linalg.sqrtm``,
``jax.scipy.signal.csd``, ``jax.scipy.signal.stft``,
``jax.scipy.signal.welch``.
* Changes:
* `jax.test_util.JaxTestCase` and `jax.test_util.JaxTestLoader` are now deprecated.
The suggested replacement is to use `parametrized.TestCase` directly. For tests that
rely on custom asserts such as `JaxTestCase.assertAllClose()`, the suggested replacement
is to use standard numpy testing utilities such as {func}`numpy.testing.assert_allclose()`,
which work directly with JAX arrays ({jax-issue}`#9620`).
* `jax.test_util.JaxTestCase` now sets `jax_numpy_rank_promotion='raise'` by default
({jax-issue}`#9562`). To recover the previous behavior, use the new
`jax.test_util.with_config` decorator:
```python
@jtu.with_config(jax_numpy_rank_promotion='allow')
class MyTestCase(jtu.JaxTestCase):
...
```
* Added {func}`jax.scipy.linalg.schur`, {func}`jax.scipy.linalg.sqrtm`,
{func}`jax.scipy.signal.csd`, {func}`jax.scipy.signal.stft`,
{func}`jax.scipy.signal.welch`.
## jaxlib 0.3.1 (Unreleased)
* Changes

View File

@ -15,8 +15,8 @@
# flake8: noqa: F401
# TODO(phawkins): remove all exports except check_grads/check_jvp/check_vjp.
from jax._src.test_util import (
JaxTestCase,
JaxTestLoader,
JaxTestCase as _PrivateJaxTestCase,
JaxTestLoader as _PrivateJaxTestLoader,
cases_from_list,
check_close,
check_eq,
@ -31,3 +31,26 @@ from jax._src.test_util import (
xla_bridge,
_default_tolerance
)
class JaxTestCase(_PrivateJaxTestCase):
def __init__(self, *args, **kwargs):
import warnings
import textwrap
warnings.warn(textwrap.dedent("""\
jax.test_util.JaxTestCase is deprecated as of jax version 0.3.1:
The suggested replacement is to use parametrized.TestCase directly.
For tests that rely on custom asserts such as JaxTestCase.assertAllClose(),
the suggested replacement is to use standard numpy testing utilities such
as np.testing.assert_allclose(), which work directly with JAX arrays."""),
category=DeprecationWarning)
super().__init__(*args, **kwargs)
class JaxTestLoader(_PrivateJaxTestLoader):
def __init__(self, *args, **kwargs):
import warnings
warnings.warn(
"jax.test_util.JaxTestLoader is deprecated as of jax version 0.3.1. Use absltest.TestLoader directly.",
category=DeprecationWarning)
super().__init__(*args, **kwargs)
del _PrivateJaxTestCase, _PrivateJaxTestLoader

View File

@ -22,9 +22,9 @@ from typing import Sequence
from absl import logging
from absl.testing import absltest
from absl.testing import parameterized
from jax import test_util
from jax.experimental import mesh_utils
from jax.experimental.maps import Mesh
from jax._src import test_util
@dataclasses.dataclass

View File

@ -16,12 +16,12 @@
import functools
import jax
from jax import test_util as jtu
from jax.config import config
import jax.numpy as jnp
import numpy as np
import scipy.linalg as osp_linalg
from jax._src.lax import svd
from jax._src import test_util as jtu
from absl.testing import absltest
from absl.testing import parameterized