Merge pull request #23261 from joaospinto:stablehlo.tan

PiperOrigin-RevId: 675973798
This commit is contained in:
jax authors 2024-09-18 06:56:28 -07:00
commit e15ec1e8ab
2 changed files with 6 additions and 6 deletions

View File

@ -340,6 +340,10 @@ def cos(x: ArrayLike) -> Array:
r"""Elementwise cosine: :math:`\mathrm{cos}(x)`."""
return cos_p.bind(x)
def tan(x: ArrayLike) -> Array:
r"""Elementwise tangent: :math:`\mathrm{tan}(x)`."""
return tan_p.bind(x)
def atan2(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise arc tangent of two variables:
:math:`\mathrm{atan}({x \over y})`."""
@ -1549,10 +1553,6 @@ def _upcast_fp16_for_computation(f):
return f_wrapped
def tan(x: ArrayLike) -> Array:
r"""Elementwise tangent: :math:`\mathrm{tan}(x)`."""
return tan_p.bind(x)
def asin(x: ArrayLike) -> Array:
r"""Elementwise arc sine: :math:`\mathrm{asin}(x)`."""
return asin_p.bind(x)
@ -2014,7 +2014,7 @@ def _tan_impl(x):
tan_p = standard_unop(_float | _complex, 'tan')
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans)))
mlir.register_lowering(tan_p, partial(_nary_lower_hlo, chlo.tan))
mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan))
def asin_impl(x):
if dtypes.issubdtype(_dtype(x), np.complexfloating):

View File

@ -419,7 +419,7 @@ def main(_):
print_ir(jnp.bfloat16(0))(lax.sqrt)
# CHECK-LABEL: TEST: tan float16[]
# CHECK: chlo.tan
# CHECK: hlo.tan
# CHECK-SAME: tensor<f16>
print_ir(np.float16(0))(lax.tan)