Deprecate remaining functionality in jax.test_util

This commit is contained in:
Jake VanderPlas 2022-04-21 12:12:40 -07:00
parent bef5e02816
commit d9508304e4
2 changed files with 38 additions and 20 deletions

View File

@ -25,6 +25,14 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
are not of an integer type, matching the behavior of are not of an integer type, matching the behavior of
{func}`numpy.take_along_axis`. Previously non-integer indices were silently {func}`numpy.take_along_axis`. Previously non-integer indices were silently
cast to integers. cast to integers.
* Deprecations
* Many functions and objects available in {mod}`jax.test_util` are now deprecated and will raise a
warning on import. This includes `cases_from_list`, `check_close`, `check_eq`, `device_under_test`,
`format_shape_dtype_string`, `rand_uniform`, `skip_on_devices`, `with_config`, `xla_bridge`, and
`_default_tolerance` ({jax-issue}`#10389`). These, along with previously-deprecated `JaxTestCase`,
`JaxTestLoader`, and `BufferDonationTestCase`, will be removed in a future JAX release.
Many of these utilities still exist in `jax._src.test_util`, but these are not public APIs and
as such may be changed or removed without notice.
## jaxlib 0.3.8 (Unreleased) ## jaxlib 0.3.8 (Unreleased)
* [GitHub * [GitHub
@ -56,7 +64,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
pod. Fixes [#10218](https://github.com/google/jax/issues/10218). pod. Fixes [#10218](https://github.com/google/jax/issues/10218).
* Deprecations: * Deprecations:
* {mod}`jax.experimental.loops` is being deprecated. See {jax-issue}`#10278` * {mod}`jax.experimental.loops` is being deprecated. See {jax-issue}`#10278`
for an alternative API. for an alternative API.
## jax 0.3.5 (April 7, 2022) ## jax 0.3.5 (April 7, 2022)
* [GitHub * [GitHub

View File

@ -12,30 +12,40 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# flake8: noqa: F401 from jax._src.public_test_util import ( # noqa: F401
from jax._src.public_test_util import (
check_grads as check_grads, check_grads as check_grads,
check_jvp as check_jvp, check_jvp as check_jvp,
check_vjp as check_vjp, check_vjp as check_vjp,
) )
# Conditional imports of private test utilities; these require their own BUILD target. # TODO(jakevdp): remove everything below once downstream callers are fixed.
# TODO(jakevdp): remove these imports once downstream dependencies are cleaned.
# Unconditionally import private test_util because it contains flag definitions.
# In bazel, jax._src.test_util requires its own BUILD target so it may not be present.
# pytype: disable=import-error
try: try:
from jax._src.test_util import ( # pytype: disable=import-error import jax._src.test_util as _private_test_util
cases_from_list,
check_close,
check_eq,
device_under_test,
format_shape_dtype_string,
rand_uniform,
skip_on_devices,
with_config as with_config,
xla_bridge,
_default_tolerance,
DeprecatedJaxTestCase as JaxTestCase,
DeprecatedJaxTestLoader as JaxTestLoader,
DeprecatedBufferDonationTestCase as BufferDonationTestCase,
)
except ImportError: except ImportError:
pass pass
else:
del _private_test_util
# Use module-level getattr to add warnings to imports of deprecated names.
# pylint: disable=import-outside-toplevel
def __getattr__(attr):
try:
from jax._src import test_util
except ImportError:
raise AttributeError(f"module {__name__} has no attribute {attr}")
if attr in ['cases_from_list', 'check_close', 'check_eq', 'device_under_test',
'format_shape_dtype_string', 'rand_uniform', 'skip_on_devices',
'with_config', 'xla_bridge', '_default_tolerance']:
import warnings
warnings.warn(f"jax.test_util.{attr} is deprecated and will soon be removed.", FutureWarning)
return getattr(test_util, attr)
elif attr in ['JaxTestCase', 'JaxTestLoader', 'BufferDonationTestCase']:
# Do the TestCase imports separately, since they were previously deprecated via a different
# mechanism & we don't want to annoy projects who may have temporarily filtered a specific warning.
return getattr(test_util, 'Deprecated' + attr)
else:
raise AttributeError(f"module {__name__} has no attribute {attr}")