From f2ffe7f8f27c6179bd4a3d73c3d12e1ce5f99396 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 3 Sep 2024 06:52:07 -0700 Subject: [PATCH] Deprecate jax.numpy.round_ NumPy removed np.round in version 2.0; jax.numpy.round is drop-in replacement. --- CHANGELOG.md | 2 ++ jax/_src/numpy/lax_numpy.py | 6 ------ jax/numpy/__init__.py | 17 +++++++++++++---- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f59b07cd2..c633ebb40 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * The internal utilities `jax.core.check_eqn`, `jax.core.check_type`, and `jax.core.check_valid_jaxtype` are now deprecated, and will be removed in the future. + * `jax.numpy.round_` has been deprecated, following removal of the corresponding + API in NumPy 2.0. Use {func}`jax.numpy.round` instead. ## jaxlib 0.4.32 diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index e7bcc9ecc..7c0162d78 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2722,12 +2722,6 @@ def around(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: return round(a, decimals, out) -@partial(jit, static_argnames=('decimals',)) -def round_(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: - """Alias of :func:`jax.numpy.round`""" - return round(a, decimals, out) - - @jit def fix(x: ArrayLike, out: None = None) -> Array: """Round input to the nearest integer towards zero. diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 88e1840ef..da79f7859 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -212,7 +212,6 @@ from jax._src.numpy.lax_numpy import ( rollaxis as rollaxis, rot90 as rot90, round as round, - round_ as round_, save as save, savez as savez, searchsorted as searchsorted, @@ -466,6 +465,11 @@ del register_jax_array_methods _deprecations = { + # Deprecated 03 Sept 2024 + "round_": ( + "jnp.round_ is deprecated; use jnp.round instead.", + round + ), # Deprecated 18 Sept 2023 and removed 06 Feb 2024 "trapz": ( "jnp.trapz is deprecated; use jnp.trapezoid instead.", @@ -473,6 +477,11 @@ _deprecations = { ), } -from jax._src.deprecations import deprecation_getattr as _deprecation_getattr -__getattr__ = _deprecation_getattr(__name__, _deprecations) -del _deprecation_getattr +import typing +if typing.TYPE_CHECKING: + round_ = round +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del typing