mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
[MHLO] Add MHLO lowerings of remaining ops blocked by the lack of complex support in CHLO
The affected ops are: acosh, asinh and atanh (in addition to cosh which was fixed a few days ago). acosh implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1181-1216;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e asinh implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1218-1270;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e atanh implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1272-1292;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e PiperOrigin-RevId: 443590210
This commit is contained in:
parent
5013bd2e3a
commit
636345fd67
@ -1699,23 +1699,35 @@ mlir.register_lowering(sinh_p, partial(_nary_lower_mhlo, chlo.SinhOp))
|
||||
|
||||
cosh_p = standard_unop(_float | _complex, 'cosh')
|
||||
ad.defjvp(cosh_p, lambda g, x: mul(g, sinh(x)))
|
||||
xla.register_translation(cosh_p, standard_translate(cosh_p))
|
||||
if jax._src.lib.mlir_api_version >= 8:
|
||||
if jax._src.lib.mlir_api_version >= 10:
|
||||
mlir.register_lowering(cosh_p, partial(_nary_lower_mhlo, chlo.CoshOp))
|
||||
else:
|
||||
xla.register_translation(cosh_p, standard_translate(cosh_p))
|
||||
if jax._src.lib.mlir_api_version >= 8:
|
||||
mlir.register_lowering(cosh_p, partial(_nary_lower_mhlo, chlo.CoshOp))
|
||||
|
||||
asinh_p = standard_unop(_float | _complex, 'asinh')
|
||||
ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(square(x) + _one(x))))
|
||||
xla.register_translation(asinh_p, standard_translate(asinh_p))
|
||||
if jax._src.lib.mlir_api_version >= 10:
|
||||
mlir.register_lowering(asinh_p, partial(_nary_lower_mhlo, chlo.AsinhOp))
|
||||
else:
|
||||
xla.register_translation(asinh_p, standard_translate(asinh_p))
|
||||
|
||||
acosh_p = standard_unop(_float | _complex, 'acosh')
|
||||
xla.register_translation(acosh_p, standard_translate(acosh_p))
|
||||
ad.defjvp(acosh_p,
|
||||
lambda g, x: mul(g, rsqrt((x - _one(x)) * (x + _one(x)))))
|
||||
if jax._src.lib.mlir_api_version >= 10:
|
||||
mlir.register_lowering(acosh_p, partial(_nary_lower_mhlo, chlo.AcoshOp))
|
||||
else:
|
||||
xla.register_translation(acosh_p, standard_translate(acosh_p))
|
||||
|
||||
atanh_p = standard_unop(_float | _complex, 'atanh')
|
||||
xla.register_translation(atanh_p, standard_translate(atanh_p))
|
||||
ad.defjvp(atanh_p,
|
||||
lambda g, x: mul(reciprocal(_one(x) + x), div(g, (_one(x) - x))))
|
||||
if jax._src.lib.mlir_api_version >= 10:
|
||||
mlir.register_lowering(atanh_p, partial(_nary_lower_mhlo, chlo.AtanhOp))
|
||||
else:
|
||||
xla.register_translation(atanh_p, standard_translate(atanh_p))
|
||||
|
||||
regularized_incomplete_beta_p = standard_naryop(
|
||||
[_float, _float, _float], 'regularized_incomplete_beta')
|
||||
|
@ -1219,15 +1219,6 @@ register_lowering(ad.custom_lin_p, ad._raise_custom_vjp_error_on_jvp)
|
||||
# lax.erf_inv_p,
|
||||
# lax.regularized_incomplete_beta_p,
|
||||
|
||||
# # CHLO doesn't have complex lowerings of these primitives (b/203718937)
|
||||
# lax.acos_p,
|
||||
# lax.acosh_p,
|
||||
# lax.asin_p,
|
||||
# lax.asinh_p,
|
||||
# lax.atan_p,
|
||||
# lax.atanh_p,
|
||||
# lax.tan_p,
|
||||
|
||||
# # CHLO doesn't have a legalization for bf16 (b/203774470)
|
||||
# lax.erf_p,
|
||||
# lax.erfc_p,
|
||||
|
@ -46,7 +46,7 @@ def main(_):
|
||||
print_ir(np.float32(1))(lax.acos)
|
||||
|
||||
# CHECK-LABEL: TEST: acosh float32[]
|
||||
# CHECK: xla_fallback_acosh
|
||||
# CHECK: chlo.acosh
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.acosh)
|
||||
|
||||
@ -56,7 +56,7 @@ def main(_):
|
||||
print_ir(np.float32(1))(lax.asin)
|
||||
|
||||
# CHECK-LABEL: TEST: asinh float32[]
|
||||
# CHECK: xla_fallback_asinh
|
||||
# CHECK: chlo.asinh
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.asinh)
|
||||
|
||||
@ -66,7 +66,7 @@ def main(_):
|
||||
print_ir(np.float32(1))(lax.atan)
|
||||
|
||||
# CHECK-LABEL: TEST: atanh float32[]
|
||||
# CHECK: xla_fallback_atanh
|
||||
# CHECK: chlo.atanh
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.atanh)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user