[MHLO] Switch tan to use CHLO lowering

Currently, it's desugared to sin(x)/cos(x) with upcast because CHLO_TanOp
legalization doesn't support complex numbers.

tan implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1175-1177;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

PiperOrigin-RevId: 443649394
This commit is contained in:
Eugene Burmako 2022-04-22 06:27:28 -07:00 committed by jax authors
parent 636345fd67
commit 5f16873aad
2 changed files with 7 additions and 5 deletions

View File

@ -1643,7 +1643,11 @@ def _tan_impl(x):
tan_p = standard_unop(_float | _complex, 'tan')
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans)))
mlir.register_lowering(tan_p, mlir.lower_fun(_tan_impl, multiple_results=False))
if jax._src.lib.mlir_api_version >= 11:
mlir.register_lowering(tan_p, partial(_nary_lower_mhlo, chlo.TanOp))
else:
mlir.register_lowering(tan_p,
mlir.lower_fun(_tan_impl, multiple_results=False))
def asin_impl(x):
if dtypes.issubdtype(_dtype(x), np.complexfloating):

View File

@ -458,10 +458,8 @@ def main(_):
print_ir(jnp.bfloat16(0))(lax.sqrt)
# CHECK-LABEL: TEST: tan float16[]
# CHECK: mhlo.sine
# CHECK-SAME: tensor<f32>
# CHECK: mhlo.cosine
# CHECK-SAME: tensor<f32>
# CHECK: chlo.tan
# CHECK-SAME: tensor<f16>
print_ir(np.float16(0))(lax.tan)
# CHECK-LABEL: TEST: tanh float32[]