mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Finalize deprecation of jnp.round_
PiperOrigin-RevId: 705998500
This commit is contained in:
parent
078c7e4444
commit
c73f306099
@ -26,6 +26,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
`non_negative_dim`.
|
||||
* from {mod}`jax.lib.xla_bridge`: `xla_client` and `default_backend`.
|
||||
* from {mod}`jax.lib.xla_client`: `_xla` and `bfloat16`.
|
||||
* from {mod}`jax.numpy`: `round_`.
|
||||
|
||||
## jax 0.4.37 (Dec 9, 2024)
|
||||
|
||||
|
@ -359,7 +359,6 @@ namespace; they are listed below.
|
||||
roots
|
||||
rot90
|
||||
round
|
||||
round_
|
||||
s_
|
||||
save
|
||||
savez
|
||||
|
@ -479,17 +479,15 @@ del register_jax_array_methods
|
||||
|
||||
|
||||
_deprecations = {
|
||||
# Deprecated 03 Sept 2024
|
||||
# Finalized 2024-12-13; remove after 2024-3-13
|
||||
"round_": (
|
||||
"jnp.round_ is deprecated; use jnp.round instead.",
|
||||
round
|
||||
"jnp.round_ was deprecated in JAX 0.4.38; use jnp.round instead.",
|
||||
None
|
||||
),
|
||||
}
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
round_ = round
|
||||
else:
|
||||
if not typing.TYPE_CHECKING:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
|
Loading…
x
Reference in New Issue
Block a user