mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Migrate igamma_p off xla_fallback
We decompose it into a series or a call to igammac. PiperOrigin-RevId: 518993077
This commit is contained in:
parent
d777cf229e
commit
f981243af5
@ -17,13 +17,19 @@
|
||||
LAX decompositions for special functions into their StableHLO counterparts.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
|
||||
from jax._src.lax.lax import (exp, full_like, log, log1p, mul, neg, np, reciprocal,
|
||||
select, sign, square, standard_naryop, standard_unop,
|
||||
xla, xops,
|
||||
from jax._src.lax.lax import (bitwise_and, bitwise_not, bitwise_or,
|
||||
broadcast_in_dim, broadcast_shapes,
|
||||
convert_element_type, eq, exp, full_like,
|
||||
gt, le, log, log1p, lt, mul, neg, reciprocal,
|
||||
reduce, select, sign, square, standard_naryop,
|
||||
standard_unop, xla, xops,
|
||||
_broadcast_translate, _const, _dtype, _float,
|
||||
_nary_lower_hlo, _ones)
|
||||
_nary_lower_hlo, _ones, _isnan, _reduce)
|
||||
from jax._src.lax.control_flow import while_loop
|
||||
from jax._src.lax.utils import (standard_translate)
|
||||
|
||||
from jax._src import dtypes
|
||||
@ -118,6 +124,120 @@ def igammac_gradx(g, a, x):
|
||||
def igammac_grada(g, a, x):
|
||||
return -igamma_grada(g, a, x)
|
||||
|
||||
# The below is directly ported from tensorflow/compiler/xla/client/lib/math.cc
|
||||
# We try to follow the corresponding functions as closely as possible, so that
|
||||
# we can quickly incorporate changes.
|
||||
class IgammaMode(Enum):
|
||||
VALUE = 1
|
||||
DERIVATIVE = 2
|
||||
SAMPLE_DERIVATIVE = 3
|
||||
|
||||
def _any(predicates: Array) -> Array:
|
||||
f = _const(predicates, False)
|
||||
predicates_shape = predicates.shape
|
||||
all_dimensions = tuple(range(len(predicates_shape)))
|
||||
return reduce(predicates, f, bitwise_or, all_dimensions)
|
||||
|
||||
def _igamma_series(ax, x, a, enabled, dtype, mode):
|
||||
def cond_fn(vals):
|
||||
return _any(vals[0])
|
||||
|
||||
def body_fn(vals):
|
||||
enabled, r, c, ans, x, dc_da, dans_da = vals
|
||||
|
||||
r = r + _const(r, 1.)
|
||||
dc_da = dc_da * (x / r) - (c * x) / (r * r)
|
||||
dans_da = dans_da + dc_da
|
||||
c = c * (x / r)
|
||||
ans = ans + c
|
||||
|
||||
if mode == IgammaMode.VALUE:
|
||||
conditional = bitwise_and(enabled, c / ans > dtypes.finfo(dtype).eps)
|
||||
else:
|
||||
conditional = bitwise_and(enabled,
|
||||
abs(dc_da / dans_da) > dtypes.finfo(dtype).eps)
|
||||
|
||||
# TODO: Make this a vmap. Might be tricky with the imports.
|
||||
return (
|
||||
conditional,
|
||||
select(enabled, r, vals[1]),
|
||||
select(enabled, c, vals[2]),
|
||||
select(enabled, ans, vals[3]),
|
||||
select(enabled, x, vals[4]),
|
||||
select(enabled, dc_da, vals[5]),
|
||||
select(enabled, dans_da, vals[6]),
|
||||
)
|
||||
|
||||
init_vals = (
|
||||
enabled, a, full_like(a, 1), full_like(a, 1), x, full_like(a, 0),
|
||||
full_like(a, 0),
|
||||
)
|
||||
|
||||
vals = while_loop(cond_fn, body_fn, init_vals)
|
||||
ans = vals[3]
|
||||
dans_da = vals[6]
|
||||
|
||||
if mode == IgammaMode.VALUE:
|
||||
return (ans * ax) / a
|
||||
|
||||
dlogax_da = log(x) - digamma(a + _const(a, 1))
|
||||
|
||||
if mode == IgammaMode.DERIVATIVE:
|
||||
return ax * (ans * dlogax_da + dans_da) / a
|
||||
elif mode == IgammaMode.SAMPLE_DERIVATIVE:
|
||||
return -(dans_da + ans * dlogax_da) * x / a
|
||||
else:
|
||||
raise ValueError("Invalid IgammaMode")
|
||||
|
||||
def igamma_impl(a, x):
|
||||
broadcasted_shape = broadcast_shapes(a.shape, x.shape)
|
||||
a = broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim)))
|
||||
x = broadcast_in_dim(x, broadcasted_shape, list(range(x.ndim)))
|
||||
|
||||
def doit(a, x, dtype):
|
||||
is_nan = bitwise_or(_isnan(a), _isnan(x))
|
||||
x_is_zero = eq(x, _const(x, 0))
|
||||
x_is_infinity = eq(x, _const(x, float('inf')))
|
||||
domain_error = bitwise_or(lt(x, _const(x, 0)), le(a, _const(a, 0)))
|
||||
use_igammac = bitwise_and(gt(x, _const(x, 1)), gt(x, a))
|
||||
ax = a * log(x) - x - lgamma(a)
|
||||
underflow = lt(ax, -log(dtypes.finfo(dtype).max))
|
||||
ax = exp(ax)
|
||||
enabled = bitwise_not(
|
||||
_reduce(bitwise_or,[x_is_zero, domain_error, underflow, is_nan]))
|
||||
|
||||
output = select(
|
||||
use_igammac,
|
||||
_const(a, 1) -
|
||||
_igammac_continued_fraction(ax, x, a, bitwise_and(enabled, use_igammac),
|
||||
dtype, IgammaMode.VALUE),
|
||||
_igamma_series(ax, x, a, bitwise_and(enabled, bitwise_not(use_igammac)),
|
||||
dtype, IgammaMode.VALUE)
|
||||
)
|
||||
output = select(x_is_zero, full_like(a, 0), output)
|
||||
output = select(x_is_infinity, full_like(a, 1), output)
|
||||
output = select(bitwise_or(domain_error, is_nan),
|
||||
full_like(a, float('nan')), output)
|
||||
return output
|
||||
|
||||
needs_upcast = a.dtype == dtypes.bfloat16 or a.dtype == np.float16
|
||||
if needs_upcast:
|
||||
a_dtype = a.dtype
|
||||
a = convert_element_type(a, np.float32)
|
||||
x = convert_element_type(x, np.float32)
|
||||
a_x_type = np.float32
|
||||
else:
|
||||
a_x_type = a.dtype
|
||||
result = doit(a, x, a_x_type)
|
||||
if needs_upcast:
|
||||
result = convert_element_type(result, a_dtype)
|
||||
return result
|
||||
|
||||
def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode):
|
||||
# TODO(atondwal): implement _igammac_continued_fraction in JAX.
|
||||
# Right now we fallback to the XLA implementation of IgammacContinuedFraction.
|
||||
return igammac(a, x)
|
||||
|
||||
lgamma_p = standard_unop(_float, 'lgamma')
|
||||
ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x)))
|
||||
mlir.register_lowering(lgamma_p, partial(_nary_lower_hlo, chlo.LgammaOp))
|
||||
@ -126,7 +246,8 @@ digamma_p = standard_unop(_float, 'digamma')
|
||||
mlir.register_lowering(digamma_p, partial(_nary_lower_hlo, chlo.DigammaOp))
|
||||
|
||||
igamma_p = standard_naryop([_float, _float], 'igamma')
|
||||
xla.register_translation(igamma_p, partial(_broadcast_translate, xops.Igamma))
|
||||
mlir.register_lowering(igamma_p,
|
||||
mlir.lower_fun(igamma_impl, multiple_results=False))
|
||||
|
||||
igamma_grad_a_p = standard_naryop([_float, _float], 'igamma_grad_a')
|
||||
xla.register_translation(igamma_grad_a_p,
|
||||
|
@ -262,11 +262,6 @@ def main(_):
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.gt)
|
||||
|
||||
# CHECK-LABEL: TEST: igamma float32[] float32[]
|
||||
# CHECK: xla_fallback_igamma
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0), np.float32(0))(lax.igamma)
|
||||
|
||||
# CHECK-LABEL: TEST: igammac float32[] float32[]
|
||||
# CHECK: xla_fallback_igammac
|
||||
# CHECK-SAME: tensor<f32>
|
||||
|
Loading…
x
Reference in New Issue
Block a user