diff --git a/CHANGELOG.md b/CHANGELOG.md index b122675e1..8c34b8e36 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * The deprecated module `jax.experimental.export` has been removed. It was replaced by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export) for information on migrating to the new API. + * The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax` + has been removed, after being deprecated in v0.4.27. * The following deprecated methods and functions in {mod}`jax.export` have been removed: * `jax.export.DisabledSafetyCheck.shape_assertions`: it had no effect diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 861e3d012..5dfaa7b7e 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -22,7 +22,6 @@ import operator import math import numpy as np from typing import Any, Literal -import warnings import jax import jax.numpy as jnp @@ -502,7 +501,7 @@ logsumexp = _logsumexp def log_softmax(x: ArrayLike, axis: int | tuple[int, ...] | None = -1, where: ArrayLike | None = None, - initial: ArrayLike | None | Unspecified = _UNSPECIFIED) -> Array: + initial: Unspecified = _UNSPECIFIED) -> Array: r"""Log-Softmax function. Computes the logarithm of the :code:`softmax` function, which rescales @@ -528,10 +527,9 @@ def log_softmax(x: ArrayLike, See also: :func:`softmax` """ + # TODO(jakevdp): remove the initial argument after JAX v0.4.40. if initial is not _UNSPECIFIED: - # Added 2024-4-10 - warnings.warn("The initial argument to log_softmax is deprecated, and no longer has any effect.", - DeprecationWarning, stacklevel=2) + raise TypeError("The initial argument to jax.nn.log_softmax was removed in JAX v0.4.36.") del initial numpy_util.check_arraylike("log_softmax", x) x_arr = jnp.asarray(x) @@ -551,7 +549,7 @@ def log_softmax(x: ArrayLike, def softmax(x: ArrayLike, axis: int | tuple[int, ...] | None = -1, where: ArrayLike | None = None, - initial: ArrayLike | None | Unspecified = _UNSPECIFIED) -> Array: + initial: Unspecified = _UNSPECIFIED) -> Array: r"""Softmax function. Computes the function which rescales elements to the range :math:`[0, 1]` @@ -577,10 +575,9 @@ def softmax(x: ArrayLike, See also: :func:`log_softmax` """ + # TODO(jakevdp): remove the initial argument after JAX v0.4.40. if initial is not _UNSPECIFIED: - # Added 2024-4-10 - warnings.warn("The initial argument to softmax is deprecated, and no longer has any effect.", - DeprecationWarning, stacklevel=2) + raise TypeError("The initial argument to jax.nn.softmax was removed in JAX v0.4.36.") del initial if config.softmax_custom_jvp.value: # mypy is confused by the `functools.partial` application in the definition