mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Lower bessel_i1e primitive to chlo.bessel_i1e.
PiperOrigin-RevId: 467996329
This commit is contained in:
parent
78c231e825
commit
0f089e1901
@ -1887,6 +1887,11 @@ ad.defjvp2(bessel_i0e_p, lambda g, y, x: g * (bessel_i1e(x) - sign(x) * y))
|
||||
|
||||
bessel_i1e_p = standard_unop(_float, 'bessel_i1e')
|
||||
xla.register_translation(bessel_i1e_p, standard_translate(bessel_i1e_p))
|
||||
if jax._src.lib.mlir_api_version < 32:
|
||||
xla.register_translation(bessel_i1e_p, standard_translate(bessel_i1e_p))
|
||||
else:
|
||||
mlir.register_lowering(bessel_i1e_p,
|
||||
partial(_nary_lower_mhlo, chlo.BesselI1eOp))
|
||||
def _bessel_i1e_jvp(g, y, x):
|
||||
eps = dtypes.finfo(_dtype(x)).eps
|
||||
x_is_not_tiny = abs(x) > eps
|
||||
|
@ -81,7 +81,7 @@ def main(_):
|
||||
print_ir(np.float32(0))(lax.bessel_i0e)
|
||||
|
||||
# CHECK-LABEL: TEST: bessel_i1e float32[]
|
||||
# CHECK: xla_fallback_bessel_i1e
|
||||
# CHECK: chlo.bessel_i1e
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.bessel_i1e)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user