diff --git a/CHANGELOG.md b/CHANGELOG.md index ddc6d9020..7202e520d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. `non_negative_dim`. * from {mod}`jax.lib.xla_bridge`: `xla_client` and `default_backend`. * from {mod}`jax.lib.xla_client`: `_xla` and `bfloat16`. + * from {mod}`jax.numpy`: `round_`. ## jax 0.4.37 (Dec 9, 2024) diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index 7f2a5785a..c79f023e5 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -359,7 +359,6 @@ namespace; they are listed below. roots rot90 round - round_ s_ save savez diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index d0e06e68d..c447b0844 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -479,17 +479,15 @@ del register_jax_array_methods _deprecations = { - # Deprecated 03 Sept 2024 + # Finalized 2024-12-13; remove after 2024-3-13 "round_": ( - "jnp.round_ is deprecated; use jnp.round instead.", - round + "jnp.round_ was deprecated in JAX 0.4.38; use jnp.round instead.", + None ), } import typing -if typing.TYPE_CHECKING: - round_ = round -else: +if not typing.TYPE_CHECKING: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) del _deprecation_getattr