jax.test_util: remove deprecated test classes.

JaxTestCase and JaxTestLoader were deprecated in jax v0.3.1, released Feb 2022.
This commit is contained in:
Jake VanderPlas 2022-06-27 11:04:50 -07:00
parent 997beb3ce0
commit 887abbc3b9
3 changed files with 2 additions and 36 deletions

View File

@ -11,6 +11,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.3.15 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.14...main).
* Changes
* `JaxTestCase` and `JaxTestLoader` have been removed from `jax.test_util`. These
classes have been deprecated since v0.3.1 ({jax-issue}`#11248`).
## jaxlib 0.3.15 (Unreleased)

View File

@ -1026,38 +1026,6 @@ class _LazyDtypes:
dtypes = _LazyDtypes()
class DeprecatedJaxTestCase(JaxTestCase):
def __init__(self, *args, **kwargs):
warnings.warn(textwrap.dedent("""\
jax.test_util.JaxTestCase is deprecated as of jax version 0.3.1:
The suggested replacement is to use parametrized.TestCase directly.
For tests that rely on custom asserts such as JaxTestCase.assertAllClose(),
the suggested replacement is to use standard numpy testing utilities such
as np.testing.assert_allclose(), which work directly with JAX arrays."""),
category=DeprecationWarning)
super().__init__(*args, **kwargs)
class DeprecatedJaxTestLoader(JaxTestLoader):
def __init__(self, *args, **kwargs):
warnings.warn(
"jax.test_util.JaxTestLoader is deprecated as of jax version 0.3.1. Use absltest.TestLoader directly.",
category=DeprecationWarning)
super().__init__(*args, **kwargs)
class DeprecatedBufferDonationTestCase(BufferDonationTestCase):
def __init__(self, *args, **kwargs):
warnings.warn(textwrap.dedent("""\
jax.test_util.JaxTestCase is deprecated as of jax version 0.3.1:
The suggested replacement is to use parametrized.TestCase directly.
For tests that rely on custom asserts such as JaxTestCase.assertAllClose(),
the suggested replacement is to use standard numpy testing utilities such
as np.testing.assert_allclose(), which work directly with JAX arrays."""),
category=DeprecationWarning)
super().__init__(*args, **kwargs)
def strict_promotion_if_dtypes_match(dtypes):
"""
Context manager to enable strict promotion if all dtypes match,

View File

@ -43,9 +43,5 @@ def __getattr__(attr):
import warnings
warnings.warn(f"jax.test_util.{attr} is deprecated and will soon be removed.", FutureWarning)
return getattr(test_util, attr)
elif attr in ['JaxTestCase', 'JaxTestLoader', 'BufferDonationTestCase']:
# Do the TestCase imports separately, since they were previously deprecated via a different
# mechanism & we don't want to annoy projects who may have temporarily filtered a specific warning.
return getattr(test_util, 'Deprecated' + attr)
else:
raise AttributeError(f"module {__name__} has no attribute {attr}")