diff --git a/CHANGELOG.md b/CHANGELOG.md index ee6dad30e..9897b7c8d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,14 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. are not of an integer type, matching the behavior of {func}`numpy.take_along_axis`. Previously non-integer indices were silently cast to integers. +* Deprecations + * Many functions and objects available in {mod}`jax.test_util` are now deprecated and will raise a + warning on import. This includes `cases_from_list`, `check_close`, `check_eq`, `device_under_test`, + `format_shape_dtype_string`, `rand_uniform`, `skip_on_devices`, `with_config`, `xla_bridge`, and + `_default_tolerance` ({jax-issue}`#10389`). These, along with previously-deprecated `JaxTestCase`, + `JaxTestLoader`, and `BufferDonationTestCase`, will be removed in a future JAX release. + Many of these utilities still exist in `jax._src.test_util`, but these are not public APIs and + as such may be changed or removed without notice. ## jaxlib 0.3.8 (Unreleased) * [GitHub @@ -56,7 +64,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. pod. Fixes [#10218](https://github.com/google/jax/issues/10218). * Deprecations: * {mod}`jax.experimental.loops` is being deprecated. See {jax-issue}`#10278` - for an alternative API. + for an alternative API. ## jax 0.3.5 (April 7, 2022) * [GitHub diff --git a/jax/test_util.py b/jax/test_util.py index 21e4194f1..e41f8d850 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -12,30 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. -# flake8: noqa: F401 -from jax._src.public_test_util import ( +from jax._src.public_test_util import ( # noqa: F401 check_grads as check_grads, check_jvp as check_jvp, check_vjp as check_vjp, ) -# Conditional imports of private test utilities; these require their own BUILD target. -# TODO(jakevdp): remove these imports once downstream dependencies are cleaned. +# TODO(jakevdp): remove everything below once downstream callers are fixed. + +# Unconditionally import private test_util because it contains flag definitions. +# In bazel, jax._src.test_util requires its own BUILD target so it may not be present. +# pytype: disable=import-error try: - from jax._src.test_util import ( # pytype: disable=import-error - cases_from_list, - check_close, - check_eq, - device_under_test, - format_shape_dtype_string, - rand_uniform, - skip_on_devices, - with_config as with_config, - xla_bridge, - _default_tolerance, - DeprecatedJaxTestCase as JaxTestCase, - DeprecatedJaxTestLoader as JaxTestLoader, - DeprecatedBufferDonationTestCase as BufferDonationTestCase, - ) + import jax._src.test_util as _private_test_util except ImportError: pass +else: + del _private_test_util + +# Use module-level getattr to add warnings to imports of deprecated names. +# pylint: disable=import-outside-toplevel +def __getattr__(attr): + try: + from jax._src import test_util + except ImportError: + raise AttributeError(f"module {__name__} has no attribute {attr}") + if attr in ['cases_from_list', 'check_close', 'check_eq', 'device_under_test', + 'format_shape_dtype_string', 'rand_uniform', 'skip_on_devices', + 'with_config', 'xla_bridge', '_default_tolerance']: + import warnings + warnings.warn(f"jax.test_util.{attr} is deprecated and will soon be removed.", FutureWarning) + return getattr(test_util, attr) + elif attr in ['JaxTestCase', 'JaxTestLoader', 'BufferDonationTestCase']: + # Do the TestCase imports separately, since they were previously deprecated via a different + # mechanism & we don't want to annoy projects who may have temporarily filtered a specific warning. + return getattr(test_util, 'Deprecated' + attr) + else: + raise AttributeError(f"module {__name__} has no attribute {attr}")