Merge pull request #20731 from NeilGirdhar:softmax

PiperOrigin-RevId: 645774372
This commit is contained in:
jax authors 2024-06-22 21:51:22 -07:00
commit 8c602cc3d0
4 changed files with 92 additions and 12 deletions

View File

@ -174,7 +174,9 @@ jax.scipy.special
i0e
i1
i1e
kl_div
log_ndtr
log_softmax
logit
logsumexp
lpmn
@ -184,13 +186,13 @@ jax.scipy.special
ndtri
poch
polygamma
rel_entr
softmax
spence
sph_harm
xlog1py
xlogy
zeta
kl_div
rel_entr
jax.scipy.stats

View File

@ -35,6 +35,8 @@ from jax._src.numpy.util import promote_args_inexact, promote_dtypes_inexact
from jax._src.ops import special as ops_special
from jax._src.third_party.scipy.betaln import betaln as _betaln_impl
from jax._src.typing import Array, ArrayLike
from jax._src.nn.functions import softmax as nn_softmax
from jax._src.nn.functions import log_softmax as nn_log_softmax
def gammaln(x: ArrayLike) -> Array:
@ -2582,3 +2584,72 @@ hyp1f1.defjvps(
lambda b_dot, primal_out, a, b, x: _hyp1f1_b_derivative(a, b, x) * b_dot,
lambda x_dot, primal_out, a, b, x: _hyp1f1_x_derivative(a, b, x) * x_dot
)
def softmax(x: ArrayLike,
/,
*,
axis: int | tuple[int, ...] | None = None,
) -> Array:
r"""Softmax function.
JAX implementation of :func:`scipy.special.softmax`.
Computes the function which rescales elements to the range :math:`[0, 1]`
such that the elements along :code:`axis` sum to :math:`1`.
.. math ::
\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
Args:
x : input array
axis: the axis or axes along which the softmax should be computed. The
softmax output summed across these dimensions should sum to :math:`1`.
Returns:
An array of the same shape as ``x``.
Note:
If any input values are ``+inf``, the result will be all ``NaN``: this
reflects the fact that ``inf / inf`` is not well-defined in the context of
floating-point math.
See also:
:func:`log_softmax`
"""
return nn_softmax(x, axis=axis)
def log_softmax(x: ArrayLike,
/,
*,
axis: int | tuple[int, ...] | None = None,
) -> Array:
r"""Log-Softmax function.
JAX implementation of :func:`scipy.special.log_softmax`
Computes the logarithm of the :code:`softmax` function, which rescales
elements to the range :math:`[-\infty, 0)`.
.. math ::
\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
\right)
Args:
x : input array
axis: the axis or axes along which the :code:`log_softmax` should be
computed.
Returns:
An array of the same shape as ``x``
Note:
If any input values are ``+inf``, the result will be all ``NaN``: this
reflects the fact that ``inf / inf`` is not well-defined in the context of
floating-point math.
See also:
:func:`softmax`
"""
return nn_log_softmax(x, axis=axis)

View File

@ -17,10 +17,10 @@
from jax._src.scipy.special import (
bernoulli as bernoulli,
bessel_jn as bessel_jn,
beta as beta,
betainc as betainc,
betaln as betaln,
beta as beta,
bessel_jn as bessel_jn,
digamma as digamma,
entr as entr,
erf as erf,
@ -31,31 +31,33 @@ from jax._src.scipy.special import (
expit as expit,
expn as expn,
factorial as factorial,
gamma as gamma,
gammainc as gammainc,
gammaincc as gammaincc,
gammaln as gammaln,
gammasgn as gammasgn,
gamma as gamma,
hyp1f1 as hyp1f1,
i0 as i0,
i0e as i0e,
i1 as i1,
i1e as i1e,
kl_div as kl_div,
log_ndtr as log_ndtr,
log_softmax as log_softmax,
logit as logit,
logsumexp as logsumexp,
lpmn as lpmn,
lpmn_values as lpmn_values,
multigammaln as multigammaln,
log_ndtr as log_ndtr,
ndtr as ndtr,
ndtri as ndtri,
poch as poch,
polygamma as polygamma,
rel_entr as rel_entr,
softmax as softmax,
spence as spence,
sph_harm as sph_harm,
xlogy as xlogy,
xlog1py as xlog1py,
xlogy as xlogy,
zeta as zeta,
kl_div as kl_div,
rel_entr as rel_entr,
poch as poch,
hyp1f1 as hyp1f1,
)

View File

@ -148,7 +148,12 @@ JAX_SPECIAL_FUNCTION_RECORDS = [
"rel_entr", 2, float_dtypes, jtu.rand_positive, True,
),
op_record("poch", 2, float_dtypes, jtu.rand_positive, True),
op_record("hyp1f1", 3, float_dtypes, functools.partial(jtu.rand_uniform, low=0.5, high=30), True)
op_record(
"hyp1f1", 3, float_dtypes,
functools.partial(jtu.rand_uniform, low=0.5, high=30), True
),
op_record("log_softmax", 1, float_dtypes, jtu.rand_default, True),
op_record("softmax", 1, float_dtypes, jtu.rand_default, True),
]