mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
[CHLO] Add erf_inv and lowering to mhlo
PiperOrigin-RevId: 513183138
This commit is contained in:
parent
713bc2687d
commit
3bad6fa223
@ -1988,7 +1988,10 @@ mlir.register_lowering(erfc_p, partial(_nary_lower_hlo, chlo.ErfcOp))
|
||||
erf_inv_p = standard_unop(_float, 'erf_inv')
|
||||
ad.defjvp2(erf_inv_p, lambda g, ans, x: mul(_const(x, np.sqrt(np.pi) / 2.),
|
||||
mul(g, exp(square(ans)))))
|
||||
xla.register_translation(erf_inv_p, standard_translate(erf_inv_p))
|
||||
if xla_client.mlir_api_version >= 45:
|
||||
mlir.register_lowering(erf_inv_p, partial(_nary_lower_hlo, chlo.ErfInvOp))
|
||||
else:
|
||||
xla.register_translation(erf_inv_p, standard_translate(erf_inv_p))
|
||||
|
||||
real_p = unop(_complex_basetype, _complex, 'real')
|
||||
ad.deflinear2(real_p, lambda t, _: [complex(t, np.zeros((), _dtype(t)))])
|
||||
|
@ -231,7 +231,7 @@ def main(_):
|
||||
print_ir(np.float32(0))(lax.erfc)
|
||||
|
||||
# CHECK-LABEL: TEST: erf_inv float32[]
|
||||
# CHECK: xla_fallback_erf_inv
|
||||
# CHECK: chlo.erf_inv
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.erf_inv)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user