mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove the initial
argument to jax.nn.softmax
and jax.nn.log_softmax
.
This argument was deprecated in JAX v0.4.27 and has no effect in JAX v0.4.27 and later. PiperOrigin-RevId: 693023366
This commit is contained in:
parent
26c0c5c764
commit
e9acaa8484
@ -23,6 +23,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
* The deprecated module `jax.experimental.export` has been removed. It was replaced
|
||||
by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export)
|
||||
for information on migrating to the new API.
|
||||
* The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax`
|
||||
has been removed, after being deprecated in v0.4.27.
|
||||
* The following deprecated methods and functions in {mod}`jax.export` have
|
||||
been removed:
|
||||
* `jax.export.DisabledSafetyCheck.shape_assertions`: it had no effect
|
||||
|
@ -22,7 +22,6 @@ import operator
|
||||
import math
|
||||
import numpy as np
|
||||
from typing import Any, Literal
|
||||
import warnings
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@ -502,7 +501,7 @@ logsumexp = _logsumexp
|
||||
def log_softmax(x: ArrayLike,
|
||||
axis: int | tuple[int, ...] | None = -1,
|
||||
where: ArrayLike | None = None,
|
||||
initial: ArrayLike | None | Unspecified = _UNSPECIFIED) -> Array:
|
||||
initial: Unspecified = _UNSPECIFIED) -> Array:
|
||||
r"""Log-Softmax function.
|
||||
|
||||
Computes the logarithm of the :code:`softmax` function, which rescales
|
||||
@ -528,10 +527,9 @@ def log_softmax(x: ArrayLike,
|
||||
See also:
|
||||
:func:`softmax`
|
||||
"""
|
||||
# TODO(jakevdp): remove the initial argument after JAX v0.4.40.
|
||||
if initial is not _UNSPECIFIED:
|
||||
# Added 2024-4-10
|
||||
warnings.warn("The initial argument to log_softmax is deprecated, and no longer has any effect.",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
raise TypeError("The initial argument to jax.nn.log_softmax was removed in JAX v0.4.36.")
|
||||
del initial
|
||||
numpy_util.check_arraylike("log_softmax", x)
|
||||
x_arr = jnp.asarray(x)
|
||||
@ -551,7 +549,7 @@ def log_softmax(x: ArrayLike,
|
||||
def softmax(x: ArrayLike,
|
||||
axis: int | tuple[int, ...] | None = -1,
|
||||
where: ArrayLike | None = None,
|
||||
initial: ArrayLike | None | Unspecified = _UNSPECIFIED) -> Array:
|
||||
initial: Unspecified = _UNSPECIFIED) -> Array:
|
||||
r"""Softmax function.
|
||||
|
||||
Computes the function which rescales elements to the range :math:`[0, 1]`
|
||||
@ -577,10 +575,9 @@ def softmax(x: ArrayLike,
|
||||
See also:
|
||||
:func:`log_softmax`
|
||||
"""
|
||||
# TODO(jakevdp): remove the initial argument after JAX v0.4.40.
|
||||
if initial is not _UNSPECIFIED:
|
||||
# Added 2024-4-10
|
||||
warnings.warn("The initial argument to softmax is deprecated, and no longer has any effect.",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
raise TypeError("The initial argument to jax.nn.softmax was removed in JAX v0.4.36.")
|
||||
del initial
|
||||
if config.softmax_custom_jvp.value:
|
||||
# mypy is confused by the `functools.partial` application in the definition
|
||||
|
Loading…
x
Reference in New Issue
Block a user