mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #23365 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 670380496
This commit is contained in:
commit
826e661347
@ -428,19 +428,95 @@ def expm1(x: ArrayLike, /) -> Array:
|
||||
def log1p(x: ArrayLike, /) -> Array:
|
||||
return lax.log1p(*promote_args_inexact('log1p', x))
|
||||
|
||||
@implements(np.sin, module='numpy')
|
||||
|
||||
@partial(jit, inline=True)
|
||||
def sin(x: ArrayLike, /) -> Array:
|
||||
"""Compute a trigonometric sine of each element of input.
|
||||
|
||||
JAX implementation of :obj:`numpy.sin`.
|
||||
|
||||
Args:
|
||||
x: array or scalar. Angle in radians.
|
||||
|
||||
Returns:
|
||||
An array containing the sine of each element in ``x``, promotes to inexact
|
||||
dtype.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.cos`: Computes a trigonometric cosine of each element of
|
||||
input.
|
||||
- :func:`jax.numpy.tan`: Computes a trigonometric tangent of each element of
|
||||
input.
|
||||
- :func:`jax.numpy.arcsin` and :func:`jax.numpy.asin`: Computes the inverse of
|
||||
trigonometric sine of each element of input.
|
||||
|
||||
Examples:
|
||||
>>> pi = jnp.pi
|
||||
>>> x = jnp.array([pi/4, pi/2, 3*pi/4, pi])
|
||||
>>> with jnp.printoptions(precision=3, suppress=True):
|
||||
... print(jnp.sin(x))
|
||||
[ 0.707 1. 0.707 -0. ]
|
||||
"""
|
||||
return lax.sin(*promote_args_inexact('sin', x))
|
||||
|
||||
@implements(np.cos, module='numpy')
|
||||
|
||||
@partial(jit, inline=True)
|
||||
def cos(x: ArrayLike, /) -> Array:
|
||||
"""Compute a trigonometric cosine of each element of input.
|
||||
|
||||
JAX implementation of :obj:`numpy.cos`.
|
||||
|
||||
Args:
|
||||
x: scalar or array. Angle in radians.
|
||||
|
||||
Returns:
|
||||
An array containing the cosine of each element in ``x``, promotes to inexact
|
||||
dtype.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.sin`: Computes a trigonometric sine of each element of input.
|
||||
- :func:`jax.numpy.tan`: Computes a trigonometric tangent of each element of
|
||||
input.
|
||||
- :func:`jax.numpy.arccos` and :func:`jax.numpy.acos`: Computes the inverse of
|
||||
trigonometric cosine of each element of input.
|
||||
|
||||
Examples:
|
||||
>>> pi = jnp.pi
|
||||
>>> x = jnp.array([pi/4, pi/2, 3*pi/4, 5*pi/6])
|
||||
>>> with jnp.printoptions(precision=3, suppress=True):
|
||||
... print(jnp.cos(x))
|
||||
[ 0.707 -0. -0.707 -0.866]
|
||||
"""
|
||||
return lax.cos(*promote_args_inexact('cos', x))
|
||||
|
||||
@implements(np.tan, module='numpy')
|
||||
|
||||
@partial(jit, inline=True)
|
||||
def tan(x: ArrayLike, /) -> Array:
|
||||
"""Compute a trigonometric tangent of each element of input.
|
||||
|
||||
JAX implementation of :obj:`numpy.tan`.
|
||||
|
||||
Args:
|
||||
x: scalar or array. Angle in radians.
|
||||
|
||||
Returns:
|
||||
An array containing the tangent of each element in ``x``, promotes to inexact
|
||||
dtype.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.sin`: Computes a trigonometric sine of each element of input.
|
||||
- :func:`jax.numpy.cos`: Computes a trigonometric cosine of each element of
|
||||
input.
|
||||
- :func:`jax.numpy.arctan` and :func:`jax.numpy.atan`: Computes the inverse of
|
||||
trigonometric tangent of each element of input.
|
||||
|
||||
Examples:
|
||||
>>> pi = jnp.pi
|
||||
>>> x = jnp.array([0, pi/6, pi/4, 3*pi/4, 5*pi/6])
|
||||
>>> with jnp.printoptions(precision=3, suppress=True):
|
||||
... print(jnp.tan(x))
|
||||
[ 0. 0.577 1. -1. -0.577]
|
||||
"""
|
||||
return lax.tan(*promote_args_inexact('tan', x))
|
||||
|
||||
@implements(np.arcsin, module='numpy')
|
||||
|
Loading…
x
Reference in New Issue
Block a user