mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Add deprecation warning to JaxTestCase and JaxTestLoader
This commit is contained in:
parent
e545daa1e5
commit
da3aaa1960
29
CHANGELOG.md
29
CHANGELOG.md
@ -11,17 +11,24 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
## jax 0.3.1 (Unreleased)
|
||||
* [GitHub
|
||||
commits](https://github.com/google/jax/compare/jax-v0.3.0...main).
|
||||
* `jax.test_util.JaxTestCase` now sets `jax_numpy_rank_promotion='raise'` by
|
||||
default. To recover the previous behavior, use the `jax.test_util.with_config`
|
||||
decorator:
|
||||
```python
|
||||
@jtu.with_config(jax_numpy_rank_promotion='allow')
|
||||
class MyTestCase(jtu.JaxTestCase):
|
||||
...
|
||||
```
|
||||
* Added ``jax.scipy.linalg.schur``, ``jax.scipy.linalg.sqrtm``,
|
||||
``jax.scipy.signal.csd``, ``jax.scipy.signal.stft``,
|
||||
``jax.scipy.signal.welch``.
|
||||
|
||||
* Changes:
|
||||
* `jax.test_util.JaxTestCase` and `jax.test_util.JaxTestLoader` are now deprecated.
|
||||
The suggested replacement is to use `parametrized.TestCase` directly. For tests that
|
||||
rely on custom asserts such as `JaxTestCase.assertAllClose()`, the suggested replacement
|
||||
is to use standard numpy testing utilities such as {func}`numpy.testing.assert_allclose()`,
|
||||
which work directly with JAX arrays ({jax-issue}`#9620`).
|
||||
* `jax.test_util.JaxTestCase` now sets `jax_numpy_rank_promotion='raise'` by default
|
||||
({jax-issue}`#9562`). To recover the previous behavior, use the new
|
||||
`jax.test_util.with_config` decorator:
|
||||
```python
|
||||
@jtu.with_config(jax_numpy_rank_promotion='allow')
|
||||
class MyTestCase(jtu.JaxTestCase):
|
||||
...
|
||||
```
|
||||
* Added {func}`jax.scipy.linalg.schur`, {func}`jax.scipy.linalg.sqrtm`,
|
||||
{func}`jax.scipy.signal.csd`, {func}`jax.scipy.signal.stft`,
|
||||
{func}`jax.scipy.signal.welch`.
|
||||
|
||||
## jaxlib 0.3.1 (Unreleased)
|
||||
* Changes
|
||||
|
@ -15,8 +15,8 @@
|
||||
# flake8: noqa: F401
|
||||
# TODO(phawkins): remove all exports except check_grads/check_jvp/check_vjp.
|
||||
from jax._src.test_util import (
|
||||
JaxTestCase,
|
||||
JaxTestLoader,
|
||||
JaxTestCase as _PrivateJaxTestCase,
|
||||
JaxTestLoader as _PrivateJaxTestLoader,
|
||||
cases_from_list,
|
||||
check_close,
|
||||
check_eq,
|
||||
@ -31,3 +31,26 @@ from jax._src.test_util import (
|
||||
xla_bridge,
|
||||
_default_tolerance
|
||||
)
|
||||
|
||||
class JaxTestCase(_PrivateJaxTestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
import warnings
|
||||
import textwrap
|
||||
warnings.warn(textwrap.dedent("""\
|
||||
jax.test_util.JaxTestCase is deprecated as of jax version 0.3.1:
|
||||
The suggested replacement is to use parametrized.TestCase directly.
|
||||
For tests that rely on custom asserts such as JaxTestCase.assertAllClose(),
|
||||
the suggested replacement is to use standard numpy testing utilities such
|
||||
as np.testing.assert_allclose(), which work directly with JAX arrays."""),
|
||||
category=DeprecationWarning)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
class JaxTestLoader(_PrivateJaxTestLoader):
|
||||
def __init__(self, *args, **kwargs):
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"jax.test_util.JaxTestLoader is deprecated as of jax version 0.3.1. Use absltest.TestLoader directly.",
|
||||
category=DeprecationWarning)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
del _PrivateJaxTestCase, _PrivateJaxTestLoader
|
@ -22,9 +22,9 @@ from typing import Sequence
|
||||
from absl import logging
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
from jax import test_util
|
||||
from jax.experimental import mesh_utils
|
||||
from jax.experimental.maps import Mesh
|
||||
from jax._src import test_util
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -16,12 +16,12 @@
|
||||
import functools
|
||||
|
||||
import jax
|
||||
from jax import test_util as jtu
|
||||
from jax.config import config
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import scipy.linalg as osp_linalg
|
||||
from jax._src.lax import svd
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
Loading…
x
Reference in New Issue
Block a user