Deprecate jax.numpy.round_

NumPy removed np.round in version 2.0; jax.numpy.round is drop-in
replacement.
This commit is contained in:
Jake VanderPlas 2024-09-03 06:52:07 -07:00
parent ccabd21084
commit f2ffe7f8f2
3 changed files with 15 additions and 10 deletions

View File

@ -51,6 +51,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* The internal utilities `jax.core.check_eqn`, `jax.core.check_type`, and
`jax.core.check_valid_jaxtype` are now deprecated, and will be removed in
the future.
* `jax.numpy.round_` has been deprecated, following removal of the corresponding
API in NumPy 2.0. Use {func}`jax.numpy.round` instead.
## jaxlib 0.4.32

View File

@ -2722,12 +2722,6 @@ def around(a: ArrayLike, decimals: int = 0, out: None = None) -> Array:
return round(a, decimals, out)
@partial(jit, static_argnames=('decimals',))
def round_(a: ArrayLike, decimals: int = 0, out: None = None) -> Array:
"""Alias of :func:`jax.numpy.round`"""
return round(a, decimals, out)
@jit
def fix(x: ArrayLike, out: None = None) -> Array:
"""Round input to the nearest integer towards zero.

View File

@ -212,7 +212,6 @@ from jax._src.numpy.lax_numpy import (
rollaxis as rollaxis,
rot90 as rot90,
round as round,
round_ as round_,
save as save,
savez as savez,
searchsorted as searchsorted,
@ -466,6 +465,11 @@ del register_jax_array_methods
_deprecations = {
# Deprecated 03 Sept 2024
"round_": (
"jnp.round_ is deprecated; use jnp.round instead.",
round
),
# Deprecated 18 Sept 2023 and removed 06 Feb 2024
"trapz": (
"jnp.trapz is deprecated; use jnp.trapezoid instead.",
@ -473,6 +477,11 @@ _deprecations = {
),
}
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
import typing
if typing.TYPE_CHECKING:
round_ = round
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing