diff --git a/CHANGELOG.md b/CHANGELOG.md index f59b07cd2..c633ebb40 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * The internal utilities `jax.core.check_eqn`, `jax.core.check_type`, and `jax.core.check_valid_jaxtype` are now deprecated, and will be removed in the future. + * `jax.numpy.round_` has been deprecated, following removal of the corresponding + API in NumPy 2.0. Use {func}`jax.numpy.round` instead. ## jaxlib 0.4.32 diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index e7bcc9ecc..7c0162d78 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2722,12 +2722,6 @@ def around(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: return round(a, decimals, out) -@partial(jit, static_argnames=('decimals',)) -def round_(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: - """Alias of :func:`jax.numpy.round`""" - return round(a, decimals, out) - - @jit def fix(x: ArrayLike, out: None = None) -> Array: """Round input to the nearest integer towards zero. diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 88e1840ef..da79f7859 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -212,7 +212,6 @@ from jax._src.numpy.lax_numpy import ( rollaxis as rollaxis, rot90 as rot90, round as round, - round_ as round_, save as save, savez as savez, searchsorted as searchsorted, @@ -466,6 +465,11 @@ del register_jax_array_methods _deprecations = { + # Deprecated 03 Sept 2024 + "round_": ( + "jnp.round_ is deprecated; use jnp.round instead.", + round + ), # Deprecated 18 Sept 2023 and removed 06 Feb 2024 "trapz": ( "jnp.trapz is deprecated; use jnp.trapezoid instead.", @@ -473,6 +477,11 @@ _deprecations = { ), } -from jax._src.deprecations import deprecation_getattr as _deprecation_getattr -__getattr__ = _deprecation_getattr(__name__, _deprecations) -del _deprecation_getattr +import typing +if typing.TYPE_CHECKING: + round_ = round +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del typing