mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Migrate igamma_grad_a_p off xla_fallback
PiperOrigin-RevId: 519148548
This commit is contained in:
parent
4a9b09485e
commit
8d1d522618
@ -189,49 +189,31 @@ def _igamma_series(ax, x, a, enabled, dtype, mode):
|
||||
else:
|
||||
raise ValueError("Invalid IgammaMode")
|
||||
|
||||
def igamma_impl(a, x):
|
||||
broadcasted_shape = broadcast_shapes(a.shape, x.shape)
|
||||
a = broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim)))
|
||||
x = broadcast_in_dim(x, broadcasted_shape, list(range(x.ndim)))
|
||||
def igamma_impl(a, x, dtype):
|
||||
is_nan = bitwise_or(_isnan(a), _isnan(x))
|
||||
x_is_zero = eq(x, _const(x, 0))
|
||||
x_is_infinity = eq(x, _const(x, float('inf')))
|
||||
domain_error = bitwise_or(lt(x, _const(x, 0)), le(a, _const(a, 0)))
|
||||
use_igammac = bitwise_and(gt(x, _const(x, 1)), gt(x, a))
|
||||
ax = a * log(x) - x - lgamma(a)
|
||||
underflow = lt(ax, -log(dtypes.finfo(dtype).max))
|
||||
ax = exp(ax)
|
||||
enabled = bitwise_not(
|
||||
_reduce(bitwise_or,[x_is_zero, domain_error, underflow, is_nan]))
|
||||
|
||||
def doit(a, x, dtype):
|
||||
is_nan = bitwise_or(_isnan(a), _isnan(x))
|
||||
x_is_zero = eq(x, _const(x, 0))
|
||||
x_is_infinity = eq(x, _const(x, float('inf')))
|
||||
domain_error = bitwise_or(lt(x, _const(x, 0)), le(a, _const(a, 0)))
|
||||
use_igammac = bitwise_and(gt(x, _const(x, 1)), gt(x, a))
|
||||
ax = a * log(x) - x - lgamma(a)
|
||||
underflow = lt(ax, -log(dtypes.finfo(dtype).max))
|
||||
ax = exp(ax)
|
||||
enabled = bitwise_not(
|
||||
_reduce(bitwise_or,[x_is_zero, domain_error, underflow, is_nan]))
|
||||
|
||||
output = select(
|
||||
use_igammac,
|
||||
_const(a, 1) -
|
||||
_igammac_continued_fraction(ax, x, a, bitwise_and(enabled, use_igammac),
|
||||
dtype, IgammaMode.VALUE),
|
||||
_igamma_series(ax, x, a, bitwise_and(enabled, bitwise_not(use_igammac)),
|
||||
dtype, IgammaMode.VALUE)
|
||||
)
|
||||
output = select(x_is_zero, full_like(a, 0), output)
|
||||
output = select(x_is_infinity, full_like(a, 1), output)
|
||||
output = select(bitwise_or(domain_error, is_nan),
|
||||
full_like(a, float('nan')), output)
|
||||
return output
|
||||
|
||||
needs_upcast = a.dtype == dtypes.bfloat16 or a.dtype == np.float16
|
||||
if needs_upcast:
|
||||
a_dtype = a.dtype
|
||||
a = convert_element_type(a, np.float32)
|
||||
x = convert_element_type(x, np.float32)
|
||||
a_x_type = np.float32
|
||||
else:
|
||||
a_x_type = a.dtype
|
||||
result = doit(a, x, a_x_type)
|
||||
if needs_upcast:
|
||||
result = convert_element_type(result, a_dtype)
|
||||
return result
|
||||
output = select(
|
||||
use_igammac,
|
||||
_const(a, 1) -
|
||||
_igammac_continued_fraction(ax, x, a, bitwise_and(enabled, use_igammac),
|
||||
dtype, IgammaMode.VALUE),
|
||||
_igamma_series(ax, x, a, bitwise_and(enabled, bitwise_not(use_igammac)),
|
||||
dtype, IgammaMode.VALUE)
|
||||
)
|
||||
output = select(x_is_zero, full_like(a, 0), output)
|
||||
output = select(x_is_infinity, full_like(a, 1), output)
|
||||
output = select(bitwise_or(domain_error, is_nan),
|
||||
full_like(a, float('nan')), output)
|
||||
return output
|
||||
|
||||
def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode):
|
||||
eps = dtypes.finfo(dtype).eps
|
||||
@ -339,42 +321,64 @@ def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode):
|
||||
raise ValueError(f"Invalid mode: {mode}")
|
||||
|
||||
|
||||
def igammac_impl(a, x):
|
||||
broadcasted_shape = broadcast_shapes(a.shape, x.shape)
|
||||
a = broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim)))
|
||||
x = broadcast_in_dim(x, broadcasted_shape, list(range(x.ndim)))
|
||||
def igammac_impl(a, x, dtype):
|
||||
out_of_range = bitwise_or(le(x, _const(x, 0)), le(a, _const(a, 0)))
|
||||
use_igamma = bitwise_or(lt(x, _const(x, 1)), lt(x, a))
|
||||
ax = a * log(x) - x - lgamma(a)
|
||||
underflow = lt(ax, -log(dtypes.finfo(dtype).max))
|
||||
enabled = bitwise_not(bitwise_or(out_of_range, underflow))
|
||||
ax = exp(ax)
|
||||
|
||||
def doit(a, x, dtype):
|
||||
out_of_range = bitwise_or(le(x, _const(x, 0)), le(a, _const(a, 0)))
|
||||
use_igamma = bitwise_or(lt(x, _const(x, 1)), lt(x, a))
|
||||
ax = a * log(x) - x - lgamma(a)
|
||||
underflow = lt(ax, -log(dtypes.finfo(dtype).max))
|
||||
enabled = bitwise_not(bitwise_or(out_of_range, underflow))
|
||||
ax = exp(ax)
|
||||
igamma_call = _igamma_series(ax, x, a, bitwise_and(enabled, use_igamma),
|
||||
dtype, IgammaMode.VALUE)
|
||||
igammac_cf_call = _igammac_continued_fraction(ax, x, a,
|
||||
bitwise_and(enabled, bitwise_not(use_igamma)), dtype, IgammaMode.VALUE)
|
||||
|
||||
igamma_call = _igamma_series(ax, x, a, bitwise_and(enabled, use_igamma),
|
||||
dtype, IgammaMode.VALUE)
|
||||
igammac_cf_call = _igammac_continued_fraction(ax, x, a,
|
||||
bitwise_and(enabled, bitwise_not(use_igamma)), dtype, IgammaMode.VALUE)
|
||||
result = select(use_igamma, _const(a, 1) - igamma_call, igammac_cf_call)
|
||||
x_is_infinity = eq(x, _const(x, float('inf')))
|
||||
result = select(x_is_infinity, full_like(result, 0), result)
|
||||
return select(out_of_range, full_like(a, 1), result)
|
||||
|
||||
result = select(use_igamma, _const(a, 1) - igamma_call, igammac_cf_call)
|
||||
x_is_infinity = eq(x, _const(x, float('inf')))
|
||||
result = select(x_is_infinity, full_like(result, 0), result);
|
||||
return select(out_of_range, full_like(a, 1), result);
|
||||
def igamma_grad_a_impl(a, x, dtype):
|
||||
is_nan = bitwise_or(_isnan(a), _isnan(x))
|
||||
x_is_zero = eq(x, full_like(x,0))
|
||||
domain_error = bitwise_or(lt(x, full_like(x, 0)), le(a, full_like(a, 0)))
|
||||
use_igammac = bitwise_and(gt(x, full_like(x,1)), gt(x, a))
|
||||
ax = a * log(x) - x - lgamma(a)
|
||||
underflow = lt(ax, -log(dtypes.finfo(dtype).max))
|
||||
ax = exp(ax)
|
||||
enabled = bitwise_not(bitwise_or(bitwise_or(bitwise_or(
|
||||
x_is_zero, domain_error), underflow), is_nan))
|
||||
output = select(use_igammac,
|
||||
-_igammac_continued_fraction(ax, x, a, bitwise_and(enabled, use_igammac),
|
||||
dtype, IgammaMode.DERIVATIVE),
|
||||
_igamma_series(ax, x, a, bitwise_and(enabled, bitwise_not(use_igammac)),
|
||||
dtype, IgammaMode.DERIVATIVE))
|
||||
output = select(x_is_zero, full_like(output,0), output)
|
||||
output = select(bitwise_or(domain_error, is_nan),
|
||||
full_like(a, float('nan')), output)
|
||||
return output
|
||||
|
||||
needs_upcast = a.dtype == dtypes.bfloat16 or a.dtype == np.float16
|
||||
if needs_upcast:
|
||||
a_dtype = a.dtype
|
||||
a = convert_element_type(a, np.float32)
|
||||
x = convert_element_type(x, np.float32)
|
||||
a_x_type = np.float32
|
||||
else:
|
||||
a_x_type = a.dtype
|
||||
def _up_and_broadcast(doit):
|
||||
def up_and_broadcast(a, x):
|
||||
broadcasted_shape = broadcast_shapes(a.shape, x.shape)
|
||||
a = broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim)))
|
||||
x = broadcast_in_dim(x, broadcasted_shape, list(range(x.ndim)))
|
||||
|
||||
result = doit(a, x, a_x_type)
|
||||
if needs_upcast:
|
||||
result = convert_element_type(result, a_dtype)
|
||||
return result
|
||||
needs_upcast = a.dtype == dtypes.bfloat16 or a.dtype == np.float16
|
||||
if needs_upcast:
|
||||
a_dtype = a.dtype
|
||||
a = convert_element_type(a, np.float32)
|
||||
x = convert_element_type(x, np.float32)
|
||||
a_x_type = np.float32
|
||||
else:
|
||||
a_x_type = a.dtype
|
||||
|
||||
result = doit(a, x, a_x_type)
|
||||
if needs_upcast:
|
||||
result = convert_element_type(result, a_dtype)
|
||||
return result
|
||||
return up_and_broadcast
|
||||
|
||||
lgamma_p = standard_unop(_float, 'lgamma')
|
||||
ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x)))
|
||||
@ -384,18 +388,20 @@ digamma_p = standard_unop(_float, 'digamma')
|
||||
mlir.register_lowering(digamma_p, partial(_nary_lower_hlo, chlo.DigammaOp))
|
||||
|
||||
igamma_p = standard_naryop([_float, _float], 'igamma')
|
||||
mlir.register_lowering(igamma_p,
|
||||
mlir.lower_fun(igamma_impl, multiple_results=False))
|
||||
mlir.register_lowering(igamma_p, mlir.lower_fun(_up_and_broadcast(igamma_impl),
|
||||
multiple_results=False))
|
||||
|
||||
igamma_grad_a_p = standard_naryop([_float, _float], 'igamma_grad_a')
|
||||
xla.register_translation(igamma_grad_a_p,
|
||||
partial(_broadcast_translate, xops.IgammaGradA))
|
||||
mlir.register_lowering(igamma_grad_a_p,
|
||||
mlir.lower_fun(_up_and_broadcast(igamma_grad_a_impl),
|
||||
multiple_results=False))
|
||||
|
||||
ad.defjvp(igamma_p, igamma_grada, igamma_gradx)
|
||||
|
||||
igammac_p = standard_naryop([_float, _float], 'igammac')
|
||||
mlir.register_lowering(igammac_p,
|
||||
mlir.lower_fun(igammac_impl, multiple_results=False))
|
||||
mlir.lower_fun(_up_and_broadcast(igammac_impl),
|
||||
multiple_results=False))
|
||||
|
||||
ad.defjvp(igammac_p, igammac_grada, igammac_gradx)
|
||||
|
||||
|
@ -262,11 +262,6 @@ def main(_):
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.gt)
|
||||
|
||||
# CHECK-LABEL: TEST: igamma_grad_a float32[] float32[]
|
||||
# CHECK: xla_fallback_igamma_grad_a
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0), np.float32(0))(lax.igamma_grad_a)
|
||||
|
||||
# CHECK-LABEL: TEST: imag complex64[]
|
||||
# CHECK: hlo.imag
|
||||
# CHECK-SAME: tensor<complex<f32>>
|
||||
|
Loading…
x
Reference in New Issue
Block a user