Remove the initial argument to jax.nn.softmax and jax.nn.log_softmax.

This argument was deprecated in JAX v0.4.27 and has no effect in JAX v0.4.27 and later.

PiperOrigin-RevId: 693023366
This commit is contained in:
Jake VanderPlas 2024-11-04 10:54:41 -08:00 committed by jax authors
parent 26c0c5c764
commit e9acaa8484
2 changed files with 8 additions and 9 deletions

View File

@ -23,6 +23,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* The deprecated module `jax.experimental.export` has been removed. It was replaced * The deprecated module `jax.experimental.export` has been removed. It was replaced
by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export) by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export)
for information on migrating to the new API. for information on migrating to the new API.
* The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax`
has been removed, after being deprecated in v0.4.27.
* The following deprecated methods and functions in {mod}`jax.export` have * The following deprecated methods and functions in {mod}`jax.export` have
been removed: been removed:
* `jax.export.DisabledSafetyCheck.shape_assertions`: it had no effect * `jax.export.DisabledSafetyCheck.shape_assertions`: it had no effect

View File

@ -22,7 +22,6 @@ import operator
import math import math
import numpy as np import numpy as np
from typing import Any, Literal from typing import Any, Literal
import warnings
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
@ -502,7 +501,7 @@ logsumexp = _logsumexp
def log_softmax(x: ArrayLike, def log_softmax(x: ArrayLike,
axis: int | tuple[int, ...] | None = -1, axis: int | tuple[int, ...] | None = -1,
where: ArrayLike | None = None, where: ArrayLike | None = None,
initial: ArrayLike | None | Unspecified = _UNSPECIFIED) -> Array: initial: Unspecified = _UNSPECIFIED) -> Array:
r"""Log-Softmax function. r"""Log-Softmax function.
Computes the logarithm of the :code:`softmax` function, which rescales Computes the logarithm of the :code:`softmax` function, which rescales
@ -528,10 +527,9 @@ def log_softmax(x: ArrayLike,
See also: See also:
:func:`softmax` :func:`softmax`
""" """
# TODO(jakevdp): remove the initial argument after JAX v0.4.40.
if initial is not _UNSPECIFIED: if initial is not _UNSPECIFIED:
# Added 2024-4-10 raise TypeError("The initial argument to jax.nn.log_softmax was removed in JAX v0.4.36.")
warnings.warn("The initial argument to log_softmax is deprecated, and no longer has any effect.",
DeprecationWarning, stacklevel=2)
del initial del initial
numpy_util.check_arraylike("log_softmax", x) numpy_util.check_arraylike("log_softmax", x)
x_arr = jnp.asarray(x) x_arr = jnp.asarray(x)
@ -551,7 +549,7 @@ def log_softmax(x: ArrayLike,
def softmax(x: ArrayLike, def softmax(x: ArrayLike,
axis: int | tuple[int, ...] | None = -1, axis: int | tuple[int, ...] | None = -1,
where: ArrayLike | None = None, where: ArrayLike | None = None,
initial: ArrayLike | None | Unspecified = _UNSPECIFIED) -> Array: initial: Unspecified = _UNSPECIFIED) -> Array:
r"""Softmax function. r"""Softmax function.
Computes the function which rescales elements to the range :math:`[0, 1]` Computes the function which rescales elements to the range :math:`[0, 1]`
@ -577,10 +575,9 @@ def softmax(x: ArrayLike,
See also: See also:
:func:`log_softmax` :func:`log_softmax`
""" """
# TODO(jakevdp): remove the initial argument after JAX v0.4.40.
if initial is not _UNSPECIFIED: if initial is not _UNSPECIFIED:
# Added 2024-4-10 raise TypeError("The initial argument to jax.nn.softmax was removed in JAX v0.4.36.")
warnings.warn("The initial argument to softmax is deprecated, and no longer has any effect.",
DeprecationWarning, stacklevel=2)
del initial del initial
if config.softmax_custom_jvp.value: if config.softmax_custom_jvp.value:
# mypy is confused by the `functools.partial` application in the definition # mypy is confused by the `functools.partial` application in the definition