Migrate regularized_incomplete_beta_p off xla_fallback

PiperOrigin-RevId: 519244597
This commit is contained in:
Anish Tondwalkar 2023-03-24 14:52:45 -07:00 committed by jax authors
parent ac44d2c2e3
commit 6842e98ca1
2 changed files with 135 additions and 32 deletions

View File

@ -23,12 +23,12 @@ from functools import partial
from jax._src.lax.lax import (add, bitwise_and, bitwise_not, bitwise_or,
broadcast_in_dim, broadcast_shapes,
convert_element_type, div, eq, exp, full_like,
convert_element_type, div, eq, exp, full_like, ge,
gt, le, log, log1p, lt, mul, ne, neg, reciprocal,
reduce, select, sign, sqrt, square,
standard_naryop, standard_unop, sub, xla, xops,
_broadcast_translate, _const, _dtype, _float,
_nary_lower_hlo, _ones, _isnan, _reduce)
standard_naryop, standard_unop, sub,
_const, _dtype,
_float, _nary_lower_hlo, _ones, _isnan, _reduce)
from jax._src.lax.control_flow import while_loop
from jax._src import dtypes
@ -90,13 +90,6 @@ def erf_inv(x: ArrayLike) -> Array:
r"""Elementwise inverse error function: :math:`\mathrm{erf}^{-1}(x)`."""
return erf_inv_p.bind(x)
regularized_incomplete_beta_p = standard_naryop(
[_float, _float, _float], 'regularized_incomplete_beta')
xla.register_translation(
regularized_incomplete_beta_p,
partial(_broadcast_translate, xops.RegularizedIncompleteBeta))
def betainc_gradx(g, a, b, x):
lbeta = lgamma(a) + lgamma(b) - lgamma(a + b)
partial_x = exp((b - 1) * log1p(-x) +
@ -106,11 +99,6 @@ def betainc_gradx(g, a, b, x):
def betainc_grad_not_implemented(g, a, b, x):
raise ValueError("Betainc gradient with respect to a and b not supported.")
ad.defjvp(regularized_incomplete_beta_p,
betainc_grad_not_implemented,
betainc_grad_not_implemented,
betainc_gradx)
def igamma_gradx(g, a, x):
return g * exp(-x + (a - _ones(a)) * log(x) - lgamma(a))
@ -126,6 +114,117 @@ def igammac_grada(g, a, x):
# The below is directly ported from tensorflow/compiler/xla/client/lib/math.cc
# We try to follow the corresponding functions as closely as possible, so that
# we can quickly incorporate changes.
def lentz_thompson_barnett_algorithm(*,num_iterations, small, threshold, nth_partial_numerator, nth_partial_denominator, inputs):
# Position in the evaluation.
kIterationIdx = 0
# Whether or not we have reached the desired tolerance.
kValuesUnconvergedIdx = 1
# Ratio between nth canonical numerator and the nth-1 canonical numerator.
kCIdx = 2
# Ratio between nth-1 canonical denominator and the nth canonical denominator.
kDIdx = 3
# Computed approximant in the evaluation.
kHIdx = 4
def while_cond_fn(values):
iteration = values[kIterationIdx]
iterations_remain_cond = lt(iteration, num_iterations)
values_unconverged_cond = values[kValuesUnconvergedIdx]
return bitwise_and(iterations_remain_cond, values_unconverged_cond)
def while_body_fn(values):
iteration = values[kIterationIdx]
partial_numerator = nth_partial_numerator(iteration, *inputs)
partial_denominator = nth_partial_denominator(iteration, *inputs)
c = add(partial_denominator, div(partial_numerator, values[kCIdx]))
small_constant = full_like(c, small)
c = select(lt(abs(c), small_constant), small_constant, c)
d = add(partial_denominator, mul(partial_numerator, values[kDIdx]))
d = select(lt(abs(d), small_constant), small_constant, d)
d = reciprocal(d)
delta = mul(c, d)
h = mul(values[kHIdx], delta)
# Update values
values[kIterationIdx] = iteration + 1
values[kCIdx] = c
values[kDIdx] = d
values[kHIdx] = h
# If any values are greater than the tolerance, we have not converged.
tolerance_comparison = ge(abs(sub(delta, _const(delta, 1.0))), threshold)
values[kValuesUnconvergedIdx] = _any(tolerance_comparison)
return values
partial_denominator = nth_partial_denominator(0, *inputs)
h = select(lt(abs(partial_denominator), small),
broadcast_in_dim(small, partial_denominator.shape, ()),
partial_denominator)
values = [1,True,h,full_like(h,0),h]
values = while_loop(while_cond_fn, while_body_fn, values)
return values[kHIdx]
def regularized_incomplete_beta_impl(a, b, x, dtype):
shape = a.shape
def nth_partial_betainc_numerator(iteration, a, b, x):
"""
The partial numerator for the incomplete beta function is given
here: http://dlmf.nist.gov/8.17.E23 Note that there is a special
case: the partial numerator for the first iteration is one.
"""
iteration_bcast = broadcast_in_dim(iteration, shape, [])
iteration_is_even = eq(iteration_bcast % full_like(iteration_bcast, 2),
full_like(iteration_bcast, 0))
iteration_is_one = eq(iteration_bcast, full_like(iteration_bcast, 1))
iteration_minus_one = iteration_bcast - full_like(iteration_bcast, 1)
m = iteration_minus_one // full_like(iteration_minus_one, 2)
m = convert_element_type(m, dtype)
one = full_like(a, 1)
two = full_like(a, 2.0)
# Partial numerator terms
even_numerator = -(a + m) * (a + b + m) * x / (
(a + two * m) * (a + two * m + one))
odd_numerator = m * (b - m) * x / ((a + two * m - one) * (a + two * m))
one_numerator = full_like(x, 1.0)
numerator = select(iteration_is_even, even_numerator, odd_numerator)
return select(iteration_is_one, one_numerator, numerator)
def nth_partial_betainc_denominator(iteration, a, b, x):
iteration_bcast = broadcast_in_dim(iteration, shape, [])
return select(eq(iteration_bcast, full_like(iteration_bcast, 0)),
full_like(x, 0), full_like(x, 1))
result_is_nan = bitwise_or(bitwise_or(bitwise_or(
le(a, full_like(a, 0)), le(b, full_like(b, 0))),
lt(x, full_like(x, 0))), gt(x, full_like(x, 1)))
# The continued fraction will converge rapidly when x < (a+1)/(a+b+2)
# as per: http://dlmf.nist.gov/8.17.E23
#
# Otherwise, we can rewrite using the symmetry relation as per:
# http://dlmf.nist.gov/8.17.E4
converges_rapidly = lt(x, (a + full_like(a, 1)) / (a + b + full_like(b, 2.0)))
a_orig = a
a = select(converges_rapidly, a, b)
b = select(converges_rapidly, b, a_orig)
x = select(converges_rapidly, x, sub(full_like(x, 1), x))
continued_fraction = lentz_thompson_barnett_algorithm(
num_iterations=200 if dtype == np.float32 else 600,
small=(dtypes.finfo(dtype).eps / 2).astype(dtype),
threshold=(dtypes.finfo(dtype).eps / 2).astype(dtype),
nth_partial_numerator=nth_partial_betainc_numerator,
nth_partial_denominator=nth_partial_betainc_denominator,
inputs=[a, b, x]
)
lbeta_ab = lgamma(a) + lgamma(b) - lgamma(a + b)
result = continued_fraction * exp(log(x) * a + log1p(-x) * b - lbeta_ab) / a
result = select(result_is_nan, full_like(a, float('nan')), result)
return select(converges_rapidly, result, sub(full_like(result, 1), result))
class IgammaMode(Enum):
VALUE = 1
DERIVATIVE = 2
@ -378,21 +477,18 @@ def random_gamma_grad_impl(a, x, dtype):
return output
def _up_and_broadcast(doit):
def up_and_broadcast(a, x):
broadcasted_shape = broadcast_shapes(a.shape, x.shape)
a = broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim)))
x = broadcast_in_dim(x, broadcasted_shape, list(range(x.ndim)))
def up_and_broadcast(*args):
broadcasted_shape = broadcast_shapes(*(a.shape for a in args))
args = [broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim))) for a in args]
needs_upcast = a.dtype == dtypes.bfloat16 or a.dtype == np.float16
a_dtype = args[0].dtype
needs_upcast = a_dtype == dtypes.bfloat16 or a_dtype == np.float16
if needs_upcast:
a_dtype = a.dtype
a = convert_element_type(a, np.float32)
x = convert_element_type(x, np.float32)
args = [convert_element_type(a, np.float32) for a in args]
a_x_type = np.float32
else:
a_x_type = a.dtype
result = doit(a, x, a_x_type)
a_x_type = a_dtype
result = doit(*args, a_x_type)
if needs_upcast:
result = convert_element_type(result, a_dtype)
return result
@ -501,6 +597,18 @@ def bessel_i0e_impl(x):
x = x.astype(np.float32)
return convert_element_type(_i0e_impl32(x), x_dtype)
regularized_incomplete_beta_p = standard_naryop(
[_float, _float, _float], 'regularized_incomplete_beta')
mlir.register_lowering(regularized_incomplete_beta_p,
mlir.lower_fun(_up_and_broadcast(regularized_incomplete_beta_impl),
multiple_results=False))
ad.defjvp(regularized_incomplete_beta_p,
betainc_grad_not_implemented,
betainc_grad_not_implemented,
betainc_gradx)
lgamma_p = standard_unop(_float, 'lgamma')
ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x)))
mlir.register_lowering(lgamma_p, partial(_nary_lower_hlo, chlo.LgammaOp))

View File

@ -80,11 +80,6 @@ def main(_):
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.bessel_i1e)
# CHECK-LABEL: TEST: betainc float32[] float32[] float32[]
# CHECK: xla_fallback_regularized_incomplete_beta
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0), np.float32(0), np.float32(0))(lax.betainc)
# CHECK-LABEL: TEST: bitcast_convert_type uint32[7]
# CHECK: hlo.bitcast_convert
# CHECK-SAME: tensor<7xui32>