Merge pull request #23310 from rajasekharporeddy:testbranch2

PiperOrigin-RevId: 669379680
This commit is contained in:
jax authors 2024-08-30 10:57:15 -07:00
commit a55a94de24

View File

@ -203,9 +203,45 @@ def ceil(x: ArrayLike, /) -> Array:
return lax.asarray(x)
return lax.ceil(*promote_args_inexact('ceil', x))
@implements(np.exp, module='numpy')
@partial(jit, inline=True)
def exp(x: ArrayLike, /) -> Array:
"""Calculate element-wise exponential of the input.
JAX implementation of :obj:`numpy.exp`.
Args:
x: input array or scalar
Returns:
An array containing the exponential of each element in ``x``, promotes to
inexact dtype.
See also:
- :func:`jax.numpy.log`: Calculates element-wise logarithm of the input.
- :func:`jax.numpy.expm1`: Calculates :math:`e^x-1` of each element of the
input.
- :func:`jax.numpy.exp2`: Calculates base-2 exponential of each element of
the input.
Examples:
``jnp.exp`` follows the properties of exponential such as :math:`e^{(a+b)}
= e^a * e^b`.
>>> x1 = jnp.array([2, 4, 3, 1])
>>> x2 = jnp.array([1, 3, 2, 3])
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.exp(x1+x2))
[ 20.09 1096.63 148.41 54.6 ]
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.exp(x1)*jnp.exp(x2))
[ 20.09 1096.63 148.41 54.6 ]
This property holds for complex input also:
>>> jnp.allclose(jnp.exp(3-4j), jnp.exp(3)*jnp.exp(-4j))
Array(True, dtype=bool)
"""
return lax.exp(*promote_args_inexact('exp', x))
@implements(np.log, module='numpy')
@ -213,9 +249,48 @@ def exp(x: ArrayLike, /) -> Array:
def log(x: ArrayLike, /) -> Array:
return lax.log(*promote_args_inexact('log', x))
@implements(np.expm1, module='numpy')
@partial(jit, inline=True)
def expm1(x: ArrayLike, /) -> Array:
"""Calculate ``exp(x)-1`` of each element of the input.
JAX implementation of :obj:`numpy.expm1`.
Args:
x: input array or scalar.
Returns:
An array containing ``exp(x)-1`` of each element in ``x``, promotes to inexact
dtype.
Note:
``jnp.expm1`` has much higher precision than the naive computation of
``exp(x)-1`` for small values of ``x``.
See also:
- :func:`jax.numpy.log1p`: Calculates element-wise logarithm of one plus input.
- :func:`jax.numpy.exp`: Calculates element-wise exponential of the input.
- :func:`jax.numpy.exp2`: Calculates base-2 exponential of each element of
the input.
Examples:
>>> x = jnp.array([2, -4, 3, -1])
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.expm1(x))
[ 6.39 -0.98 19.09 -0.63]
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.exp(x)-1)
[ 6.39 -0.98 19.09 -0.63]
For values very close to 0, ``jnp.expm1(x)`` is much more accurate than
``jnp.exp(x)-1``:
>>> x1 = jnp.array([1e-4, 1e-6, 2e-10])
>>> jnp.expm1(x1)
Array([1.0000500e-04, 1.0000005e-06, 2.0000000e-10], dtype=float32)
>>> jnp.exp(x1)-1
Array([1.00016594e-04, 9.53674316e-07, 0.00000000e+00], dtype=float32)
"""
return lax.expm1(*promote_args_inexact('expm1', x))
@implements(np.log1p, module='numpy')
@ -970,9 +1045,36 @@ def log10(x: ArrayLike, /) -> Array:
return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))
@implements(np.exp2, module='numpy')
@partial(jit, inline=True)
def exp2(x: ArrayLike, /) -> Array:
"""Calculate element-wise base-2 exponential of input.
JAX implementation of :obj:`numpy.exp2`.
Args:
x: input array or scalar
Returns:
An array containing the base-2 exponential of each element in ``x``, promotes
to inexact dtype.
See also:
- :func:`jax.numpy.log2`: Calculates base-2 logarithm of each element of input.
- :func:`jax.numpy.exp`: Calculates exponential of each element of the input.
- :func:`jax.numpy.expm1`: Calculates :math:`e^x-1` of each element of the
input.
Examples:
``jnp.exp2`` follows the properties of the exponential such as :math:`2^{a+b}
= 2^a * 2^b`.
>>> x1 = jnp.array([2, -4, 3, -1])
>>> x2 = jnp.array([-1, 3, -2, 3])
>>> jnp.exp2(x1+x2)
Array([2. , 0.5, 2. , 4. ], dtype=float32)
>>> jnp.exp2(x1)*jnp.exp2(x2)
Array([2. , 0.5, 2. , 4. ], dtype=float32)
"""
x, = promote_args_inexact("exp2", x)
return lax.exp2(x)