mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Deprecate jax.numpy.round_
NumPy removed np.round in version 2.0; jax.numpy.round is drop-in replacement.
This commit is contained in:
parent
ccabd21084
commit
f2ffe7f8f2
@ -51,6 +51,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
* The internal utilities `jax.core.check_eqn`, `jax.core.check_type`, and
|
||||
`jax.core.check_valid_jaxtype` are now deprecated, and will be removed in
|
||||
the future.
|
||||
* `jax.numpy.round_` has been deprecated, following removal of the corresponding
|
||||
API in NumPy 2.0. Use {func}`jax.numpy.round` instead.
|
||||
|
||||
## jaxlib 0.4.32
|
||||
|
||||
|
@ -2722,12 +2722,6 @@ def around(a: ArrayLike, decimals: int = 0, out: None = None) -> Array:
|
||||
return round(a, decimals, out)
|
||||
|
||||
|
||||
@partial(jit, static_argnames=('decimals',))
|
||||
def round_(a: ArrayLike, decimals: int = 0, out: None = None) -> Array:
|
||||
"""Alias of :func:`jax.numpy.round`"""
|
||||
return round(a, decimals, out)
|
||||
|
||||
|
||||
@jit
|
||||
def fix(x: ArrayLike, out: None = None) -> Array:
|
||||
"""Round input to the nearest integer towards zero.
|
||||
|
@ -212,7 +212,6 @@ from jax._src.numpy.lax_numpy import (
|
||||
rollaxis as rollaxis,
|
||||
rot90 as rot90,
|
||||
round as round,
|
||||
round_ as round_,
|
||||
save as save,
|
||||
savez as savez,
|
||||
searchsorted as searchsorted,
|
||||
@ -466,6 +465,11 @@ del register_jax_array_methods
|
||||
|
||||
|
||||
_deprecations = {
|
||||
# Deprecated 03 Sept 2024
|
||||
"round_": (
|
||||
"jnp.round_ is deprecated; use jnp.round instead.",
|
||||
round
|
||||
),
|
||||
# Deprecated 18 Sept 2023 and removed 06 Feb 2024
|
||||
"trapz": (
|
||||
"jnp.trapz is deprecated; use jnp.trapezoid instead.",
|
||||
@ -473,6 +477,11 @@ _deprecations = {
|
||||
),
|
||||
}
|
||||
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
round_ = round
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del typing
|
||||
|
Loading…
x
Reference in New Issue
Block a user