Migrate igammac_p off xla_fallback path

It is now decomposed into stablehlo ops.

PiperOrigin-RevId: 519122775
This commit is contained in:
Anish Tondwalkar 2023-03-24 05:57:59 -07:00 committed by jax authors
parent 8081031c90
commit 4a9b09485e
2 changed files with 144 additions and 10 deletions

View File

@ -26,7 +26,7 @@ from jax._src.lax.lax import (bitwise_and, bitwise_not, bitwise_or,
convert_element_type, eq, exp, full_like,
gt, le, log, log1p, lt, mul, neg, reciprocal,
reduce, select, sign, square, standard_naryop,
standard_unop, xla, xops,
standard_unop, xla, xops, ne, div, sub, add,
_broadcast_translate, _const, _dtype, _float,
_nary_lower_hlo, _ones, _isnan, _reduce)
from jax._src.lax.control_flow import while_loop
@ -234,9 +234,147 @@ def igamma_impl(a, x):
return result
def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode):
# TODO(atondwal): implement _igammac_continued_fraction in JAX.
# Right now we fallback to the XLA implementation of IgammacContinuedFraction.
return igammac(a, x)
eps = dtypes.finfo(dtype).eps
def cond_fn(vals):
enabled, _ans, _t, _y, _x, c, *_ = vals
return bitwise_and(c < _const(c, 2000), _any(enabled))
def body_fn(vals):
(enabled, ans, t, y, z, c, pkm1, qkm1, pkm2, qkm2,
dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da) = vals
c = c + _const(c, 1)
y = y + _const(y, 1)
z = z + _const(z, 2)
yc = y * c
pk = pkm1 * z - pkm2 * yc
qk = qkm1 * z - qkm2 * yc
qk_is_nonzero = ne(qk, _const(qk, 0))
r = pk / qk
t = select(qk_is_nonzero, abs(div(sub(ans, r), r)), full_like(r, 1))
ans = select(qk_is_nonzero, r, ans)
dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c
dqk_da = dqkm1_da * z - qkm1 - dqkm2_da * yc + qkm2 * c
dans_da_new = select(qk_is_nonzero, div(dpk_da - ans * dqk_da, qk), dans_da)
grad_conditional = select(qk_is_nonzero,
abs(dans_da_new - dans_da),
full_like(dans_da, 1))
pkm2 = pkm1
pkm1 = pk
qkm2 = qkm1
qkm1 = qk
dpkm2_da = dpkm1_da
dqkm2_da = dqkm1_da
dpkm1_da = dpk_da
dqkm1_da = dqk_da
rescale = gt(abs(pk), reciprocal(_const(pk, eps)))
pkm2 = select(rescale, mul(pkm2, _const(pkm2, eps)), pkm2)
pkm1 = select(rescale, mul(pkm1, _const(pkm1, eps)), pkm1)
qkm2 = select(rescale, mul(qkm2, _const(qkm2, eps)), qkm2)
qkm1 = select(rescale, mul(qkm1, _const(qkm1, eps)), qkm1)
dpkm2_da = select(rescale, mul(dpkm2_da, _const(dpkm2_da, eps)), dpkm2_da)
dqkm2_da = select(rescale, mul(dqkm2_da, _const(dqkm2_da, eps)), dqkm2_da)
dpkm1_da = select(rescale, mul(dpkm1_da, _const(dpkm1_da, eps)), dpkm1_da)
dqkm1_da = select(rescale, mul(dqkm1_da, _const(dqkm1_da, eps)), dqkm1_da)
if mode == IgammaMode.VALUE:
conditional = bitwise_and(enabled, t > eps)
else:
conditional = bitwise_and(enabled,
grad_conditional > _const(grad_conditional, eps))
return (conditional,
select(enabled, ans, vals[1]),
select(enabled, t, vals[2]),
select(enabled, y, vals[3]),
select(enabled, z, vals[4]),
c,
select(enabled, pkm1, vals[6]),
select(enabled, qkm1, vals[7]),
select(enabled, pkm2, vals[8]),
select(enabled, qkm2, vals[9]),
select(enabled, dpkm2_da, vals[10]),
select(enabled, dqkm2_da, vals[11]),
select(enabled, dpkm1_da, vals[12]),
select(enabled, dqkm1_da, vals[13]),
select(enabled, dans_da_new, vals[14]))
y = _const(a, 1) - a
z = x + y + _const(x, 1)
c = _const(x, 0)
pkm2 = full_like(x, 1)
qkm2 = x
pkm1 = x + _const(x, 1)
qkm1 = z * x
ans = pkm1 / qkm1
t = full_like(x, 1)
dpkm2_da = full_like(x, 0)
dqkm2_da = full_like(x, 0)
dpkm1_da = full_like(x, 0)
dqkm1_da = -x
dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1
init_vals = (enabled, ans, t, y, z,
c, pkm1, qkm1, pkm2, qkm2,
dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da)
vals = while_loop(cond_fn, body_fn, init_vals)
ans = vals[1]
if mode == IgammaMode.VALUE:
return ans * ax
dans_da = vals[14]
dlogax_da = log(x) - digamma(a)
if mode == IgammaMode.DERIVATIVE:
return mul(ax, add(mul(ans, dlogax_da), dans_da))
elif mode == IgammaMode.SAMPLE_DERIVATIVE:
return neg(add(dans_da, mul(ans, dlogax_da)) * x)
else:
raise ValueError(f"Invalid mode: {mode}")
def igammac_impl(a, x):
broadcasted_shape = broadcast_shapes(a.shape, x.shape)
a = broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim)))
x = broadcast_in_dim(x, broadcasted_shape, list(range(x.ndim)))
def doit(a, x, dtype):
out_of_range = bitwise_or(le(x, _const(x, 0)), le(a, _const(a, 0)))
use_igamma = bitwise_or(lt(x, _const(x, 1)), lt(x, a))
ax = a * log(x) - x - lgamma(a)
underflow = lt(ax, -log(dtypes.finfo(dtype).max))
enabled = bitwise_not(bitwise_or(out_of_range, underflow))
ax = exp(ax)
igamma_call = _igamma_series(ax, x, a, bitwise_and(enabled, use_igamma),
dtype, IgammaMode.VALUE)
igammac_cf_call = _igammac_continued_fraction(ax, x, a,
bitwise_and(enabled, bitwise_not(use_igamma)), dtype, IgammaMode.VALUE)
result = select(use_igamma, _const(a, 1) - igamma_call, igammac_cf_call)
x_is_infinity = eq(x, _const(x, float('inf')))
result = select(x_is_infinity, full_like(result, 0), result);
return select(out_of_range, full_like(a, 1), result);
needs_upcast = a.dtype == dtypes.bfloat16 or a.dtype == np.float16
if needs_upcast:
a_dtype = a.dtype
a = convert_element_type(a, np.float32)
x = convert_element_type(x, np.float32)
a_x_type = np.float32
else:
a_x_type = a.dtype
result = doit(a, x, a_x_type)
if needs_upcast:
result = convert_element_type(result, a_dtype)
return result
lgamma_p = standard_unop(_float, 'lgamma')
ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x)))
@ -256,7 +394,8 @@ xla.register_translation(igamma_grad_a_p,
ad.defjvp(igamma_p, igamma_grada, igamma_gradx)
igammac_p = standard_naryop([_float, _float], 'igammac')
xla.register_translation(igammac_p, partial(_broadcast_translate, xops.Igammac))
mlir.register_lowering(igammac_p,
mlir.lower_fun(igammac_impl, multiple_results=False))
ad.defjvp(igammac_p, igammac_grada, igammac_gradx)

View File

@ -262,11 +262,6 @@ def main(_):
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.gt)
# CHECK-LABEL: TEST: igammac float32[] float32[]
# CHECK: xla_fallback_igammac
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0), np.float32(0))(lax.igammac)
# CHECK-LABEL: TEST: igamma_grad_a float32[] float32[]
# CHECK: xla_fallback_igamma_grad_a
# CHECK-SAME: tensor<f32>