mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Merge pull request #20731 from NeilGirdhar:softmax
PiperOrigin-RevId: 645774372
This commit is contained in:
commit
8c602cc3d0
@ -174,7 +174,9 @@ jax.scipy.special
|
||||
i0e
|
||||
i1
|
||||
i1e
|
||||
kl_div
|
||||
log_ndtr
|
||||
log_softmax
|
||||
logit
|
||||
logsumexp
|
||||
lpmn
|
||||
@ -184,13 +186,13 @@ jax.scipy.special
|
||||
ndtri
|
||||
poch
|
||||
polygamma
|
||||
rel_entr
|
||||
softmax
|
||||
spence
|
||||
sph_harm
|
||||
xlog1py
|
||||
xlogy
|
||||
zeta
|
||||
kl_div
|
||||
rel_entr
|
||||
|
||||
|
||||
jax.scipy.stats
|
||||
|
@ -35,6 +35,8 @@ from jax._src.numpy.util import promote_args_inexact, promote_dtypes_inexact
|
||||
from jax._src.ops import special as ops_special
|
||||
from jax._src.third_party.scipy.betaln import betaln as _betaln_impl
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.nn.functions import softmax as nn_softmax
|
||||
from jax._src.nn.functions import log_softmax as nn_log_softmax
|
||||
|
||||
|
||||
def gammaln(x: ArrayLike) -> Array:
|
||||
@ -2582,3 +2584,72 @@ hyp1f1.defjvps(
|
||||
lambda b_dot, primal_out, a, b, x: _hyp1f1_b_derivative(a, b, x) * b_dot,
|
||||
lambda x_dot, primal_out, a, b, x: _hyp1f1_x_derivative(a, b, x) * x_dot
|
||||
)
|
||||
|
||||
|
||||
def softmax(x: ArrayLike,
|
||||
/,
|
||||
*,
|
||||
axis: int | tuple[int, ...] | None = None,
|
||||
) -> Array:
|
||||
r"""Softmax function.
|
||||
|
||||
JAX implementation of :func:`scipy.special.softmax`.
|
||||
|
||||
Computes the function which rescales elements to the range :math:`[0, 1]`
|
||||
such that the elements along :code:`axis` sum to :math:`1`.
|
||||
|
||||
.. math ::
|
||||
\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
||||
|
||||
Args:
|
||||
x : input array
|
||||
axis: the axis or axes along which the softmax should be computed. The
|
||||
softmax output summed across these dimensions should sum to :math:`1`.
|
||||
|
||||
Returns:
|
||||
An array of the same shape as ``x``.
|
||||
|
||||
Note:
|
||||
If any input values are ``+inf``, the result will be all ``NaN``: this
|
||||
reflects the fact that ``inf / inf`` is not well-defined in the context of
|
||||
floating-point math.
|
||||
|
||||
See also:
|
||||
:func:`log_softmax`
|
||||
"""
|
||||
return nn_softmax(x, axis=axis)
|
||||
|
||||
|
||||
def log_softmax(x: ArrayLike,
|
||||
/,
|
||||
*,
|
||||
axis: int | tuple[int, ...] | None = None,
|
||||
) -> Array:
|
||||
r"""Log-Softmax function.
|
||||
|
||||
JAX implementation of :func:`scipy.special.log_softmax`
|
||||
|
||||
Computes the logarithm of the :code:`softmax` function, which rescales
|
||||
elements to the range :math:`[-\infty, 0)`.
|
||||
|
||||
.. math ::
|
||||
\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
||||
\right)
|
||||
|
||||
Args:
|
||||
x : input array
|
||||
axis: the axis or axes along which the :code:`log_softmax` should be
|
||||
computed.
|
||||
|
||||
Returns:
|
||||
An array of the same shape as ``x``
|
||||
|
||||
Note:
|
||||
If any input values are ``+inf``, the result will be all ``NaN``: this
|
||||
reflects the fact that ``inf / inf`` is not well-defined in the context of
|
||||
floating-point math.
|
||||
|
||||
See also:
|
||||
:func:`softmax`
|
||||
"""
|
||||
return nn_log_softmax(x, axis=axis)
|
||||
|
@ -17,10 +17,10 @@
|
||||
|
||||
from jax._src.scipy.special import (
|
||||
bernoulli as bernoulli,
|
||||
bessel_jn as bessel_jn,
|
||||
beta as beta,
|
||||
betainc as betainc,
|
||||
betaln as betaln,
|
||||
beta as beta,
|
||||
bessel_jn as bessel_jn,
|
||||
digamma as digamma,
|
||||
entr as entr,
|
||||
erf as erf,
|
||||
@ -31,31 +31,33 @@ from jax._src.scipy.special import (
|
||||
expit as expit,
|
||||
expn as expn,
|
||||
factorial as factorial,
|
||||
gamma as gamma,
|
||||
gammainc as gammainc,
|
||||
gammaincc as gammaincc,
|
||||
gammaln as gammaln,
|
||||
gammasgn as gammasgn,
|
||||
gamma as gamma,
|
||||
hyp1f1 as hyp1f1,
|
||||
i0 as i0,
|
||||
i0e as i0e,
|
||||
i1 as i1,
|
||||
i1e as i1e,
|
||||
kl_div as kl_div,
|
||||
log_ndtr as log_ndtr,
|
||||
log_softmax as log_softmax,
|
||||
logit as logit,
|
||||
logsumexp as logsumexp,
|
||||
lpmn as lpmn,
|
||||
lpmn_values as lpmn_values,
|
||||
multigammaln as multigammaln,
|
||||
log_ndtr as log_ndtr,
|
||||
ndtr as ndtr,
|
||||
ndtri as ndtri,
|
||||
poch as poch,
|
||||
polygamma as polygamma,
|
||||
rel_entr as rel_entr,
|
||||
softmax as softmax,
|
||||
spence as spence,
|
||||
sph_harm as sph_harm,
|
||||
xlogy as xlogy,
|
||||
xlog1py as xlog1py,
|
||||
xlogy as xlogy,
|
||||
zeta as zeta,
|
||||
kl_div as kl_div,
|
||||
rel_entr as rel_entr,
|
||||
poch as poch,
|
||||
hyp1f1 as hyp1f1,
|
||||
)
|
||||
|
@ -148,7 +148,12 @@ JAX_SPECIAL_FUNCTION_RECORDS = [
|
||||
"rel_entr", 2, float_dtypes, jtu.rand_positive, True,
|
||||
),
|
||||
op_record("poch", 2, float_dtypes, jtu.rand_positive, True),
|
||||
op_record("hyp1f1", 3, float_dtypes, functools.partial(jtu.rand_uniform, low=0.5, high=30), True)
|
||||
op_record(
|
||||
"hyp1f1", 3, float_dtypes,
|
||||
functools.partial(jtu.rand_uniform, low=0.5, high=30), True
|
||||
),
|
||||
op_record("log_softmax", 1, float_dtypes, jtu.rand_default, True),
|
||||
op_record("softmax", 1, float_dtypes, jtu.rand_default, True),
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user