mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Migrate besseli0e off xla_fallback
PiperOrigin-RevId: 519241252
This commit is contained in:
parent
257ac6a993
commit
ac44d2c2e3
@ -21,16 +21,15 @@ from enum import Enum
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
|
||||
from jax._src.lax.lax import (bitwise_and, bitwise_not, bitwise_or,
|
||||
from jax._src.lax.lax import (add, bitwise_and, bitwise_not, bitwise_or,
|
||||
broadcast_in_dim, broadcast_shapes,
|
||||
convert_element_type, eq, exp, full_like,
|
||||
gt, le, log, log1p, lt, mul, neg, reciprocal,
|
||||
reduce, select, sign, square, standard_naryop,
|
||||
standard_unop, xla, xops, ne, div, sub, add,
|
||||
convert_element_type, div, eq, exp, full_like,
|
||||
gt, le, log, log1p, lt, mul, ne, neg, reciprocal,
|
||||
reduce, select, sign, sqrt, square,
|
||||
standard_naryop, standard_unop, sub, xla, xops,
|
||||
_broadcast_translate, _const, _dtype, _float,
|
||||
_nary_lower_hlo, _ones, _isnan, _reduce)
|
||||
from jax._src.lax.control_flow import while_loop
|
||||
from jax._src.lax.utils import (standard_translate)
|
||||
|
||||
from jax._src import dtypes
|
||||
from jax._src.interpreters import ad
|
||||
@ -399,6 +398,109 @@ def _up_and_broadcast(doit):
|
||||
return result
|
||||
return up_and_broadcast
|
||||
|
||||
|
||||
def evaluate_chebyshev_polynomial(x, coefficients):
|
||||
b0 = full_like(x,0)
|
||||
b1 = full_like(x,0)
|
||||
b2 = full_like(x,0)
|
||||
for c in coefficients:
|
||||
b2 = b1
|
||||
b1 = b0
|
||||
b0 = x * b1 - b2 + full_like(x, c)
|
||||
return 0.5 * (b0 - b2)
|
||||
|
||||
def _i0e_impl32(x):
|
||||
"""
|
||||
Computes an approximation to the modified Bessel function of the first kind,
|
||||
zeroth order. The following implementation follows Cephes' F32 and F64
|
||||
implementation of i0e.
|
||||
"""
|
||||
i0e_coeffs_a = np.array(
|
||||
[-1.30002500998624804212E-8, 6.04699502254191894932E-8,
|
||||
-2.67079385394061173391E-7, 1.11738753912010371815E-6,
|
||||
-4.41673835845875056359E-6, 1.64484480707288970893E-5,
|
||||
-5.75419501008210370398E-5, 1.88502885095841655729E-4,
|
||||
-5.76375574538582365885E-4, 1.63947561694133579842E-3,
|
||||
-4.32430999505057594430E-3, 1.05464603945949983183E-2,
|
||||
-2.37374148058994688156E-2, 4.93052842396707084878E-2,
|
||||
-9.49010970480476444210E-2, 1.71620901522208775349E-1,
|
||||
-3.04682672343198398683E-1, 6.76795274409476084995E-1]
|
||||
)
|
||||
i0e_coeffs_b = np.array(
|
||||
[3.39623202570838634515E-9, 2.26666899049817806459E-8,
|
||||
2.04891858946906374183E-7, 2.89137052083475648297E-6,
|
||||
6.88975834691682398426E-5, 3.36911647825569408990E-3,
|
||||
8.04490411014108831608E-1]
|
||||
)
|
||||
|
||||
x = abs(x)
|
||||
half = full_like(x, 0.5)
|
||||
two = full_like(x, 2.0)
|
||||
thirty_two = full_like(x, 32.0)
|
||||
|
||||
result_le_8 = evaluate_chebyshev_polynomial(half * x - two, i0e_coeffs_a)
|
||||
result_gt_8 = div(evaluate_chebyshev_polynomial(thirty_two / x - two,
|
||||
i0e_coeffs_b), sqrt(x))
|
||||
|
||||
return select(x <= 8.0, result_le_8, result_gt_8)
|
||||
|
||||
def _i0e_impl64(x):
|
||||
i0e_coeffs_a = np.array(
|
||||
[-4.41534164647933937950E-18, 3.33079451882223809783E-17,
|
||||
-2.43127984654795469359E-16, 1.71539128555513303061E-15,
|
||||
-1.16853328779934516808E-14, 7.67618549860493561688E-14,
|
||||
-4.85644678311192946090E-13, 2.95505266312963983461E-12,
|
||||
-1.72682629144155570723E-11, 9.67580903537323691224E-11,
|
||||
-5.18979560163526290666E-10, 2.65982372468238665035E-9,
|
||||
-1.30002500998624804212E-8, 6.04699502254191894932E-8,
|
||||
-2.67079385394061173391E-7, 1.11738753912010371815E-6,
|
||||
-4.41673835845875056359E-6, 1.64484480707288970893E-5,
|
||||
-5.75419501008210370398E-5, 1.88502885095841655729E-4,
|
||||
-5.76375574538582365885E-4, 1.63947561694133579842E-3,
|
||||
-4.32430999505057594430E-3, 1.05464603945949983183E-2,
|
||||
-2.37374148058994688156E-2, 4.93052842396707084878E-2,
|
||||
-9.49010970480476444210E-2, 1.71620901522208775349E-1,
|
||||
-3.04682672343198398683E-1, 6.76795274409476084995E-1]
|
||||
)
|
||||
i0e_coeffs_b = np.array(
|
||||
[-7.23318048787475395456E-18, -4.83050448594418207126E-18,
|
||||
4.46562142029675999901E-17, 3.46122286769746109310E-17,
|
||||
-2.82762398051658348494E-16, -3.42548561967721913462E-16,
|
||||
1.77256013305652638360E-15, 3.81168066935262242075E-15,
|
||||
-9.55484669882830764870E-15, -4.15056934728722208663E-14,
|
||||
1.54008621752140982691E-14, 3.85277838274214270114E-13,
|
||||
7.18012445138366623367E-13, -1.79417853150680611778E-12,
|
||||
-1.32158118404477131188E-11, -3.14991652796324136454E-11,
|
||||
1.18891471078464383424E-11, 4.94060238822496958910E-10,
|
||||
3.39623202570838634515E-9, 2.26666899049817806459E-8,
|
||||
2.04891858946906374183E-7, 2.89137052083475648297E-6,
|
||||
6.88975834691682398426E-5, 3.36911647825569408990E-3,
|
||||
8.04490411014108831608E-1]
|
||||
)
|
||||
|
||||
x = abs(x)
|
||||
half = full_like(x, 0.5)
|
||||
two = full_like(x, 2.0)
|
||||
thirty_two = full_like(x, 32.0)
|
||||
|
||||
result_le_8 = evaluate_chebyshev_polynomial(half * x - two, i0e_coeffs_a)
|
||||
result_gt_8 = div(evaluate_chebyshev_polynomial(thirty_two / x - two,
|
||||
i0e_coeffs_b), sqrt(x))
|
||||
|
||||
return select(x <= 8.0, result_le_8, result_gt_8)
|
||||
|
||||
def bessel_i0e_impl(x):
|
||||
if x.dtype == np.float64:
|
||||
return _i0e_impl64(x)
|
||||
elif x.dtype == np.float32:
|
||||
return _i0e_impl32(x)
|
||||
else:
|
||||
# Have to upcast f16 because the magic Cephes coefficents don't have enough
|
||||
# precision for it.
|
||||
x_dtype = x.dtype
|
||||
x = x.astype(np.float32)
|
||||
return convert_element_type(_i0e_impl32(x), x_dtype)
|
||||
|
||||
lgamma_p = standard_unop(_float, 'lgamma')
|
||||
ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x)))
|
||||
mlir.register_lowering(lgamma_p, partial(_nary_lower_hlo, chlo.LgammaOp))
|
||||
@ -430,7 +532,9 @@ mlir.register_lowering(random_gamma_grad_p,
|
||||
multiple_results=False))
|
||||
|
||||
bessel_i0e_p = standard_unop(_float, 'bessel_i0e')
|
||||
xla.register_translation(bessel_i0e_p, standard_translate(bessel_i0e_p))
|
||||
mlir.register_lowering(bessel_i0e_p,
|
||||
mlir.lower_fun(bessel_i0e_impl,
|
||||
multiple_results=False))
|
||||
ad.defjvp2(bessel_i0e_p, lambda g, y, x: g * (bessel_i1e(x) - sign(x) * y))
|
||||
|
||||
bessel_i1e_p = standard_unop(_float, 'bessel_i1e')
|
||||
|
@ -75,11 +75,6 @@ def main(_):
|
||||
# CHECK-SAME: tensor<f64>
|
||||
print_ir(np.float64(1), np.float64(2))(lax.atan2)
|
||||
|
||||
# CHECK-LABEL: TEST: bessel_i0e float32[]
|
||||
# CHECK: xla_fallback_bessel_i0e
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.bessel_i0e)
|
||||
|
||||
# CHECK-LABEL: TEST: bessel_i1e float32[]
|
||||
# CHECK: chlo.bessel_i1e
|
||||
# CHECK-SAME: tensor<f32>
|
||||
|
Loading…
x
Reference in New Issue
Block a user