Merge pull request #23365 from rajasekharporeddy:testbranch2

PiperOrigin-RevId: 670380496
This commit is contained in:
jax authors 2024-09-02 20:26:07 -07:00
commit 826e661347

View File

@ -428,19 +428,95 @@ def expm1(x: ArrayLike, /) -> Array:
def log1p(x: ArrayLike, /) -> Array:
return lax.log1p(*promote_args_inexact('log1p', x))
@implements(np.sin, module='numpy')
@partial(jit, inline=True)
def sin(x: ArrayLike, /) -> Array:
"""Compute a trigonometric sine of each element of input.
JAX implementation of :obj:`numpy.sin`.
Args:
x: array or scalar. Angle in radians.
Returns:
An array containing the sine of each element in ``x``, promotes to inexact
dtype.
See also:
- :func:`jax.numpy.cos`: Computes a trigonometric cosine of each element of
input.
- :func:`jax.numpy.tan`: Computes a trigonometric tangent of each element of
input.
- :func:`jax.numpy.arcsin` and :func:`jax.numpy.asin`: Computes the inverse of
trigonometric sine of each element of input.
Examples:
>>> pi = jnp.pi
>>> x = jnp.array([pi/4, pi/2, 3*pi/4, pi])
>>> with jnp.printoptions(precision=3, suppress=True):
... print(jnp.sin(x))
[ 0.707 1. 0.707 -0. ]
"""
return lax.sin(*promote_args_inexact('sin', x))
@implements(np.cos, module='numpy')
@partial(jit, inline=True)
def cos(x: ArrayLike, /) -> Array:
"""Compute a trigonometric cosine of each element of input.
JAX implementation of :obj:`numpy.cos`.
Args:
x: scalar or array. Angle in radians.
Returns:
An array containing the cosine of each element in ``x``, promotes to inexact
dtype.
See also:
- :func:`jax.numpy.sin`: Computes a trigonometric sine of each element of input.
- :func:`jax.numpy.tan`: Computes a trigonometric tangent of each element of
input.
- :func:`jax.numpy.arccos` and :func:`jax.numpy.acos`: Computes the inverse of
trigonometric cosine of each element of input.
Examples:
>>> pi = jnp.pi
>>> x = jnp.array([pi/4, pi/2, 3*pi/4, 5*pi/6])
>>> with jnp.printoptions(precision=3, suppress=True):
... print(jnp.cos(x))
[ 0.707 -0. -0.707 -0.866]
"""
return lax.cos(*promote_args_inexact('cos', x))
@implements(np.tan, module='numpy')
@partial(jit, inline=True)
def tan(x: ArrayLike, /) -> Array:
"""Compute a trigonometric tangent of each element of input.
JAX implementation of :obj:`numpy.tan`.
Args:
x: scalar or array. Angle in radians.
Returns:
An array containing the tangent of each element in ``x``, promotes to inexact
dtype.
See also:
- :func:`jax.numpy.sin`: Computes a trigonometric sine of each element of input.
- :func:`jax.numpy.cos`: Computes a trigonometric cosine of each element of
input.
- :func:`jax.numpy.arctan` and :func:`jax.numpy.atan`: Computes the inverse of
trigonometric tangent of each element of input.
Examples:
>>> pi = jnp.pi
>>> x = jnp.array([0, pi/6, pi/4, 3*pi/4, 5*pi/6])
>>> with jnp.printoptions(precision=3, suppress=True):
... print(jnp.tan(x))
[ 0. 0.577 1. -1. -0.577]
"""
return lax.tan(*promote_args_inexact('tan', x))
@implements(np.arcsin, module='numpy')