mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
make random_gamma_grad not a primitive anymore
Fixes #16076 Co-authored-by: Roy Frostig <frostig@google.com>
This commit is contained in:
parent
76271d638a
commit
a80f6279e9
@ -600,7 +600,7 @@ nan_primitives = [lax.acos_p, lax.acosh_p, lax.add_p, lax.asin_p, lax.asinh_p,
|
||||
lax.igamma_p, lax.igammac_p, lax.integer_pow_p, lax.lgamma_p,
|
||||
lax.linear_solve_p, lax.log1p_p, lax.log_p, lax.logistic_p,
|
||||
lax.mul_p, lax.pad_p, lax.pow_p, lax.psum_p,
|
||||
lax.random_gamma_grad_p, lax.reduce_p, lax.reduce_prod_p,
|
||||
lax.reduce_p, lax.reduce_prod_p,
|
||||
lax.reduce_sum_p, lax.reduce_window_p,
|
||||
lax.reduce_window_sum_p, lax.regularized_incomplete_beta_p,
|
||||
lax.rem_p, lax.rng_uniform_p, lax.rsqrt_p, lax.sin_p,
|
||||
|
@ -38,6 +38,25 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.lib.mlir.dialects import chlo
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
# TODO(mattjj): this function sucks, delete it
|
||||
def _up_and_broadcast(doit):
|
||||
def up_and_broadcast(*args):
|
||||
broadcasted_shape = broadcast_shapes(*(a.shape for a in args))
|
||||
args = [broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim))) for a in args]
|
||||
|
||||
a_dtype = args[0].dtype
|
||||
needs_upcast = a_dtype == dtypes.bfloat16 or a_dtype == np.float16
|
||||
if needs_upcast:
|
||||
args = [convert_element_type(a, np.float32) for a in args]
|
||||
a_x_type = np.float32
|
||||
else:
|
||||
a_x_type = a_dtype
|
||||
result = doit(*args, dtype=a_x_type)
|
||||
if needs_upcast:
|
||||
result = convert_element_type(result, a_dtype)
|
||||
return result
|
||||
return up_and_broadcast
|
||||
|
||||
def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array:
|
||||
r"""Elementwise regularized incomplete beta integral."""
|
||||
a, b, x = core.standard_insert_pbroadcast(a, b, x)
|
||||
@ -71,10 +90,11 @@ def igamma_grad_a(a: ArrayLike, x: ArrayLike) -> Array:
|
||||
a, x = core.standard_insert_pbroadcast(a, x)
|
||||
return igamma_grad_a_p.bind(a, x)
|
||||
|
||||
def random_gamma_grad(a: ArrayLike, x: ArrayLike) -> Array:
|
||||
@_up_and_broadcast
|
||||
def random_gamma_grad(a: ArrayLike, x: ArrayLike, *, dtype) -> Array:
|
||||
r"""Elementwise derivative of samples from `Gamma(a, 1)`."""
|
||||
a, x = core.standard_insert_pbroadcast(a, x)
|
||||
return random_gamma_grad_p.bind(a, x)
|
||||
return random_gamma_grad_impl(a, x, dtype=dtype)
|
||||
|
||||
def zeta(x: ArrayLike, q: ArrayLike) -> Array:
|
||||
r"""Elementwise Hurwitz zeta function: :math:`\zeta(x, q)`"""
|
||||
@ -531,24 +551,6 @@ def random_gamma_grad_impl(a, x, *, dtype):
|
||||
full_like(a, float('nan')), output)
|
||||
return output
|
||||
|
||||
def _up_and_broadcast(doit):
|
||||
def up_and_broadcast(*args):
|
||||
broadcasted_shape = broadcast_shapes(*(a.shape for a in args))
|
||||
args = [broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim))) for a in args]
|
||||
|
||||
a_dtype = args[0].dtype
|
||||
needs_upcast = a_dtype == dtypes.bfloat16 or a_dtype == np.float16
|
||||
if needs_upcast:
|
||||
args = [convert_element_type(a, np.float32) for a in args]
|
||||
a_x_type = np.float32
|
||||
else:
|
||||
a_x_type = a_dtype
|
||||
result = doit(*args, dtype=a_x_type)
|
||||
if needs_upcast:
|
||||
result = convert_element_type(result, a_dtype)
|
||||
return result
|
||||
return up_and_broadcast
|
||||
|
||||
|
||||
def evaluate_chebyshev_polynomial(x, coefficients):
|
||||
b0 = full_like(x,0)
|
||||
@ -694,11 +696,6 @@ mlir.register_lowering(igammac_p,
|
||||
|
||||
ad.defjvp(igammac_p, igammac_grada, igammac_gradx)
|
||||
|
||||
random_gamma_grad_p = standard_naryop([_float, _float], 'random_gamma_grad')
|
||||
mlir.register_lowering(random_gamma_grad_p,
|
||||
mlir.lower_fun(_up_and_broadcast(random_gamma_grad_impl),
|
||||
multiple_results=False))
|
||||
|
||||
zeta_p = standard_naryop([_float, _float], 'zeta')
|
||||
mlir.register_lowering(zeta_p, partial(_nary_lower_hlo, chlo.zeta))
|
||||
|
||||
|
@ -149,7 +149,6 @@ from jax._src.lax.special import (
|
||||
igamma_p as igamma_p,
|
||||
lgamma_p as lgamma_p,
|
||||
polygamma_p as polygamma_p,
|
||||
random_gamma_grad_p as random_gamma_grad_p,
|
||||
regularized_incomplete_beta_p as regularized_incomplete_beta_p,
|
||||
zeta_p as zeta_p,
|
||||
)
|
||||
|
@ -261,7 +261,6 @@ from jax._src.lax.special import (
|
||||
polygamma as polygamma,
|
||||
polygamma_p as polygamma_p,
|
||||
random_gamma_grad as random_gamma_grad,
|
||||
random_gamma_grad_p as random_gamma_grad_p,
|
||||
regularized_incomplete_beta_p as regularized_incomplete_beta_p,
|
||||
zeta as zeta,
|
||||
zeta_p as zeta_p,
|
||||
|
Loading…
x
Reference in New Issue
Block a user