mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Migrate regularized_incomplete_beta_p off xla_fallback
PiperOrigin-RevId: 519244597
This commit is contained in:
parent
ac44d2c2e3
commit
6842e98ca1
@ -23,12 +23,12 @@ from functools import partial
|
||||
|
||||
from jax._src.lax.lax import (add, bitwise_and, bitwise_not, bitwise_or,
|
||||
broadcast_in_dim, broadcast_shapes,
|
||||
convert_element_type, div, eq, exp, full_like,
|
||||
convert_element_type, div, eq, exp, full_like, ge,
|
||||
gt, le, log, log1p, lt, mul, ne, neg, reciprocal,
|
||||
reduce, select, sign, sqrt, square,
|
||||
standard_naryop, standard_unop, sub, xla, xops,
|
||||
_broadcast_translate, _const, _dtype, _float,
|
||||
_nary_lower_hlo, _ones, _isnan, _reduce)
|
||||
standard_naryop, standard_unop, sub,
|
||||
_const, _dtype,
|
||||
_float, _nary_lower_hlo, _ones, _isnan, _reduce)
|
||||
from jax._src.lax.control_flow import while_loop
|
||||
|
||||
from jax._src import dtypes
|
||||
@ -90,13 +90,6 @@ def erf_inv(x: ArrayLike) -> Array:
|
||||
r"""Elementwise inverse error function: :math:`\mathrm{erf}^{-1}(x)`."""
|
||||
return erf_inv_p.bind(x)
|
||||
|
||||
|
||||
regularized_incomplete_beta_p = standard_naryop(
|
||||
[_float, _float, _float], 'regularized_incomplete_beta')
|
||||
xla.register_translation(
|
||||
regularized_incomplete_beta_p,
|
||||
partial(_broadcast_translate, xops.RegularizedIncompleteBeta))
|
||||
|
||||
def betainc_gradx(g, a, b, x):
|
||||
lbeta = lgamma(a) + lgamma(b) - lgamma(a + b)
|
||||
partial_x = exp((b - 1) * log1p(-x) +
|
||||
@ -106,11 +99,6 @@ def betainc_gradx(g, a, b, x):
|
||||
def betainc_grad_not_implemented(g, a, b, x):
|
||||
raise ValueError("Betainc gradient with respect to a and b not supported.")
|
||||
|
||||
ad.defjvp(regularized_incomplete_beta_p,
|
||||
betainc_grad_not_implemented,
|
||||
betainc_grad_not_implemented,
|
||||
betainc_gradx)
|
||||
|
||||
def igamma_gradx(g, a, x):
|
||||
return g * exp(-x + (a - _ones(a)) * log(x) - lgamma(a))
|
||||
|
||||
@ -126,6 +114,117 @@ def igammac_grada(g, a, x):
|
||||
# The below is directly ported from tensorflow/compiler/xla/client/lib/math.cc
|
||||
# We try to follow the corresponding functions as closely as possible, so that
|
||||
# we can quickly incorporate changes.
|
||||
def lentz_thompson_barnett_algorithm(*,num_iterations, small, threshold, nth_partial_numerator, nth_partial_denominator, inputs):
|
||||
# Position in the evaluation.
|
||||
kIterationIdx = 0
|
||||
# Whether or not we have reached the desired tolerance.
|
||||
kValuesUnconvergedIdx = 1
|
||||
# Ratio between nth canonical numerator and the nth-1 canonical numerator.
|
||||
kCIdx = 2
|
||||
# Ratio between nth-1 canonical denominator and the nth canonical denominator.
|
||||
kDIdx = 3
|
||||
# Computed approximant in the evaluation.
|
||||
kHIdx = 4
|
||||
|
||||
def while_cond_fn(values):
|
||||
iteration = values[kIterationIdx]
|
||||
iterations_remain_cond = lt(iteration, num_iterations)
|
||||
values_unconverged_cond = values[kValuesUnconvergedIdx]
|
||||
return bitwise_and(iterations_remain_cond, values_unconverged_cond)
|
||||
|
||||
def while_body_fn(values):
|
||||
iteration = values[kIterationIdx]
|
||||
partial_numerator = nth_partial_numerator(iteration, *inputs)
|
||||
partial_denominator = nth_partial_denominator(iteration, *inputs)
|
||||
|
||||
c = add(partial_denominator, div(partial_numerator, values[kCIdx]))
|
||||
small_constant = full_like(c, small)
|
||||
c = select(lt(abs(c), small_constant), small_constant, c)
|
||||
d = add(partial_denominator, mul(partial_numerator, values[kDIdx]))
|
||||
d = select(lt(abs(d), small_constant), small_constant, d)
|
||||
d = reciprocal(d)
|
||||
delta = mul(c, d)
|
||||
h = mul(values[kHIdx], delta)
|
||||
|
||||
# Update values
|
||||
values[kIterationIdx] = iteration + 1
|
||||
values[kCIdx] = c
|
||||
values[kDIdx] = d
|
||||
values[kHIdx] = h
|
||||
# If any values are greater than the tolerance, we have not converged.
|
||||
tolerance_comparison = ge(abs(sub(delta, _const(delta, 1.0))), threshold)
|
||||
values[kValuesUnconvergedIdx] = _any(tolerance_comparison)
|
||||
return values
|
||||
|
||||
partial_denominator = nth_partial_denominator(0, *inputs)
|
||||
h = select(lt(abs(partial_denominator), small),
|
||||
broadcast_in_dim(small, partial_denominator.shape, ()),
|
||||
partial_denominator)
|
||||
values = [1,True,h,full_like(h,0),h]
|
||||
values = while_loop(while_cond_fn, while_body_fn, values)
|
||||
return values[kHIdx]
|
||||
|
||||
|
||||
def regularized_incomplete_beta_impl(a, b, x, dtype):
|
||||
shape = a.shape
|
||||
|
||||
def nth_partial_betainc_numerator(iteration, a, b, x):
|
||||
"""
|
||||
The partial numerator for the incomplete beta function is given
|
||||
here: http://dlmf.nist.gov/8.17.E23 Note that there is a special
|
||||
case: the partial numerator for the first iteration is one.
|
||||
"""
|
||||
iteration_bcast = broadcast_in_dim(iteration, shape, [])
|
||||
iteration_is_even = eq(iteration_bcast % full_like(iteration_bcast, 2),
|
||||
full_like(iteration_bcast, 0))
|
||||
iteration_is_one = eq(iteration_bcast, full_like(iteration_bcast, 1))
|
||||
iteration_minus_one = iteration_bcast - full_like(iteration_bcast, 1)
|
||||
m = iteration_minus_one // full_like(iteration_minus_one, 2)
|
||||
m = convert_element_type(m, dtype)
|
||||
one = full_like(a, 1)
|
||||
two = full_like(a, 2.0)
|
||||
# Partial numerator terms
|
||||
even_numerator = -(a + m) * (a + b + m) * x / (
|
||||
(a + two * m) * (a + two * m + one))
|
||||
odd_numerator = m * (b - m) * x / ((a + two * m - one) * (a + two * m))
|
||||
one_numerator = full_like(x, 1.0)
|
||||
numerator = select(iteration_is_even, even_numerator, odd_numerator)
|
||||
return select(iteration_is_one, one_numerator, numerator)
|
||||
|
||||
def nth_partial_betainc_denominator(iteration, a, b, x):
|
||||
iteration_bcast = broadcast_in_dim(iteration, shape, [])
|
||||
return select(eq(iteration_bcast, full_like(iteration_bcast, 0)),
|
||||
full_like(x, 0), full_like(x, 1))
|
||||
|
||||
result_is_nan = bitwise_or(bitwise_or(bitwise_or(
|
||||
le(a, full_like(a, 0)), le(b, full_like(b, 0))),
|
||||
lt(x, full_like(x, 0))), gt(x, full_like(x, 1)))
|
||||
|
||||
# The continued fraction will converge rapidly when x < (a+1)/(a+b+2)
|
||||
# as per: http://dlmf.nist.gov/8.17.E23
|
||||
#
|
||||
# Otherwise, we can rewrite using the symmetry relation as per:
|
||||
# http://dlmf.nist.gov/8.17.E4
|
||||
converges_rapidly = lt(x, (a + full_like(a, 1)) / (a + b + full_like(b, 2.0)))
|
||||
a_orig = a
|
||||
a = select(converges_rapidly, a, b)
|
||||
b = select(converges_rapidly, b, a_orig)
|
||||
x = select(converges_rapidly, x, sub(full_like(x, 1), x))
|
||||
|
||||
continued_fraction = lentz_thompson_barnett_algorithm(
|
||||
num_iterations=200 if dtype == np.float32 else 600,
|
||||
small=(dtypes.finfo(dtype).eps / 2).astype(dtype),
|
||||
threshold=(dtypes.finfo(dtype).eps / 2).astype(dtype),
|
||||
nth_partial_numerator=nth_partial_betainc_numerator,
|
||||
nth_partial_denominator=nth_partial_betainc_denominator,
|
||||
inputs=[a, b, x]
|
||||
)
|
||||
|
||||
lbeta_ab = lgamma(a) + lgamma(b) - lgamma(a + b)
|
||||
result = continued_fraction * exp(log(x) * a + log1p(-x) * b - lbeta_ab) / a
|
||||
result = select(result_is_nan, full_like(a, float('nan')), result)
|
||||
return select(converges_rapidly, result, sub(full_like(result, 1), result))
|
||||
|
||||
class IgammaMode(Enum):
|
||||
VALUE = 1
|
||||
DERIVATIVE = 2
|
||||
@ -378,21 +477,18 @@ def random_gamma_grad_impl(a, x, dtype):
|
||||
return output
|
||||
|
||||
def _up_and_broadcast(doit):
|
||||
def up_and_broadcast(a, x):
|
||||
broadcasted_shape = broadcast_shapes(a.shape, x.shape)
|
||||
a = broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim)))
|
||||
x = broadcast_in_dim(x, broadcasted_shape, list(range(x.ndim)))
|
||||
def up_and_broadcast(*args):
|
||||
broadcasted_shape = broadcast_shapes(*(a.shape for a in args))
|
||||
args = [broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim))) for a in args]
|
||||
|
||||
needs_upcast = a.dtype == dtypes.bfloat16 or a.dtype == np.float16
|
||||
a_dtype = args[0].dtype
|
||||
needs_upcast = a_dtype == dtypes.bfloat16 or a_dtype == np.float16
|
||||
if needs_upcast:
|
||||
a_dtype = a.dtype
|
||||
a = convert_element_type(a, np.float32)
|
||||
x = convert_element_type(x, np.float32)
|
||||
args = [convert_element_type(a, np.float32) for a in args]
|
||||
a_x_type = np.float32
|
||||
else:
|
||||
a_x_type = a.dtype
|
||||
|
||||
result = doit(a, x, a_x_type)
|
||||
a_x_type = a_dtype
|
||||
result = doit(*args, a_x_type)
|
||||
if needs_upcast:
|
||||
result = convert_element_type(result, a_dtype)
|
||||
return result
|
||||
@ -501,6 +597,18 @@ def bessel_i0e_impl(x):
|
||||
x = x.astype(np.float32)
|
||||
return convert_element_type(_i0e_impl32(x), x_dtype)
|
||||
|
||||
|
||||
regularized_incomplete_beta_p = standard_naryop(
|
||||
[_float, _float, _float], 'regularized_incomplete_beta')
|
||||
mlir.register_lowering(regularized_incomplete_beta_p,
|
||||
mlir.lower_fun(_up_and_broadcast(regularized_incomplete_beta_impl),
|
||||
multiple_results=False))
|
||||
|
||||
ad.defjvp(regularized_incomplete_beta_p,
|
||||
betainc_grad_not_implemented,
|
||||
betainc_grad_not_implemented,
|
||||
betainc_gradx)
|
||||
|
||||
lgamma_p = standard_unop(_float, 'lgamma')
|
||||
ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x)))
|
||||
mlir.register_lowering(lgamma_p, partial(_nary_lower_hlo, chlo.LgammaOp))
|
||||
|
@ -80,11 +80,6 @@ def main(_):
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.bessel_i1e)
|
||||
|
||||
# CHECK-LABEL: TEST: betainc float32[] float32[] float32[]
|
||||
# CHECK: xla_fallback_regularized_incomplete_beta
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0), np.float32(0), np.float32(0))(lax.betainc)
|
||||
|
||||
# CHECK-LABEL: TEST: bitcast_convert_type uint32[7]
|
||||
# CHECK: hlo.bitcast_convert
|
||||
# CHECK-SAME: tensor<7xui32>
|
||||
|
Loading…
x
Reference in New Issue
Block a user