mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Migrate random_gamma_grad off xla_fallback
PiperOrigin-RevId: 519154537
This commit is contained in:
parent
8d1d522618
commit
8c75e27f67
@ -320,7 +320,6 @@ def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode):
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {mode}")
|
||||
|
||||
|
||||
def igammac_impl(a, x, dtype):
|
||||
out_of_range = bitwise_or(le(x, _const(x, 0)), le(a, _const(a, 0)))
|
||||
use_igamma = bitwise_or(lt(x, _const(x, 1)), lt(x, a))
|
||||
@ -359,6 +358,26 @@ def igamma_grad_a_impl(a, x, dtype):
|
||||
full_like(a, float('nan')), output)
|
||||
return output
|
||||
|
||||
def random_gamma_grad_impl(a, x, dtype):
|
||||
is_nan = bitwise_or(_isnan(a), _isnan(x))
|
||||
x_is_zero = eq(x, full_like(x,0))
|
||||
domain_error = bitwise_or(lt(x, full_like(x,0)), le(a, full_like(a,0)))
|
||||
use_igammac = bitwise_and(gt(x, full_like(x,1)), gt(x, a))
|
||||
ax = a * log(x) - x - lgamma(a)
|
||||
underflow = lt(ax, -log(dtypes.finfo(a.dtype).max))
|
||||
ax = exp(ax)
|
||||
enabled = bitwise_not(bitwise_or(bitwise_or(bitwise_or
|
||||
(x_is_zero, domain_error), underflow), is_nan))
|
||||
output = select(use_igammac,
|
||||
-_igammac_continued_fraction(ax, x, a, bitwise_and(enabled, use_igammac),
|
||||
dtype, IgammaMode.SAMPLE_DERIVATIVE),
|
||||
_igamma_series(ax, x, a, bitwise_and(enabled, bitwise_not(use_igammac)),
|
||||
dtype, IgammaMode.SAMPLE_DERIVATIVE))
|
||||
output = select(x_is_zero, full_like(output,0), output)
|
||||
output = select(bitwise_or(domain_error, is_nan),
|
||||
full_like(a, float('nan')), output)
|
||||
return output
|
||||
|
||||
def _up_and_broadcast(doit):
|
||||
def up_and_broadcast(a, x):
|
||||
broadcasted_shape = broadcast_shapes(a.shape, x.shape)
|
||||
@ -406,8 +425,9 @@ mlir.register_lowering(igammac_p,
|
||||
ad.defjvp(igammac_p, igammac_grada, igammac_gradx)
|
||||
|
||||
random_gamma_grad_p = standard_naryop([_float, _float], 'random_gamma_grad')
|
||||
xla.register_translation(random_gamma_grad_p,
|
||||
partial(_broadcast_translate, xops.RandomGammaGrad))
|
||||
mlir.register_lowering(random_gamma_grad_p,
|
||||
mlir.lower_fun(_up_and_broadcast(random_gamma_grad_impl),
|
||||
multiple_results=False))
|
||||
|
||||
bessel_i0e_p = standard_unop(_float, 'bessel_i0e')
|
||||
xla.register_translation(bessel_i0e_p, standard_translate(bessel_i0e_p))
|
||||
|
@ -356,11 +356,6 @@ def main(_):
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.pow)
|
||||
|
||||
# CHECK-LABEL: TEST: random_gamma_grad float32[] float32[]
|
||||
# CHECK: xla_fallback_random_gamma_grad
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0), np.float32(0))(lax.random_gamma_grad)
|
||||
|
||||
# CHECK-LABEL: TEST: real complex128[]
|
||||
# CHECK: hlo.real
|
||||
# CHECK-SAME: tensor<complex<f64>>
|
||||
|
Loading…
x
Reference in New Issue
Block a user