[CHLO] Add erf_inv and lowering to mhlo

PiperOrigin-RevId: 513183138
This commit is contained in:
Anish Tondwalkar 2023-03-01 02:52:16 -08:00 committed by jax authors
parent 713bc2687d
commit 3bad6fa223
2 changed files with 5 additions and 2 deletions

View File

@ -1988,7 +1988,10 @@ mlir.register_lowering(erfc_p, partial(_nary_lower_hlo, chlo.ErfcOp))
erf_inv_p = standard_unop(_float, 'erf_inv')
ad.defjvp2(erf_inv_p, lambda g, ans, x: mul(_const(x, np.sqrt(np.pi) / 2.),
mul(g, exp(square(ans)))))
xla.register_translation(erf_inv_p, standard_translate(erf_inv_p))
if xla_client.mlir_api_version >= 45:
mlir.register_lowering(erf_inv_p, partial(_nary_lower_hlo, chlo.ErfInvOp))
else:
xla.register_translation(erf_inv_p, standard_translate(erf_inv_p))
real_p = unop(_complex_basetype, _complex, 'real')
ad.deflinear2(real_p, lambda t, _: [complex(t, np.zeros((), _dtype(t)))])

View File

@ -231,7 +231,7 @@ def main(_):
print_ir(np.float32(0))(lax.erfc)
# CHECK-LABEL: TEST: erf_inv float32[]
# CHECK: xla_fallback_erf_inv
# CHECK: chlo.erf_inv
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.erf_inv)