mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #23310 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 669379680
This commit is contained in:
commit
a55a94de24
@ -203,9 +203,45 @@ def ceil(x: ArrayLike, /) -> Array:
|
||||
return lax.asarray(x)
|
||||
return lax.ceil(*promote_args_inexact('ceil', x))
|
||||
|
||||
@implements(np.exp, module='numpy')
|
||||
|
||||
@partial(jit, inline=True)
|
||||
def exp(x: ArrayLike, /) -> Array:
|
||||
"""Calculate element-wise exponential of the input.
|
||||
|
||||
JAX implementation of :obj:`numpy.exp`.
|
||||
|
||||
Args:
|
||||
x: input array or scalar
|
||||
|
||||
Returns:
|
||||
An array containing the exponential of each element in ``x``, promotes to
|
||||
inexact dtype.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.log`: Calculates element-wise logarithm of the input.
|
||||
- :func:`jax.numpy.expm1`: Calculates :math:`e^x-1` of each element of the
|
||||
input.
|
||||
- :func:`jax.numpy.exp2`: Calculates base-2 exponential of each element of
|
||||
the input.
|
||||
|
||||
Examples:
|
||||
``jnp.exp`` follows the properties of exponential such as :math:`e^{(a+b)}
|
||||
= e^a * e^b`.
|
||||
|
||||
>>> x1 = jnp.array([2, 4, 3, 1])
|
||||
>>> x2 = jnp.array([1, 3, 2, 3])
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jnp.exp(x1+x2))
|
||||
[ 20.09 1096.63 148.41 54.6 ]
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jnp.exp(x1)*jnp.exp(x2))
|
||||
[ 20.09 1096.63 148.41 54.6 ]
|
||||
|
||||
This property holds for complex input also:
|
||||
|
||||
>>> jnp.allclose(jnp.exp(3-4j), jnp.exp(3)*jnp.exp(-4j))
|
||||
Array(True, dtype=bool)
|
||||
"""
|
||||
return lax.exp(*promote_args_inexact('exp', x))
|
||||
|
||||
@implements(np.log, module='numpy')
|
||||
@ -213,9 +249,48 @@ def exp(x: ArrayLike, /) -> Array:
|
||||
def log(x: ArrayLike, /) -> Array:
|
||||
return lax.log(*promote_args_inexact('log', x))
|
||||
|
||||
@implements(np.expm1, module='numpy')
|
||||
|
||||
@partial(jit, inline=True)
|
||||
def expm1(x: ArrayLike, /) -> Array:
|
||||
"""Calculate ``exp(x)-1`` of each element of the input.
|
||||
|
||||
JAX implementation of :obj:`numpy.expm1`.
|
||||
|
||||
Args:
|
||||
x: input array or scalar.
|
||||
|
||||
Returns:
|
||||
An array containing ``exp(x)-1`` of each element in ``x``, promotes to inexact
|
||||
dtype.
|
||||
|
||||
Note:
|
||||
``jnp.expm1`` has much higher precision than the naive computation of
|
||||
``exp(x)-1`` for small values of ``x``.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.log1p`: Calculates element-wise logarithm of one plus input.
|
||||
- :func:`jax.numpy.exp`: Calculates element-wise exponential of the input.
|
||||
- :func:`jax.numpy.exp2`: Calculates base-2 exponential of each element of
|
||||
the input.
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.array([2, -4, 3, -1])
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jnp.expm1(x))
|
||||
[ 6.39 -0.98 19.09 -0.63]
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jnp.exp(x)-1)
|
||||
[ 6.39 -0.98 19.09 -0.63]
|
||||
|
||||
For values very close to 0, ``jnp.expm1(x)`` is much more accurate than
|
||||
``jnp.exp(x)-1``:
|
||||
|
||||
>>> x1 = jnp.array([1e-4, 1e-6, 2e-10])
|
||||
>>> jnp.expm1(x1)
|
||||
Array([1.0000500e-04, 1.0000005e-06, 2.0000000e-10], dtype=float32)
|
||||
>>> jnp.exp(x1)-1
|
||||
Array([1.00016594e-04, 9.53674316e-07, 0.00000000e+00], dtype=float32)
|
||||
"""
|
||||
return lax.expm1(*promote_args_inexact('expm1', x))
|
||||
|
||||
@implements(np.log1p, module='numpy')
|
||||
@ -970,9 +1045,36 @@ def log10(x: ArrayLike, /) -> Array:
|
||||
return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))
|
||||
|
||||
|
||||
@implements(np.exp2, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def exp2(x: ArrayLike, /) -> Array:
|
||||
"""Calculate element-wise base-2 exponential of input.
|
||||
|
||||
JAX implementation of :obj:`numpy.exp2`.
|
||||
|
||||
Args:
|
||||
x: input array or scalar
|
||||
|
||||
Returns:
|
||||
An array containing the base-2 exponential of each element in ``x``, promotes
|
||||
to inexact dtype.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.log2`: Calculates base-2 logarithm of each element of input.
|
||||
- :func:`jax.numpy.exp`: Calculates exponential of each element of the input.
|
||||
- :func:`jax.numpy.expm1`: Calculates :math:`e^x-1` of each element of the
|
||||
input.
|
||||
|
||||
Examples:
|
||||
``jnp.exp2`` follows the properties of the exponential such as :math:`2^{a+b}
|
||||
= 2^a * 2^b`.
|
||||
|
||||
>>> x1 = jnp.array([2, -4, 3, -1])
|
||||
>>> x2 = jnp.array([-1, 3, -2, 3])
|
||||
>>> jnp.exp2(x1+x2)
|
||||
Array([2. , 0.5, 2. , 4. ], dtype=float32)
|
||||
>>> jnp.exp2(x1)*jnp.exp2(x2)
|
||||
Array([2. , 0.5, 2. , 4. ], dtype=float32)
|
||||
"""
|
||||
x, = promote_args_inexact("exp2", x)
|
||||
return lax.exp2(x)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user