Migrate random_gamma_grad off xla_fallback

PiperOrigin-RevId: 519154537
This commit is contained in:
Anish Tondwalkar 2023-03-24 08:48:55 -07:00 committed by jax authors
parent 8d1d522618
commit 8c75e27f67
2 changed files with 23 additions and 8 deletions

View File

@ -320,7 +320,6 @@ def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode):
else:
raise ValueError(f"Invalid mode: {mode}")
def igammac_impl(a, x, dtype):
out_of_range = bitwise_or(le(x, _const(x, 0)), le(a, _const(a, 0)))
use_igamma = bitwise_or(lt(x, _const(x, 1)), lt(x, a))
@ -359,6 +358,26 @@ def igamma_grad_a_impl(a, x, dtype):
full_like(a, float('nan')), output)
return output
def random_gamma_grad_impl(a, x, dtype):
is_nan = bitwise_or(_isnan(a), _isnan(x))
x_is_zero = eq(x, full_like(x,0))
domain_error = bitwise_or(lt(x, full_like(x,0)), le(a, full_like(a,0)))
use_igammac = bitwise_and(gt(x, full_like(x,1)), gt(x, a))
ax = a * log(x) - x - lgamma(a)
underflow = lt(ax, -log(dtypes.finfo(a.dtype).max))
ax = exp(ax)
enabled = bitwise_not(bitwise_or(bitwise_or(bitwise_or
(x_is_zero, domain_error), underflow), is_nan))
output = select(use_igammac,
-_igammac_continued_fraction(ax, x, a, bitwise_and(enabled, use_igammac),
dtype, IgammaMode.SAMPLE_DERIVATIVE),
_igamma_series(ax, x, a, bitwise_and(enabled, bitwise_not(use_igammac)),
dtype, IgammaMode.SAMPLE_DERIVATIVE))
output = select(x_is_zero, full_like(output,0), output)
output = select(bitwise_or(domain_error, is_nan),
full_like(a, float('nan')), output)
return output
def _up_and_broadcast(doit):
def up_and_broadcast(a, x):
broadcasted_shape = broadcast_shapes(a.shape, x.shape)
@ -406,8 +425,9 @@ mlir.register_lowering(igammac_p,
ad.defjvp(igammac_p, igammac_grada, igammac_gradx)
random_gamma_grad_p = standard_naryop([_float, _float], 'random_gamma_grad')
xla.register_translation(random_gamma_grad_p,
partial(_broadcast_translate, xops.RandomGammaGrad))
mlir.register_lowering(random_gamma_grad_p,
mlir.lower_fun(_up_and_broadcast(random_gamma_grad_impl),
multiple_results=False))
bessel_i0e_p = standard_unop(_float, 'bessel_i0e')
xla.register_translation(bessel_i0e_p, standard_translate(bessel_i0e_p))

View File

@ -356,11 +356,6 @@ def main(_):
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.pow)
# CHECK-LABEL: TEST: random_gamma_grad float32[] float32[]
# CHECK: xla_fallback_random_gamma_grad
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0), np.float32(0))(lax.random_gamma_grad)
# CHECK-LABEL: TEST: real complex128[]
# CHECK: hlo.real
# CHECK-SAME: tensor<complex<f64>>