mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Better docs for jax.numpy: log and log1p
This commit is contained in:
parent
22be4eafca
commit
cb45fb426a
@ -374,9 +374,41 @@ def exp(x: ArrayLike, /) -> Array:
|
||||
"""
|
||||
return lax.exp(*promote_args_inexact('exp', x))
|
||||
|
||||
@implements(np.log, module='numpy')
|
||||
|
||||
@partial(jit, inline=True)
|
||||
def log(x: ArrayLike, /) -> Array:
|
||||
"""Calculate element-wise natural logarithm of the input.
|
||||
|
||||
JAX implementation of :obj:`numpy.log`.
|
||||
|
||||
Args:
|
||||
x: input array or scalar.
|
||||
|
||||
Returns:
|
||||
An array containing the logarithm of each element in ``x``, promotes to inexact
|
||||
dtype.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.exp`: Calculates element-wise exponential of the input.
|
||||
- :func:`jax.numpy.log2`: Calculates base-2 logarithm of each element of input.
|
||||
- :func:`jax.numpy.log1p`: Calculates element-wise logarithm of one plus input.
|
||||
|
||||
Examples:
|
||||
``jnp.log`` and ``jnp.exp`` are inverse functions of each other. Applying
|
||||
``jnp.log`` on the result of ``jnp.exp(x)`` yields the original input ``x``.
|
||||
|
||||
>>> x = jnp.array([2, 3, 4, 5])
|
||||
>>> jnp.log(jnp.exp(x))
|
||||
Array([2., 3., 4., 5.], dtype=float32)
|
||||
|
||||
Using ``jnp.log`` we can demonstrate well-known properties of logarithms, such
|
||||
as :math:`log(a*b) = log(a)+log(b)`.
|
||||
|
||||
>>> x1 = jnp.array([2, 1, 3, 1])
|
||||
>>> x2 = jnp.array([1, 3, 2, 4])
|
||||
>>> jnp.allclose(jnp.log(x1*x2), jnp.log(x1)+jnp.log(x2))
|
||||
Array(True, dtype=bool)
|
||||
"""
|
||||
return lax.log(*promote_args_inexact('log', x))
|
||||
|
||||
|
||||
@ -423,9 +455,44 @@ def expm1(x: ArrayLike, /) -> Array:
|
||||
"""
|
||||
return lax.expm1(*promote_args_inexact('expm1', x))
|
||||
|
||||
@implements(np.log1p, module='numpy')
|
||||
|
||||
@partial(jit, inline=True)
|
||||
def log1p(x: ArrayLike, /) -> Array:
|
||||
"""Calculates element-wise logarithm of one plus input, ``log(x+1)``.
|
||||
|
||||
JAX implementation of :obj:`numpy.log1p`.
|
||||
|
||||
Args:
|
||||
x: input array or scalar.
|
||||
|
||||
Returns:
|
||||
An array containing the logarithm of one plus of each element in ``x``,
|
||||
promotes to inexact dtype.
|
||||
|
||||
Note:
|
||||
``jnp.log1p`` is more accurate than when using the naive computation of
|
||||
``log(x+1)`` for small values of ``x``.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.expm1`: Calculates :math:`e^x-1` of each element of the
|
||||
input.
|
||||
- :func:`jax.numpy.log2`: Calculates base-2 logarithm of each element of input.
|
||||
- :func:`jax.numpy.log`: Calculates element-wise logarithm of the input.
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.array([2, 5, 9, 4])
|
||||
>>> jnp.allclose(jnp.log1p(x), jnp.log(x+1))
|
||||
Array(True, dtype=bool)
|
||||
|
||||
For values very close to 0, ``jnp.log1p(x)`` is more accurate than
|
||||
``jnp.log(x+1)``:
|
||||
|
||||
>>> x1 = jnp.array([1e-4, 1e-6, 2e-10])
|
||||
>>> jnp.expm1(jnp.log1p(x1)) # doctest: +SKIP
|
||||
Array([1.00000005e-04, 9.99999997e-07, 2.00000003e-10], dtype=float32)
|
||||
>>> jnp.expm1(jnp.log(x1+1)) # doctest: +SKIP
|
||||
Array([1.000166e-04, 9.536743e-07, 0.000000e+00], dtype=float32)
|
||||
"""
|
||||
return lax.log1p(*promote_args_inexact('log1p', x))
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user