mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Improve precision of chlo.sinh.
Update chlo.sinh lowering to match xla::Sinh(), see https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1311?q=xla%20sinh [JAX] Use chlo.sinh instead of the XLA client library HLO lowering. PiperOrigin-RevId: 441851170
This commit is contained in:
parent
78fb120f86
commit
c2fe97ae01
@ -1740,9 +1740,8 @@ mlir.register_lowering(atan2_p, partial(_nary_lower_mhlo, mhlo.Atan2Op))
|
||||
|
||||
sinh_p = standard_unop(_float | _complex, 'sinh')
|
||||
ad.defjvp(sinh_p, lambda g, x: mul(g, cosh(x)))
|
||||
# TODO(b/209505237): the CHLO lowering of chlo.sinh is less accurate than that
|
||||
# in the XLA client library. Use the fallback path for now.
|
||||
# mlir.register_lowering(sinh_p, partial(_nary_lower_mhlo, chlo.SinhOp))
|
||||
if jax._src.lib.mlir_api_version >= 7:
|
||||
mlir.register_lowering(sinh_p, partial(_nary_lower_mhlo, chlo.SinhOp))
|
||||
|
||||
cosh_p = standard_unop(_float | _complex, 'cosh')
|
||||
ad.defjvp(cosh_p, lambda g, x: mul(g, sinh(x)))
|
||||
|
@ -443,7 +443,7 @@ def main(_):
|
||||
print_ir(np.float32(0))(lax.sin)
|
||||
|
||||
# CHECK-LABEL: TEST: sinh float32[]
|
||||
# CHECK: xla_fallback_sinh
|
||||
# CHECK: chlo.sinh
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.sinh)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user