Improve precision of chlo.sinh.

Update chlo.sinh lowering to match xla::Sinh(), see https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1311?q=xla%20sinh

[JAX] Use chlo.sinh instead of the XLA client library HLO lowering.

PiperOrigin-RevId: 441851170
This commit is contained in:
Peter Hawkins 2022-04-14 14:09:54 -07:00 committed by jax authors
parent 78fb120f86
commit c2fe97ae01
2 changed files with 3 additions and 4 deletions

View File

@ -1740,9 +1740,8 @@ mlir.register_lowering(atan2_p, partial(_nary_lower_mhlo, mhlo.Atan2Op))
sinh_p = standard_unop(_float | _complex, 'sinh')
ad.defjvp(sinh_p, lambda g, x: mul(g, cosh(x)))
# TODO(b/209505237): the CHLO lowering of chlo.sinh is less accurate than that
# in the XLA client library. Use the fallback path for now.
# mlir.register_lowering(sinh_p, partial(_nary_lower_mhlo, chlo.SinhOp))
if jax._src.lib.mlir_api_version >= 7:
mlir.register_lowering(sinh_p, partial(_nary_lower_mhlo, chlo.SinhOp))
cosh_p = standard_unop(_float | _complex, 'cosh')
ad.defjvp(cosh_p, lambda g, x: mul(g, sinh(x)))

View File

@ -443,7 +443,7 @@ def main(_):
print_ir(np.float32(0))(lax.sin)
# CHECK-LABEL: TEST: sinh float32[]
# CHECK: xla_fallback_sinh
# CHECK: chlo.sinh
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.sinh)