Lower bessel_i1e primitive to chlo.bessel_i1e.

PiperOrigin-RevId: 467996329
This commit is contained in:
Bixia Zheng 2022-08-16 12:38:09 -07:00 committed by jax authors
parent 78c231e825
commit 0f089e1901
2 changed files with 6 additions and 1 deletions

View File

@ -1887,6 +1887,11 @@ ad.defjvp2(bessel_i0e_p, lambda g, y, x: g * (bessel_i1e(x) - sign(x) * y))
bessel_i1e_p = standard_unop(_float, 'bessel_i1e')
xla.register_translation(bessel_i1e_p, standard_translate(bessel_i1e_p))
if jax._src.lib.mlir_api_version < 32:
xla.register_translation(bessel_i1e_p, standard_translate(bessel_i1e_p))
else:
mlir.register_lowering(bessel_i1e_p,
partial(_nary_lower_mhlo, chlo.BesselI1eOp))
def _bessel_i1e_jvp(g, y, x):
eps = dtypes.finfo(_dtype(x)).eps
x_is_not_tiny = abs(x) > eps

View File

@ -81,7 +81,7 @@ def main(_):
print_ir(np.float32(0))(lax.bessel_i0e)
# CHECK-LABEL: TEST: bessel_i1e float32[]
# CHECK: xla_fallback_bessel_i1e
# CHECK: chlo.bessel_i1e
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.bessel_i1e)