Minor improvements to doc for jax.nn.logsumexp.

This commit is contained in:
carlosgmartin 2025-04-13 15:17:11 -04:00
parent 773b323b26
commit 2336cd1695

View File

@ -47,16 +47,15 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
JAX implementation of :func:`scipy.special.logsumexp`.
.. math::
\mathrm{logsumexp}(a) = \mathrm{log} \sum_j b \cdot \mathrm{exp}(a_{ij})
\operatorname{logsumexp} a = \log \sum_i b_i \exp a_i
where the :math:`j` indices range over one or more dimensions to be reduced.
where the :math:`i` indices range over one or more dimensions to be reduced.
Args:
a: the input array
axis: int or sequence of ints, default=None. Axis along which the sum to be
computed. If None, the sum is computed along all the axes.
b: scaling factors for :math:`\mathrm{exp}(a)`. Must be broadcastable to the
shape of `a`.
b: scaling factors for the exponentials. Must be broadcastable to the shape of `a`.
keepdims: If ``True``, the axes that are reduced are left in the output as
dimensions of size 1.
return_sign: If ``True``, the output will be a ``(result, sign)`` pair,