mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add random_gamma_grad and use in jax.random.gamma (#3281)
This commit is contained in:
parent
0da7b4d1bd
commit
927c209148
@ -190,6 +190,8 @@ from .lax import (
|
||||
pow,
|
||||
pow_p,
|
||||
prod,
|
||||
random_gamma_grad,
|
||||
random_gamma_grad_p,
|
||||
real,
|
||||
real_p,
|
||||
reciprocal,
|
||||
|
@ -196,6 +196,10 @@ def igamma_grad_a(a: Array, x: Array) -> Array:
|
||||
r"""Elementwise derivative of the regularized incomplete gamma function."""
|
||||
return igamma_grad_a_p.bind(a, x)
|
||||
|
||||
def random_gamma_grad(a: Array, x: Array) -> Array:
|
||||
r"""Elementwise derivative of samples from `Gamma(a, 1)`."""
|
||||
return random_gamma_grad_p.bind(a, x)
|
||||
|
||||
def bessel_i0e(x: Array) -> Array:
|
||||
r"""Exponentially scaled modified Bessel function of order 0:
|
||||
:math:`\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)`
|
||||
@ -1999,6 +2003,10 @@ def igammac_grada(g, a, x):
|
||||
|
||||
ad.defjvp(igammac_p, igammac_grada, igammac_gradx)
|
||||
|
||||
random_gamma_grad_p = standard_naryop([_float, _float], 'random_gamma_grad',
|
||||
translation_rule=_broadcast_translate(partial(standard_translate,
|
||||
'random_gamma_grad')))
|
||||
|
||||
bessel_i0e_p = standard_unop(_float, 'bessel_i0e')
|
||||
ad.defjvp2(bessel_i0e_p, lambda g, y, x: g * (bessel_i1e(x) - sign(x) * y))
|
||||
|
||||
|
105
jax/random.py
105
jax/random.py
@ -904,115 +904,14 @@ def _gamma_one(key, alpha):
|
||||
z = lax.mul(lax.mul(d, V), boost)
|
||||
return lax.select(lax.eq(z, zero), jnp.finfo(z.dtype).tiny, z)
|
||||
|
||||
_bivariate_coef = [[0.16009398, -0.094634816, 0.025146379, -0.0030648348,
|
||||
1, 0.3266811, 0.10406087, 0.0014179033],
|
||||
[0.53487893, 0.12980707, 0.06573594, -0.0015649787,
|
||||
0.16639465, 0.020070098, -0.0035938937, -0.00058392601],
|
||||
[0.040121005, -0.0065914079, -0.002628604, -0.0013441777,
|
||||
0.017050642, -0.0021309345, 0.00085092385, -1.5248239e-07]]
|
||||
|
||||
def _gamma_grad_one(z, alpha):
|
||||
# Ref 1: Pathwise Derivatives Beyond the Reparameterization Trick, Martin & Fritz
|
||||
# Ref 2: Case 4 follows https://github.com/fritzo/notebooks/blob/master/gamma-reparameterized.ipynb
|
||||
|
||||
# TODO: use lax.cond instead of lax.while_loop when its batching rule is available
|
||||
# See https://github.com/google/jax/issues/490
|
||||
def _case1(zagf):
|
||||
z, alpha, _, flag = zagf
|
||||
|
||||
# dz = - dCDF(z; a) / pdf(z; a)
|
||||
# pdf = z^(a-1) * e^(-z) / Gamma(a)
|
||||
# CDF(z; a) = IncompleteGamma(a, z) / Gamma(a)
|
||||
# dCDF(z; a) = (dIncompleteGamma - IncompleteGamma * Digamma(a)) / Gamma(a)
|
||||
# =: unnormalized_dCDF / Gamma(a)
|
||||
# IncompleteGamma ~ z^a [ 1/a - z/(a+1) + z^2/2!(a+2) - z^3/3!(a+3) + z^4/4!(a+4) - z^5/5!(a+5) ]
|
||||
# =: z^a * term1
|
||||
# dIncompleteGamma ~ z^a * log(z) * term1 - z^a [1/a^2 - z/(a+1)^2 + z^2/2!(a+2)^2
|
||||
# - z^3/3!(a+3)^2 + z^4/4!(a+4)^2 - z^5/5!(a+5)^2 ]
|
||||
# =: z^a * log(z) * term1 - z^a * term2
|
||||
# unnormalized_dCDF = z^a { [log(z) - Digamma(a)] * term1 - term2 }
|
||||
zi = 1.0
|
||||
update = zi / alpha
|
||||
term1 = update
|
||||
term2 = update / alpha
|
||||
for i in range(1, 6):
|
||||
zi = -zi * z / i
|
||||
update = zi / (alpha + i)
|
||||
term1 = term1 + update
|
||||
term2 = term2 + update / (alpha + i)
|
||||
|
||||
unnormalized_cdf_dot = jnp.power(z, alpha) * ((jnp.log(z) - lax.digamma(alpha)) * term1 - term2)
|
||||
unnormalized_pdf = jnp.power(z, alpha - 1) * jnp.exp(-z)
|
||||
grad = -unnormalized_cdf_dot / unnormalized_pdf
|
||||
|
||||
return z, alpha, grad, ~flag
|
||||
|
||||
def _cond2(zagf):
|
||||
z, alpha, _, flag = zagf
|
||||
return (~flag) & (alpha > 8.0) & ((z < 0.9 * alpha) | (z > 1.1 * alpha))
|
||||
|
||||
def _case2(zagf):
|
||||
z, alpha, _, flag = zagf
|
||||
|
||||
# Formula 58 of [1]
|
||||
sqrt_8a = jnp.sqrt(8 * alpha)
|
||||
z_minus_a = z - alpha
|
||||
log_z_div_a = jnp.log(z / alpha)
|
||||
sign = jnp.where(z < alpha, lax._const(z, 1.0), lax._const(z, -1.0))
|
||||
term1 = 4 * (z + alpha) / (sqrt_8a * z_minus_a * z_minus_a)
|
||||
term2 = log_z_div_a * (sqrt_8a / z_minus_a + sign * jnp.power(z_minus_a - alpha * log_z_div_a, -1.5))
|
||||
term3 = z * (1.0 + 1.0 / (12 * alpha) + 1.0 / (288 * alpha * alpha)) / sqrt_8a
|
||||
grad = (term1 + term2) * term3
|
||||
|
||||
return z, alpha, grad, ~flag
|
||||
|
||||
def _cond3(zagf):
|
||||
z, alpha, _, flag = zagf
|
||||
return (~flag) & (alpha > 8.0) & (z >= 0.9 * alpha) & (z <= 1.1 * alpha)
|
||||
|
||||
def _case3(zagf):
|
||||
z, alpha, _, flag = zagf
|
||||
|
||||
# Formula 59 of [1]
|
||||
z_div_a = jnp.divide(z, alpha)
|
||||
aa = alpha * alpha
|
||||
term1 = 1440 * alpha + 6 * z_div_a * (53 - 120 * z) - 65 * z_div_a * z_div_a + 3600 * z + 107
|
||||
term2 = 1244160 * alpha * aa
|
||||
term3 = 1 + 24 * alpha + 288 * aa
|
||||
grad = term1 * term3 / term2
|
||||
|
||||
return z, alpha, grad, ~flag
|
||||
|
||||
def _case4(zagf):
|
||||
z, alpha, _, flag = zagf
|
||||
|
||||
# Ref [2]
|
||||
u = jnp.log(z / alpha)
|
||||
v = jnp.log(alpha)
|
||||
c = []
|
||||
for i in range(8):
|
||||
c.append(_bivariate_coef[0][i] + u * (_bivariate_coef[1][i] + u * _bivariate_coef[2][i]))
|
||||
p = c[0] + v * (c[1] + v * (c[2] + v * c[3]))
|
||||
q = c[4] + v * (c[5] + v * (c[6] + v * c[7]))
|
||||
grad = jnp.exp(p / jnp.maximum(q, 0.01))
|
||||
|
||||
return z, alpha, grad, ~flag
|
||||
|
||||
_, _, grad, flag = lax.while_loop(lambda zagf: (~zagf[3]) & (zagf[0] < 0.8),
|
||||
_case1,
|
||||
(z, alpha, lax._const(alpha, 0.0), False))
|
||||
_, _, grad, flag = lax.while_loop(_cond2, _case2, (z, alpha, grad, flag))
|
||||
_, _, grad, flag = lax.while_loop(_cond3, _case3, (z, alpha, grad, flag))
|
||||
_, _, grad, flag = lax.while_loop(lambda zagf: ~zagf[3], _case4, (z, alpha, grad, flag))
|
||||
return grad
|
||||
|
||||
def _gamma_grad(sample, a):
|
||||
samples = jnp.reshape(sample, -1)
|
||||
alphas = jnp.reshape(a, -1)
|
||||
if xla_bridge.get_backend().platform == 'cpu':
|
||||
grads = lax.map(lambda args: _gamma_grad_one(*args), (samples, alphas))
|
||||
grads = lax.map(lambda args: lax.random_gamma_grad(*args), (alphas, samples))
|
||||
else:
|
||||
grads = vmap(_gamma_grad_one)(samples, alphas)
|
||||
grads = vmap(lax.random_gamma_grad)(alphas, samples)
|
||||
return grads.reshape(np.shape(a))
|
||||
|
||||
def _gamma_impl(key, a):
|
||||
|
@ -440,8 +440,8 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
pdf = scipy.stats.gamma.pdf(z, alpha)
|
||||
expected_grad = -cdf_dot / pdf
|
||||
|
||||
self.assertAllClose(actual_grad, expected_grad,
|
||||
rtol=2e-2 if jtu.device_under_test() == "tpu" else 5e-4)
|
||||
self.assertAllClose(actual_grad, expected_grad, check_dtypes=True,
|
||||
rtol=2e-2 if jtu.device_under_test() == "tpu" else 7e-4)
|
||||
|
||||
def testGammaGradType(self):
|
||||
# Regression test for https://github.com/google/jax/issues/2130
|
||||
|
Loading…
x
Reference in New Issue
Block a user