Migrate igamma_p off xla_fallback

We decompose it into a series or a call to igammac.

PiperOrigin-RevId: 518993077
This commit is contained in:
Anish Tondwalkar 2023-03-23 16:26:26 -07:00 committed by jax authors
parent d777cf229e
commit f981243af5
2 changed files with 126 additions and 10 deletions

View File

@ -17,13 +17,19 @@
LAX decompositions for special functions into their StableHLO counterparts.
"""
from enum import Enum
import numpy as np
from functools import partial
from jax._src.lax.lax import (exp, full_like, log, log1p, mul, neg, np, reciprocal,
select, sign, square, standard_naryop, standard_unop,
xla, xops,
from jax._src.lax.lax import (bitwise_and, bitwise_not, bitwise_or,
broadcast_in_dim, broadcast_shapes,
convert_element_type, eq, exp, full_like,
gt, le, log, log1p, lt, mul, neg, reciprocal,
reduce, select, sign, square, standard_naryop,
standard_unop, xla, xops,
_broadcast_translate, _const, _dtype, _float,
_nary_lower_hlo, _ones)
_nary_lower_hlo, _ones, _isnan, _reduce)
from jax._src.lax.control_flow import while_loop
from jax._src.lax.utils import (standard_translate)
from jax._src import dtypes
@ -118,6 +124,120 @@ def igammac_gradx(g, a, x):
def igammac_grada(g, a, x):
return -igamma_grada(g, a, x)
# The below is directly ported from tensorflow/compiler/xla/client/lib/math.cc
# We try to follow the corresponding functions as closely as possible, so that
# we can quickly incorporate changes.
class IgammaMode(Enum):
VALUE = 1
DERIVATIVE = 2
SAMPLE_DERIVATIVE = 3
def _any(predicates: Array) -> Array:
f = _const(predicates, False)
predicates_shape = predicates.shape
all_dimensions = tuple(range(len(predicates_shape)))
return reduce(predicates, f, bitwise_or, all_dimensions)
def _igamma_series(ax, x, a, enabled, dtype, mode):
def cond_fn(vals):
return _any(vals[0])
def body_fn(vals):
enabled, r, c, ans, x, dc_da, dans_da = vals
r = r + _const(r, 1.)
dc_da = dc_da * (x / r) - (c * x) / (r * r)
dans_da = dans_da + dc_da
c = c * (x / r)
ans = ans + c
if mode == IgammaMode.VALUE:
conditional = bitwise_and(enabled, c / ans > dtypes.finfo(dtype).eps)
else:
conditional = bitwise_and(enabled,
abs(dc_da / dans_da) > dtypes.finfo(dtype).eps)
# TODO: Make this a vmap. Might be tricky with the imports.
return (
conditional,
select(enabled, r, vals[1]),
select(enabled, c, vals[2]),
select(enabled, ans, vals[3]),
select(enabled, x, vals[4]),
select(enabled, dc_da, vals[5]),
select(enabled, dans_da, vals[6]),
)
init_vals = (
enabled, a, full_like(a, 1), full_like(a, 1), x, full_like(a, 0),
full_like(a, 0),
)
vals = while_loop(cond_fn, body_fn, init_vals)
ans = vals[3]
dans_da = vals[6]
if mode == IgammaMode.VALUE:
return (ans * ax) / a
dlogax_da = log(x) - digamma(a + _const(a, 1))
if mode == IgammaMode.DERIVATIVE:
return ax * (ans * dlogax_da + dans_da) / a
elif mode == IgammaMode.SAMPLE_DERIVATIVE:
return -(dans_da + ans * dlogax_da) * x / a
else:
raise ValueError("Invalid IgammaMode")
def igamma_impl(a, x):
broadcasted_shape = broadcast_shapes(a.shape, x.shape)
a = broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim)))
x = broadcast_in_dim(x, broadcasted_shape, list(range(x.ndim)))
def doit(a, x, dtype):
is_nan = bitwise_or(_isnan(a), _isnan(x))
x_is_zero = eq(x, _const(x, 0))
x_is_infinity = eq(x, _const(x, float('inf')))
domain_error = bitwise_or(lt(x, _const(x, 0)), le(a, _const(a, 0)))
use_igammac = bitwise_and(gt(x, _const(x, 1)), gt(x, a))
ax = a * log(x) - x - lgamma(a)
underflow = lt(ax, -log(dtypes.finfo(dtype).max))
ax = exp(ax)
enabled = bitwise_not(
_reduce(bitwise_or,[x_is_zero, domain_error, underflow, is_nan]))
output = select(
use_igammac,
_const(a, 1) -
_igammac_continued_fraction(ax, x, a, bitwise_and(enabled, use_igammac),
dtype, IgammaMode.VALUE),
_igamma_series(ax, x, a, bitwise_and(enabled, bitwise_not(use_igammac)),
dtype, IgammaMode.VALUE)
)
output = select(x_is_zero, full_like(a, 0), output)
output = select(x_is_infinity, full_like(a, 1), output)
output = select(bitwise_or(domain_error, is_nan),
full_like(a, float('nan')), output)
return output
needs_upcast = a.dtype == dtypes.bfloat16 or a.dtype == np.float16
if needs_upcast:
a_dtype = a.dtype
a = convert_element_type(a, np.float32)
x = convert_element_type(x, np.float32)
a_x_type = np.float32
else:
a_x_type = a.dtype
result = doit(a, x, a_x_type)
if needs_upcast:
result = convert_element_type(result, a_dtype)
return result
def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode):
# TODO(atondwal): implement _igammac_continued_fraction in JAX.
# Right now we fallback to the XLA implementation of IgammacContinuedFraction.
return igammac(a, x)
lgamma_p = standard_unop(_float, 'lgamma')
ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x)))
mlir.register_lowering(lgamma_p, partial(_nary_lower_hlo, chlo.LgammaOp))
@ -126,7 +246,8 @@ digamma_p = standard_unop(_float, 'digamma')
mlir.register_lowering(digamma_p, partial(_nary_lower_hlo, chlo.DigammaOp))
igamma_p = standard_naryop([_float, _float], 'igamma')
xla.register_translation(igamma_p, partial(_broadcast_translate, xops.Igamma))
mlir.register_lowering(igamma_p,
mlir.lower_fun(igamma_impl, multiple_results=False))
igamma_grad_a_p = standard_naryop([_float, _float], 'igamma_grad_a')
xla.register_translation(igamma_grad_a_p,

View File

@ -262,11 +262,6 @@ def main(_):
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.gt)
# CHECK-LABEL: TEST: igamma float32[] float32[]
# CHECK: xla_fallback_igamma
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0), np.float32(0))(lax.igamma)
# CHECK-LABEL: TEST: igammac float32[] float32[]
# CHECK: xla_fallback_igammac
# CHECK-SAME: tensor<f32>