[MHLO] Add MHLO lowerings of remaining ops blocked by the lack of complex support in CHLO

The affected ops are: acosh, asinh and atanh
(in addition to cosh which was fixed a few days ago).

acosh implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1181-1216;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

asinh implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1218-1270;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

atanh implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1272-1292;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

PiperOrigin-RevId: 443590210
This commit is contained in:
Eugene Burmako 2022-04-22 00:38:30 -07:00 committed by jax authors
parent 5013bd2e3a
commit 636345fd67
3 changed files with 20 additions and 17 deletions

View File

@ -1699,23 +1699,35 @@ mlir.register_lowering(sinh_p, partial(_nary_lower_mhlo, chlo.SinhOp))
cosh_p = standard_unop(_float | _complex, 'cosh')
ad.defjvp(cosh_p, lambda g, x: mul(g, sinh(x)))
xla.register_translation(cosh_p, standard_translate(cosh_p))
if jax._src.lib.mlir_api_version >= 8:
if jax._src.lib.mlir_api_version >= 10:
mlir.register_lowering(cosh_p, partial(_nary_lower_mhlo, chlo.CoshOp))
else:
xla.register_translation(cosh_p, standard_translate(cosh_p))
if jax._src.lib.mlir_api_version >= 8:
mlir.register_lowering(cosh_p, partial(_nary_lower_mhlo, chlo.CoshOp))
asinh_p = standard_unop(_float | _complex, 'asinh')
ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(square(x) + _one(x))))
xla.register_translation(asinh_p, standard_translate(asinh_p))
if jax._src.lib.mlir_api_version >= 10:
mlir.register_lowering(asinh_p, partial(_nary_lower_mhlo, chlo.AsinhOp))
else:
xla.register_translation(asinh_p, standard_translate(asinh_p))
acosh_p = standard_unop(_float | _complex, 'acosh')
xla.register_translation(acosh_p, standard_translate(acosh_p))
ad.defjvp(acosh_p,
lambda g, x: mul(g, rsqrt((x - _one(x)) * (x + _one(x)))))
if jax._src.lib.mlir_api_version >= 10:
mlir.register_lowering(acosh_p, partial(_nary_lower_mhlo, chlo.AcoshOp))
else:
xla.register_translation(acosh_p, standard_translate(acosh_p))
atanh_p = standard_unop(_float | _complex, 'atanh')
xla.register_translation(atanh_p, standard_translate(atanh_p))
ad.defjvp(atanh_p,
lambda g, x: mul(reciprocal(_one(x) + x), div(g, (_one(x) - x))))
if jax._src.lib.mlir_api_version >= 10:
mlir.register_lowering(atanh_p, partial(_nary_lower_mhlo, chlo.AtanhOp))
else:
xla.register_translation(atanh_p, standard_translate(atanh_p))
regularized_incomplete_beta_p = standard_naryop(
[_float, _float, _float], 'regularized_incomplete_beta')

View File

@ -1219,15 +1219,6 @@ register_lowering(ad.custom_lin_p, ad._raise_custom_vjp_error_on_jvp)
# lax.erf_inv_p,
# lax.regularized_incomplete_beta_p,
# # CHLO doesn't have complex lowerings of these primitives (b/203718937)
# lax.acos_p,
# lax.acosh_p,
# lax.asin_p,
# lax.asinh_p,
# lax.atan_p,
# lax.atanh_p,
# lax.tan_p,
# # CHLO doesn't have a legalization for bf16 (b/203774470)
# lax.erf_p,
# lax.erfc_p,

View File

@ -46,7 +46,7 @@ def main(_):
print_ir(np.float32(1))(lax.acos)
# CHECK-LABEL: TEST: acosh float32[]
# CHECK: xla_fallback_acosh
# CHECK: chlo.acosh
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.acosh)
@ -56,7 +56,7 @@ def main(_):
print_ir(np.float32(1))(lax.asin)
# CHECK-LABEL: TEST: asinh float32[]
# CHECK: xla_fallback_asinh
# CHECK: chlo.asinh
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.asinh)
@ -66,7 +66,7 @@ def main(_):
print_ir(np.float32(1))(lax.atan)
# CHECK-LABEL: TEST: atanh float32[]
# CHECK: xla_fallback_atanh
# CHECK: chlo.atanh
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.atanh)