Remove the initial argument to jax.nn.softmax and jax.nn.log_softmax.

This argument was deprecated in JAX v0.4.27 and has no effect in JAX v0.4.27 and later.

PiperOrigin-RevId: 693023366
This commit is contained in:
Jake VanderPlas 2024-11-04 10:54:41 -08:00 committed by jax authors
parent 26c0c5c764
commit e9acaa8484
2 changed files with 8 additions and 9 deletions

View File

@ -23,6 +23,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* The deprecated module `jax.experimental.export` has been removed. It was replaced
by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export)
for information on migrating to the new API.
* The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax`
has been removed, after being deprecated in v0.4.27.
* The following deprecated methods and functions in {mod}`jax.export` have
been removed:
* `jax.export.DisabledSafetyCheck.shape_assertions`: it had no effect

View File

@ -22,7 +22,6 @@ import operator
import math
import numpy as np
from typing import Any, Literal
import warnings
import jax
import jax.numpy as jnp
@ -502,7 +501,7 @@ logsumexp = _logsumexp
def log_softmax(x: ArrayLike,
axis: int | tuple[int, ...] | None = -1,
where: ArrayLike | None = None,
initial: ArrayLike | None | Unspecified = _UNSPECIFIED) -> Array:
initial: Unspecified = _UNSPECIFIED) -> Array:
r"""Log-Softmax function.
Computes the logarithm of the :code:`softmax` function, which rescales
@ -528,10 +527,9 @@ def log_softmax(x: ArrayLike,
See also:
:func:`softmax`
"""
# TODO(jakevdp): remove the initial argument after JAX v0.4.40.
if initial is not _UNSPECIFIED:
# Added 2024-4-10
warnings.warn("The initial argument to log_softmax is deprecated, and no longer has any effect.",
DeprecationWarning, stacklevel=2)
raise TypeError("The initial argument to jax.nn.log_softmax was removed in JAX v0.4.36.")
del initial
numpy_util.check_arraylike("log_softmax", x)
x_arr = jnp.asarray(x)
@ -551,7 +549,7 @@ def log_softmax(x: ArrayLike,
def softmax(x: ArrayLike,
axis: int | tuple[int, ...] | None = -1,
where: ArrayLike | None = None,
initial: ArrayLike | None | Unspecified = _UNSPECIFIED) -> Array:
initial: Unspecified = _UNSPECIFIED) -> Array:
r"""Softmax function.
Computes the function which rescales elements to the range :math:`[0, 1]`
@ -577,10 +575,9 @@ def softmax(x: ArrayLike,
See also:
:func:`log_softmax`
"""
# TODO(jakevdp): remove the initial argument after JAX v0.4.40.
if initial is not _UNSPECIFIED:
# Added 2024-4-10
warnings.warn("The initial argument to softmax is deprecated, and no longer has any effect.",
DeprecationWarning, stacklevel=2)
raise TypeError("The initial argument to jax.nn.softmax was removed in JAX v0.4.36.")
del initial
if config.softmax_custom_jvp.value:
# mypy is confused by the `functools.partial` application in the definition