Migrate igamma_grad_a_p off xla_fallback

PiperOrigin-RevId: 519148548
This commit is contained in:
Anish Tondwalkar 2023-03-24 08:20:46 -07:00 committed by jax authors
parent 4a9b09485e
commit 8d1d522618
2 changed files with 84 additions and 83 deletions

View File

@ -189,49 +189,31 @@ def _igamma_series(ax, x, a, enabled, dtype, mode):
else:
raise ValueError("Invalid IgammaMode")
def igamma_impl(a, x):
broadcasted_shape = broadcast_shapes(a.shape, x.shape)
a = broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim)))
x = broadcast_in_dim(x, broadcasted_shape, list(range(x.ndim)))
def igamma_impl(a, x, dtype):
is_nan = bitwise_or(_isnan(a), _isnan(x))
x_is_zero = eq(x, _const(x, 0))
x_is_infinity = eq(x, _const(x, float('inf')))
domain_error = bitwise_or(lt(x, _const(x, 0)), le(a, _const(a, 0)))
use_igammac = bitwise_and(gt(x, _const(x, 1)), gt(x, a))
ax = a * log(x) - x - lgamma(a)
underflow = lt(ax, -log(dtypes.finfo(dtype).max))
ax = exp(ax)
enabled = bitwise_not(
_reduce(bitwise_or,[x_is_zero, domain_error, underflow, is_nan]))
def doit(a, x, dtype):
is_nan = bitwise_or(_isnan(a), _isnan(x))
x_is_zero = eq(x, _const(x, 0))
x_is_infinity = eq(x, _const(x, float('inf')))
domain_error = bitwise_or(lt(x, _const(x, 0)), le(a, _const(a, 0)))
use_igammac = bitwise_and(gt(x, _const(x, 1)), gt(x, a))
ax = a * log(x) - x - lgamma(a)
underflow = lt(ax, -log(dtypes.finfo(dtype).max))
ax = exp(ax)
enabled = bitwise_not(
_reduce(bitwise_or,[x_is_zero, domain_error, underflow, is_nan]))
output = select(
use_igammac,
_const(a, 1) -
_igammac_continued_fraction(ax, x, a, bitwise_and(enabled, use_igammac),
dtype, IgammaMode.VALUE),
_igamma_series(ax, x, a, bitwise_and(enabled, bitwise_not(use_igammac)),
dtype, IgammaMode.VALUE)
)
output = select(x_is_zero, full_like(a, 0), output)
output = select(x_is_infinity, full_like(a, 1), output)
output = select(bitwise_or(domain_error, is_nan),
full_like(a, float('nan')), output)
return output
needs_upcast = a.dtype == dtypes.bfloat16 or a.dtype == np.float16
if needs_upcast:
a_dtype = a.dtype
a = convert_element_type(a, np.float32)
x = convert_element_type(x, np.float32)
a_x_type = np.float32
else:
a_x_type = a.dtype
result = doit(a, x, a_x_type)
if needs_upcast:
result = convert_element_type(result, a_dtype)
return result
output = select(
use_igammac,
_const(a, 1) -
_igammac_continued_fraction(ax, x, a, bitwise_and(enabled, use_igammac),
dtype, IgammaMode.VALUE),
_igamma_series(ax, x, a, bitwise_and(enabled, bitwise_not(use_igammac)),
dtype, IgammaMode.VALUE)
)
output = select(x_is_zero, full_like(a, 0), output)
output = select(x_is_infinity, full_like(a, 1), output)
output = select(bitwise_or(domain_error, is_nan),
full_like(a, float('nan')), output)
return output
def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode):
eps = dtypes.finfo(dtype).eps
@ -339,42 +321,64 @@ def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode):
raise ValueError(f"Invalid mode: {mode}")
def igammac_impl(a, x):
broadcasted_shape = broadcast_shapes(a.shape, x.shape)
a = broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim)))
x = broadcast_in_dim(x, broadcasted_shape, list(range(x.ndim)))
def igammac_impl(a, x, dtype):
out_of_range = bitwise_or(le(x, _const(x, 0)), le(a, _const(a, 0)))
use_igamma = bitwise_or(lt(x, _const(x, 1)), lt(x, a))
ax = a * log(x) - x - lgamma(a)
underflow = lt(ax, -log(dtypes.finfo(dtype).max))
enabled = bitwise_not(bitwise_or(out_of_range, underflow))
ax = exp(ax)
def doit(a, x, dtype):
out_of_range = bitwise_or(le(x, _const(x, 0)), le(a, _const(a, 0)))
use_igamma = bitwise_or(lt(x, _const(x, 1)), lt(x, a))
ax = a * log(x) - x - lgamma(a)
underflow = lt(ax, -log(dtypes.finfo(dtype).max))
enabled = bitwise_not(bitwise_or(out_of_range, underflow))
ax = exp(ax)
igamma_call = _igamma_series(ax, x, a, bitwise_and(enabled, use_igamma),
dtype, IgammaMode.VALUE)
igammac_cf_call = _igammac_continued_fraction(ax, x, a,
bitwise_and(enabled, bitwise_not(use_igamma)), dtype, IgammaMode.VALUE)
igamma_call = _igamma_series(ax, x, a, bitwise_and(enabled, use_igamma),
dtype, IgammaMode.VALUE)
igammac_cf_call = _igammac_continued_fraction(ax, x, a,
bitwise_and(enabled, bitwise_not(use_igamma)), dtype, IgammaMode.VALUE)
result = select(use_igamma, _const(a, 1) - igamma_call, igammac_cf_call)
x_is_infinity = eq(x, _const(x, float('inf')))
result = select(x_is_infinity, full_like(result, 0), result)
return select(out_of_range, full_like(a, 1), result)
result = select(use_igamma, _const(a, 1) - igamma_call, igammac_cf_call)
x_is_infinity = eq(x, _const(x, float('inf')))
result = select(x_is_infinity, full_like(result, 0), result);
return select(out_of_range, full_like(a, 1), result);
def igamma_grad_a_impl(a, x, dtype):
is_nan = bitwise_or(_isnan(a), _isnan(x))
x_is_zero = eq(x, full_like(x,0))
domain_error = bitwise_or(lt(x, full_like(x, 0)), le(a, full_like(a, 0)))
use_igammac = bitwise_and(gt(x, full_like(x,1)), gt(x, a))
ax = a * log(x) - x - lgamma(a)
underflow = lt(ax, -log(dtypes.finfo(dtype).max))
ax = exp(ax)
enabled = bitwise_not(bitwise_or(bitwise_or(bitwise_or(
x_is_zero, domain_error), underflow), is_nan))
output = select(use_igammac,
-_igammac_continued_fraction(ax, x, a, bitwise_and(enabled, use_igammac),
dtype, IgammaMode.DERIVATIVE),
_igamma_series(ax, x, a, bitwise_and(enabled, bitwise_not(use_igammac)),
dtype, IgammaMode.DERIVATIVE))
output = select(x_is_zero, full_like(output,0), output)
output = select(bitwise_or(domain_error, is_nan),
full_like(a, float('nan')), output)
return output
needs_upcast = a.dtype == dtypes.bfloat16 or a.dtype == np.float16
if needs_upcast:
a_dtype = a.dtype
a = convert_element_type(a, np.float32)
x = convert_element_type(x, np.float32)
a_x_type = np.float32
else:
a_x_type = a.dtype
def _up_and_broadcast(doit):
def up_and_broadcast(a, x):
broadcasted_shape = broadcast_shapes(a.shape, x.shape)
a = broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim)))
x = broadcast_in_dim(x, broadcasted_shape, list(range(x.ndim)))
result = doit(a, x, a_x_type)
if needs_upcast:
result = convert_element_type(result, a_dtype)
return result
needs_upcast = a.dtype == dtypes.bfloat16 or a.dtype == np.float16
if needs_upcast:
a_dtype = a.dtype
a = convert_element_type(a, np.float32)
x = convert_element_type(x, np.float32)
a_x_type = np.float32
else:
a_x_type = a.dtype
result = doit(a, x, a_x_type)
if needs_upcast:
result = convert_element_type(result, a_dtype)
return result
return up_and_broadcast
lgamma_p = standard_unop(_float, 'lgamma')
ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x)))
@ -384,18 +388,20 @@ digamma_p = standard_unop(_float, 'digamma')
mlir.register_lowering(digamma_p, partial(_nary_lower_hlo, chlo.DigammaOp))
igamma_p = standard_naryop([_float, _float], 'igamma')
mlir.register_lowering(igamma_p,
mlir.lower_fun(igamma_impl, multiple_results=False))
mlir.register_lowering(igamma_p, mlir.lower_fun(_up_and_broadcast(igamma_impl),
multiple_results=False))
igamma_grad_a_p = standard_naryop([_float, _float], 'igamma_grad_a')
xla.register_translation(igamma_grad_a_p,
partial(_broadcast_translate, xops.IgammaGradA))
mlir.register_lowering(igamma_grad_a_p,
mlir.lower_fun(_up_and_broadcast(igamma_grad_a_impl),
multiple_results=False))
ad.defjvp(igamma_p, igamma_grada, igamma_gradx)
igammac_p = standard_naryop([_float, _float], 'igammac')
mlir.register_lowering(igammac_p,
mlir.lower_fun(igammac_impl, multiple_results=False))
mlir.lower_fun(_up_and_broadcast(igammac_impl),
multiple_results=False))
ad.defjvp(igammac_p, igammac_grada, igammac_gradx)

View File

@ -262,11 +262,6 @@ def main(_):
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.gt)
# CHECK-LABEL: TEST: igamma_grad_a float32[] float32[]
# CHECK: xla_fallback_igamma_grad_a
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0), np.float32(0))(lax.igamma_grad_a)
# CHECK-LABEL: TEST: imag complex64[]
# CHECK: hlo.imag
# CHECK-SAME: tensor<complex<f32>>