Migrate besseli0e off xla_fallback

PiperOrigin-RevId: 519241252
This commit is contained in:
Anish Tondwalkar 2023-03-24 14:39:05 -07:00 committed by jax authors
parent 257ac6a993
commit ac44d2c2e3
2 changed files with 111 additions and 12 deletions

View File

@ -21,16 +21,15 @@ from enum import Enum
import numpy as np
from functools import partial
from jax._src.lax.lax import (bitwise_and, bitwise_not, bitwise_or,
from jax._src.lax.lax import (add, bitwise_and, bitwise_not, bitwise_or,
broadcast_in_dim, broadcast_shapes,
convert_element_type, eq, exp, full_like,
gt, le, log, log1p, lt, mul, neg, reciprocal,
reduce, select, sign, square, standard_naryop,
standard_unop, xla, xops, ne, div, sub, add,
convert_element_type, div, eq, exp, full_like,
gt, le, log, log1p, lt, mul, ne, neg, reciprocal,
reduce, select, sign, sqrt, square,
standard_naryop, standard_unop, sub, xla, xops,
_broadcast_translate, _const, _dtype, _float,
_nary_lower_hlo, _ones, _isnan, _reduce)
from jax._src.lax.control_flow import while_loop
from jax._src.lax.utils import (standard_translate)
from jax._src import dtypes
from jax._src.interpreters import ad
@ -399,6 +398,109 @@ def _up_and_broadcast(doit):
return result
return up_and_broadcast
def evaluate_chebyshev_polynomial(x, coefficients):
b0 = full_like(x,0)
b1 = full_like(x,0)
b2 = full_like(x,0)
for c in coefficients:
b2 = b1
b1 = b0
b0 = x * b1 - b2 + full_like(x, c)
return 0.5 * (b0 - b2)
def _i0e_impl32(x):
"""
Computes an approximation to the modified Bessel function of the first kind,
zeroth order. The following implementation follows Cephes' F32 and F64
implementation of i0e.
"""
i0e_coeffs_a = np.array(
[-1.30002500998624804212E-8, 6.04699502254191894932E-8,
-2.67079385394061173391E-7, 1.11738753912010371815E-6,
-4.41673835845875056359E-6, 1.64484480707288970893E-5,
-5.75419501008210370398E-5, 1.88502885095841655729E-4,
-5.76375574538582365885E-4, 1.63947561694133579842E-3,
-4.32430999505057594430E-3, 1.05464603945949983183E-2,
-2.37374148058994688156E-2, 4.93052842396707084878E-2,
-9.49010970480476444210E-2, 1.71620901522208775349E-1,
-3.04682672343198398683E-1, 6.76795274409476084995E-1]
)
i0e_coeffs_b = np.array(
[3.39623202570838634515E-9, 2.26666899049817806459E-8,
2.04891858946906374183E-7, 2.89137052083475648297E-6,
6.88975834691682398426E-5, 3.36911647825569408990E-3,
8.04490411014108831608E-1]
)
x = abs(x)
half = full_like(x, 0.5)
two = full_like(x, 2.0)
thirty_two = full_like(x, 32.0)
result_le_8 = evaluate_chebyshev_polynomial(half * x - two, i0e_coeffs_a)
result_gt_8 = div(evaluate_chebyshev_polynomial(thirty_two / x - two,
i0e_coeffs_b), sqrt(x))
return select(x <= 8.0, result_le_8, result_gt_8)
def _i0e_impl64(x):
i0e_coeffs_a = np.array(
[-4.41534164647933937950E-18, 3.33079451882223809783E-17,
-2.43127984654795469359E-16, 1.71539128555513303061E-15,
-1.16853328779934516808E-14, 7.67618549860493561688E-14,
-4.85644678311192946090E-13, 2.95505266312963983461E-12,
-1.72682629144155570723E-11, 9.67580903537323691224E-11,
-5.18979560163526290666E-10, 2.65982372468238665035E-9,
-1.30002500998624804212E-8, 6.04699502254191894932E-8,
-2.67079385394061173391E-7, 1.11738753912010371815E-6,
-4.41673835845875056359E-6, 1.64484480707288970893E-5,
-5.75419501008210370398E-5, 1.88502885095841655729E-4,
-5.76375574538582365885E-4, 1.63947561694133579842E-3,
-4.32430999505057594430E-3, 1.05464603945949983183E-2,
-2.37374148058994688156E-2, 4.93052842396707084878E-2,
-9.49010970480476444210E-2, 1.71620901522208775349E-1,
-3.04682672343198398683E-1, 6.76795274409476084995E-1]
)
i0e_coeffs_b = np.array(
[-7.23318048787475395456E-18, -4.83050448594418207126E-18,
4.46562142029675999901E-17, 3.46122286769746109310E-17,
-2.82762398051658348494E-16, -3.42548561967721913462E-16,
1.77256013305652638360E-15, 3.81168066935262242075E-15,
-9.55484669882830764870E-15, -4.15056934728722208663E-14,
1.54008621752140982691E-14, 3.85277838274214270114E-13,
7.18012445138366623367E-13, -1.79417853150680611778E-12,
-1.32158118404477131188E-11, -3.14991652796324136454E-11,
1.18891471078464383424E-11, 4.94060238822496958910E-10,
3.39623202570838634515E-9, 2.26666899049817806459E-8,
2.04891858946906374183E-7, 2.89137052083475648297E-6,
6.88975834691682398426E-5, 3.36911647825569408990E-3,
8.04490411014108831608E-1]
)
x = abs(x)
half = full_like(x, 0.5)
two = full_like(x, 2.0)
thirty_two = full_like(x, 32.0)
result_le_8 = evaluate_chebyshev_polynomial(half * x - two, i0e_coeffs_a)
result_gt_8 = div(evaluate_chebyshev_polynomial(thirty_two / x - two,
i0e_coeffs_b), sqrt(x))
return select(x <= 8.0, result_le_8, result_gt_8)
def bessel_i0e_impl(x):
if x.dtype == np.float64:
return _i0e_impl64(x)
elif x.dtype == np.float32:
return _i0e_impl32(x)
else:
# Have to upcast f16 because the magic Cephes coefficents don't have enough
# precision for it.
x_dtype = x.dtype
x = x.astype(np.float32)
return convert_element_type(_i0e_impl32(x), x_dtype)
lgamma_p = standard_unop(_float, 'lgamma')
ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x)))
mlir.register_lowering(lgamma_p, partial(_nary_lower_hlo, chlo.LgammaOp))
@ -430,7 +532,9 @@ mlir.register_lowering(random_gamma_grad_p,
multiple_results=False))
bessel_i0e_p = standard_unop(_float, 'bessel_i0e')
xla.register_translation(bessel_i0e_p, standard_translate(bessel_i0e_p))
mlir.register_lowering(bessel_i0e_p,
mlir.lower_fun(bessel_i0e_impl,
multiple_results=False))
ad.defjvp2(bessel_i0e_p, lambda g, y, x: g * (bessel_i1e(x) - sign(x) * y))
bessel_i1e_p = standard_unop(_float, 'bessel_i1e')

View File

@ -75,11 +75,6 @@ def main(_):
# CHECK-SAME: tensor<f64>
print_ir(np.float64(1), np.float64(2))(lax.atan2)
# CHECK-LABEL: TEST: bessel_i0e float32[]
# CHECK: xla_fallback_bessel_i0e
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.bessel_i0e)
# CHECK-LABEL: TEST: bessel_i1e float32[]
# CHECK: chlo.bessel_i1e
# CHECK-SAME: tensor<f32>