mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
jax.test_util: remove deprecated test classes.
JaxTestCase and JaxTestLoader were deprecated in jax v0.3.1, released Feb 2022.
This commit is contained in:
parent
997beb3ce0
commit
887abbc3b9
@ -11,6 +11,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
## jax 0.3.15 (Unreleased)
|
## jax 0.3.15 (Unreleased)
|
||||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.14...main).
|
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.14...main).
|
||||||
* Changes
|
* Changes
|
||||||
|
* `JaxTestCase` and `JaxTestLoader` have been removed from `jax.test_util`. These
|
||||||
|
classes have been deprecated since v0.3.1 ({jax-issue}`#11248`).
|
||||||
|
|
||||||
## jaxlib 0.3.15 (Unreleased)
|
## jaxlib 0.3.15 (Unreleased)
|
||||||
|
|
||||||
|
@ -1026,38 +1026,6 @@ class _LazyDtypes:
|
|||||||
dtypes = _LazyDtypes()
|
dtypes = _LazyDtypes()
|
||||||
|
|
||||||
|
|
||||||
class DeprecatedJaxTestCase(JaxTestCase):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
warnings.warn(textwrap.dedent("""\
|
|
||||||
jax.test_util.JaxTestCase is deprecated as of jax version 0.3.1:
|
|
||||||
The suggested replacement is to use parametrized.TestCase directly.
|
|
||||||
For tests that rely on custom asserts such as JaxTestCase.assertAllClose(),
|
|
||||||
the suggested replacement is to use standard numpy testing utilities such
|
|
||||||
as np.testing.assert_allclose(), which work directly with JAX arrays."""),
|
|
||||||
category=DeprecationWarning)
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class DeprecatedJaxTestLoader(JaxTestLoader):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
warnings.warn(
|
|
||||||
"jax.test_util.JaxTestLoader is deprecated as of jax version 0.3.1. Use absltest.TestLoader directly.",
|
|
||||||
category=DeprecationWarning)
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class DeprecatedBufferDonationTestCase(BufferDonationTestCase):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
warnings.warn(textwrap.dedent("""\
|
|
||||||
jax.test_util.JaxTestCase is deprecated as of jax version 0.3.1:
|
|
||||||
The suggested replacement is to use parametrized.TestCase directly.
|
|
||||||
For tests that rely on custom asserts such as JaxTestCase.assertAllClose(),
|
|
||||||
the suggested replacement is to use standard numpy testing utilities such
|
|
||||||
as np.testing.assert_allclose(), which work directly with JAX arrays."""),
|
|
||||||
category=DeprecationWarning)
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def strict_promotion_if_dtypes_match(dtypes):
|
def strict_promotion_if_dtypes_match(dtypes):
|
||||||
"""
|
"""
|
||||||
Context manager to enable strict promotion if all dtypes match,
|
Context manager to enable strict promotion if all dtypes match,
|
||||||
|
@ -43,9 +43,5 @@ def __getattr__(attr):
|
|||||||
import warnings
|
import warnings
|
||||||
warnings.warn(f"jax.test_util.{attr} is deprecated and will soon be removed.", FutureWarning)
|
warnings.warn(f"jax.test_util.{attr} is deprecated and will soon be removed.", FutureWarning)
|
||||||
return getattr(test_util, attr)
|
return getattr(test_util, attr)
|
||||||
elif attr in ['JaxTestCase', 'JaxTestLoader', 'BufferDonationTestCase']:
|
|
||||||
# Do the TestCase imports separately, since they were previously deprecated via a different
|
|
||||||
# mechanism & we don't want to annoy projects who may have temporarily filtered a specific warning.
|
|
||||||
return getattr(test_util, 'Deprecated' + attr)
|
|
||||||
else:
|
else:
|
||||||
raise AttributeError(f"module {__name__} has no attribute {attr}")
|
raise AttributeError(f"module {__name__} has no attribute {attr}")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user