[MHLO] Add MHLO lowering for erf and erfc

erf implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=319-336;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

PiperOrigin-RevId: 443665435
This commit is contained in:
Eugene Burmako 2022-04-22 07:53:49 -07:00 committed by jax authors
parent 5f16873aad
commit 0ed29b63f0
2 changed files with 10 additions and 4 deletions

View File

@ -1806,12 +1806,18 @@ ad.defjvp2(bessel_i1e_p, _bessel_i1e_jvp)
erf_p = standard_unop(_float, 'erf')
ad.defjvp(erf_p, lambda g, x: mul(_const(x, 2. / np.sqrt(np.pi)),
mul(g, exp(neg(square(x))))))
xla.register_translation(erf_p, standard_translate(erf_p))
if jax._src.lib.mlir_api_version >= 12:
mlir.register_lowering(erf_p, partial(_nary_lower_mhlo, chlo.ErfOp))
else:
xla.register_translation(erf_p, standard_translate(erf_p))
erfc_p = standard_unop(_float, 'erfc')
ad.defjvp(erfc_p, lambda g, x: mul(_const(x, -2. / np.sqrt(np.pi)),
mul(g, exp(neg(square(x))))))
xla.register_translation(erfc_p, standard_translate(erfc_p))
if jax._src.lib.mlir_api_version >= 12:
mlir.register_lowering(erfc_p, partial(_nary_lower_mhlo, chlo.ErfcOp))
else:
xla.register_translation(erfc_p, standard_translate(erfc_p))
erf_inv_p = standard_unop(_float, 'erf_inv')
ad.defjvp2(erf_inv_p, lambda g, ans, x: mul(_const(x, np.sqrt(np.pi) / 2.),

View File

@ -225,12 +225,12 @@ def main(_):
print_ir(np.uint16(1), np.uint16(2))(lax.eq)
# CHECK-LABEL: TEST: erf float32[]
# CHECK: xla_fallback_erf
# CHECK: chlo.erf
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.erf)
# CHECK-LABEL: TEST: erfc float32[]
# CHECK: xla_fallback_erfc
# CHECK: chlo.erfc
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.erfc)