mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[MHLO] Add MHLO lowering for erf and erfc
erf implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=319-336;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e PiperOrigin-RevId: 443665435
This commit is contained in:
parent
5f16873aad
commit
0ed29b63f0
@ -1806,12 +1806,18 @@ ad.defjvp2(bessel_i1e_p, _bessel_i1e_jvp)
|
||||
erf_p = standard_unop(_float, 'erf')
|
||||
ad.defjvp(erf_p, lambda g, x: mul(_const(x, 2. / np.sqrt(np.pi)),
|
||||
mul(g, exp(neg(square(x))))))
|
||||
xla.register_translation(erf_p, standard_translate(erf_p))
|
||||
if jax._src.lib.mlir_api_version >= 12:
|
||||
mlir.register_lowering(erf_p, partial(_nary_lower_mhlo, chlo.ErfOp))
|
||||
else:
|
||||
xla.register_translation(erf_p, standard_translate(erf_p))
|
||||
|
||||
erfc_p = standard_unop(_float, 'erfc')
|
||||
ad.defjvp(erfc_p, lambda g, x: mul(_const(x, -2. / np.sqrt(np.pi)),
|
||||
mul(g, exp(neg(square(x))))))
|
||||
xla.register_translation(erfc_p, standard_translate(erfc_p))
|
||||
if jax._src.lib.mlir_api_version >= 12:
|
||||
mlir.register_lowering(erfc_p, partial(_nary_lower_mhlo, chlo.ErfcOp))
|
||||
else:
|
||||
xla.register_translation(erfc_p, standard_translate(erfc_p))
|
||||
|
||||
erf_inv_p = standard_unop(_float, 'erf_inv')
|
||||
ad.defjvp2(erf_inv_p, lambda g, ans, x: mul(_const(x, np.sqrt(np.pi) / 2.),
|
||||
|
@ -225,12 +225,12 @@ def main(_):
|
||||
print_ir(np.uint16(1), np.uint16(2))(lax.eq)
|
||||
|
||||
# CHECK-LABEL: TEST: erf float32[]
|
||||
# CHECK: xla_fallback_erf
|
||||
# CHECK: chlo.erf
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.erf)
|
||||
|
||||
# CHECK-LABEL: TEST: erfc float32[]
|
||||
# CHECK: xla_fallback_erfc
|
||||
# CHECK: chlo.erfc
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.erfc)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user