mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Minor improvements to doc for jax.nn.logsumexp.
This commit is contained in:
parent
773b323b26
commit
2336cd1695
@ -47,16 +47,15 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
|
||||
JAX implementation of :func:`scipy.special.logsumexp`.
|
||||
|
||||
.. math::
|
||||
\mathrm{logsumexp}(a) = \mathrm{log} \sum_j b \cdot \mathrm{exp}(a_{ij})
|
||||
\operatorname{logsumexp} a = \log \sum_i b_i \exp a_i
|
||||
|
||||
where the :math:`j` indices range over one or more dimensions to be reduced.
|
||||
where the :math:`i` indices range over one or more dimensions to be reduced.
|
||||
|
||||
Args:
|
||||
a: the input array
|
||||
axis: int or sequence of ints, default=None. Axis along which the sum to be
|
||||
computed. If None, the sum is computed along all the axes.
|
||||
b: scaling factors for :math:`\mathrm{exp}(a)`. Must be broadcastable to the
|
||||
shape of `a`.
|
||||
b: scaling factors for the exponentials. Must be broadcastable to the shape of `a`.
|
||||
keepdims: If ``True``, the axes that are reduced are left in the output as
|
||||
dimensions of size 1.
|
||||
return_sign: If ``True``, the output will be a ``(result, sign)`` pair,
|
||||
|
Loading…
x
Reference in New Issue
Block a user