make random_gamma_grad not a primitive anymore

Fixes #16076

Co-authored-by: Roy Frostig <frostig@google.com>
This commit is contained in:
Matthew Johnson 2025-04-01 00:17:30 +00:00
parent 76271d638a
commit a80f6279e9
4 changed files with 23 additions and 28 deletions

View File

@ -600,7 +600,7 @@ nan_primitives = [lax.acos_p, lax.acosh_p, lax.add_p, lax.asin_p, lax.asinh_p,
lax.igamma_p, lax.igammac_p, lax.integer_pow_p, lax.lgamma_p,
lax.linear_solve_p, lax.log1p_p, lax.log_p, lax.logistic_p,
lax.mul_p, lax.pad_p, lax.pow_p, lax.psum_p,
lax.random_gamma_grad_p, lax.reduce_p, lax.reduce_prod_p,
lax.reduce_p, lax.reduce_prod_p,
lax.reduce_sum_p, lax.reduce_window_p,
lax.reduce_window_sum_p, lax.regularized_incomplete_beta_p,
lax.rem_p, lax.rng_uniform_p, lax.rsqrt_p, lax.sin_p,

View File

@ -38,6 +38,25 @@ from jax._src.interpreters import mlir
from jax._src.lib.mlir.dialects import chlo
from jax._src.typing import Array, ArrayLike
# TODO(mattjj): this function sucks, delete it
def _up_and_broadcast(doit):
def up_and_broadcast(*args):
broadcasted_shape = broadcast_shapes(*(a.shape for a in args))
args = [broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim))) for a in args]
a_dtype = args[0].dtype
needs_upcast = a_dtype == dtypes.bfloat16 or a_dtype == np.float16
if needs_upcast:
args = [convert_element_type(a, np.float32) for a in args]
a_x_type = np.float32
else:
a_x_type = a_dtype
result = doit(*args, dtype=a_x_type)
if needs_upcast:
result = convert_element_type(result, a_dtype)
return result
return up_and_broadcast
def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array:
r"""Elementwise regularized incomplete beta integral."""
a, b, x = core.standard_insert_pbroadcast(a, b, x)
@ -71,10 +90,11 @@ def igamma_grad_a(a: ArrayLike, x: ArrayLike) -> Array:
a, x = core.standard_insert_pbroadcast(a, x)
return igamma_grad_a_p.bind(a, x)
def random_gamma_grad(a: ArrayLike, x: ArrayLike) -> Array:
@_up_and_broadcast
def random_gamma_grad(a: ArrayLike, x: ArrayLike, *, dtype) -> Array:
r"""Elementwise derivative of samples from `Gamma(a, 1)`."""
a, x = core.standard_insert_pbroadcast(a, x)
return random_gamma_grad_p.bind(a, x)
return random_gamma_grad_impl(a, x, dtype=dtype)
def zeta(x: ArrayLike, q: ArrayLike) -> Array:
r"""Elementwise Hurwitz zeta function: :math:`\zeta(x, q)`"""
@ -531,24 +551,6 @@ def random_gamma_grad_impl(a, x, *, dtype):
full_like(a, float('nan')), output)
return output
def _up_and_broadcast(doit):
def up_and_broadcast(*args):
broadcasted_shape = broadcast_shapes(*(a.shape for a in args))
args = [broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim))) for a in args]
a_dtype = args[0].dtype
needs_upcast = a_dtype == dtypes.bfloat16 or a_dtype == np.float16
if needs_upcast:
args = [convert_element_type(a, np.float32) for a in args]
a_x_type = np.float32
else:
a_x_type = a_dtype
result = doit(*args, dtype=a_x_type)
if needs_upcast:
result = convert_element_type(result, a_dtype)
return result
return up_and_broadcast
def evaluate_chebyshev_polynomial(x, coefficients):
b0 = full_like(x,0)
@ -694,11 +696,6 @@ mlir.register_lowering(igammac_p,
ad.defjvp(igammac_p, igammac_grada, igammac_gradx)
random_gamma_grad_p = standard_naryop([_float, _float], 'random_gamma_grad')
mlir.register_lowering(random_gamma_grad_p,
mlir.lower_fun(_up_and_broadcast(random_gamma_grad_impl),
multiple_results=False))
zeta_p = standard_naryop([_float, _float], 'zeta')
mlir.register_lowering(zeta_p, partial(_nary_lower_hlo, chlo.zeta))

View File

@ -149,7 +149,6 @@ from jax._src.lax.special import (
igamma_p as igamma_p,
lgamma_p as lgamma_p,
polygamma_p as polygamma_p,
random_gamma_grad_p as random_gamma_grad_p,
regularized_incomplete_beta_p as regularized_incomplete_beta_p,
zeta_p as zeta_p,
)

View File

@ -261,7 +261,6 @@ from jax._src.lax.special import (
polygamma as polygamma,
polygamma_p as polygamma_p,
random_gamma_grad as random_gamma_grad,
random_gamma_grad_p as random_gamma_grad_p,
regularized_incomplete_beta_p as regularized_incomplete_beta_p,
zeta as zeta,
zeta_p as zeta_p,