mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #23261 from joaospinto:stablehlo.tan
PiperOrigin-RevId: 675973798
This commit is contained in:
commit
e15ec1e8ab
@ -340,6 +340,10 @@ def cos(x: ArrayLike) -> Array:
|
||||
r"""Elementwise cosine: :math:`\mathrm{cos}(x)`."""
|
||||
return cos_p.bind(x)
|
||||
|
||||
def tan(x: ArrayLike) -> Array:
|
||||
r"""Elementwise tangent: :math:`\mathrm{tan}(x)`."""
|
||||
return tan_p.bind(x)
|
||||
|
||||
def atan2(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
r"""Elementwise arc tangent of two variables:
|
||||
:math:`\mathrm{atan}({x \over y})`."""
|
||||
@ -1549,10 +1553,6 @@ def _upcast_fp16_for_computation(f):
|
||||
|
||||
return f_wrapped
|
||||
|
||||
def tan(x: ArrayLike) -> Array:
|
||||
r"""Elementwise tangent: :math:`\mathrm{tan}(x)`."""
|
||||
return tan_p.bind(x)
|
||||
|
||||
def asin(x: ArrayLike) -> Array:
|
||||
r"""Elementwise arc sine: :math:`\mathrm{asin}(x)`."""
|
||||
return asin_p.bind(x)
|
||||
@ -2014,7 +2014,7 @@ def _tan_impl(x):
|
||||
|
||||
tan_p = standard_unop(_float | _complex, 'tan')
|
||||
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans)))
|
||||
mlir.register_lowering(tan_p, partial(_nary_lower_hlo, chlo.tan))
|
||||
mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan))
|
||||
|
||||
def asin_impl(x):
|
||||
if dtypes.issubdtype(_dtype(x), np.complexfloating):
|
||||
|
@ -419,7 +419,7 @@ def main(_):
|
||||
print_ir(jnp.bfloat16(0))(lax.sqrt)
|
||||
|
||||
# CHECK-LABEL: TEST: tan float16[]
|
||||
# CHECK: chlo.tan
|
||||
# CHECK: hlo.tan
|
||||
# CHECK-SAME: tensor<f16>
|
||||
print_ir(np.float16(0))(lax.tan)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user