mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
add random.loggamma and improve dirichlet & beta implementation
This commit is contained in:
parent
1ffa285bd6
commit
69969ef803
@ -8,7 +8,12 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
-->
|
||||
|
||||
## jax 0.3.4 (Unreleased)
|
||||
## jax 0.3.5 (Unreleased)
|
||||
* [GitHub
|
||||
commits](https://github.com/google/jax/compare/jax-v0.3.4...main).
|
||||
* Changes:
|
||||
* added {func}`jax.random.loggamma` & improved behavior of {func}`jax.random.beta`
|
||||
and {func}`jax.random.dirichlet` for small parameter values `({jax-issue}`9906`).
|
||||
|
||||
|
||||
## jaxlib 0.3.3 (Unreleased)
|
||||
@ -16,7 +21,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
|
||||
## jax 0.3.4 (March 18, 2022)
|
||||
* [GitHub
|
||||
commits](https://github.com/google/jax/compare/jax-v0.3.2...jax-v0.3.4).
|
||||
commits](https://github.com/google/jax/compare/jax-v0.3.3...jax-v0.3.4).
|
||||
|
||||
|
||||
## jax 0.3.3 (March 17, 2022)
|
||||
|
@ -28,6 +28,7 @@ List of Available Functions
|
||||
gamma
|
||||
gumbel
|
||||
laplace
|
||||
loggamma
|
||||
logistic
|
||||
maxwell
|
||||
multivariate_normal
|
||||
|
@ -749,6 +749,7 @@ def beta(key: KeyArray,
|
||||
shape = core.canonicalize_shape(shape)
|
||||
return _beta(key, a, b, shape, dtype)
|
||||
|
||||
|
||||
def _beta(key, a, b, shape, dtype):
|
||||
if shape is None:
|
||||
shape = lax.broadcast_shapes(np.shape(a), np.shape(b))
|
||||
@ -760,9 +761,13 @@ def _beta(key, a, b, shape, dtype):
|
||||
key_a, key_b = _split(key)
|
||||
a = jnp.broadcast_to(a, shape)
|
||||
b = jnp.broadcast_to(b, shape)
|
||||
gamma_a = gamma(key_a, a, shape, dtype)
|
||||
gamma_b = gamma(key_b, b, shape, dtype)
|
||||
return gamma_a / (gamma_a + gamma_b)
|
||||
log_gamma_a = loggamma(key_a, a, shape, dtype)
|
||||
log_gamma_b = loggamma(key_b, b, shape, dtype)
|
||||
# Compute gamma_a / (gamma_a + gamma_b) without losing precision.
|
||||
log_max = lax.max(log_gamma_a, log_gamma_b)
|
||||
gamma_a_scaled = jnp.exp(log_gamma_a - log_max)
|
||||
gamma_b_scaled = jnp.exp(log_gamma_b - log_max)
|
||||
return gamma_a_scaled / (gamma_a_scaled + gamma_b_scaled)
|
||||
|
||||
|
||||
def cauchy(key: KeyArray,
|
||||
@ -840,8 +845,19 @@ def _dirichlet(key, alpha, shape, dtype):
|
||||
_check_shape("dirichlet", shape, np.shape(alpha)[:-1])
|
||||
|
||||
alpha = lax.convert_element_type(alpha, dtype)
|
||||
gamma_samples = gamma(key, alpha, shape + np.shape(alpha)[-1:], dtype)
|
||||
return gamma_samples / jnp.sum(gamma_samples, axis=-1, keepdims=True)
|
||||
|
||||
# Compute gamma in log space, otherwise small alpha can lead to poor behavior.
|
||||
log_gamma_samples = loggamma(key, alpha, shape + np.shape(alpha)[-1:], dtype)
|
||||
return _softmax(log_gamma_samples, -1)
|
||||
|
||||
|
||||
def _softmax(x, axis):
|
||||
"""Utility to compute the softmax of x along a given axis."""
|
||||
if not dtypes.issubdtype(x.dtype, np.floating):
|
||||
raise TypeError(f"_softmax only accepts floating dtypes, got {x.dtype}")
|
||||
x_max = jnp.max(x, axis, keepdims=True)
|
||||
unnormalized = jnp.exp(x - lax.stop_gradient(x_max))
|
||||
return unnormalized / unnormalized.sum(axis, keepdims=True)
|
||||
|
||||
|
||||
def exponential(key: KeyArray,
|
||||
@ -875,7 +891,7 @@ def _exponential(key, shape, dtype):
|
||||
return lax.neg(lax.log1p(lax.neg(u)))
|
||||
|
||||
|
||||
def _gamma_one(key: KeyArray, alpha):
|
||||
def _gamma_one(key: KeyArray, alpha, log_space):
|
||||
# Ref: A simple method for generating gamma variables, George Marsaglia and Wai Wan Tsang
|
||||
# The algorithm can also be founded in:
|
||||
# https://en.wikipedia.org/wiki/Gamma_distribution#Generating_gamma-distributed_random_variables
|
||||
@ -887,13 +903,20 @@ def _gamma_one(key: KeyArray, alpha):
|
||||
squeeze_const = _lax_const(alpha, 0.0331)
|
||||
dtype = lax.dtype(alpha)
|
||||
|
||||
key, subkey = _split(key)
|
||||
# for alpha < 1, we boost alpha to alpha + 1 and get a sample according to
|
||||
# Gamma(alpha) ~ Gamma(alpha+1) * Uniform()^(1 / alpha)
|
||||
boost = lax.select(lax.ge(alpha, one),
|
||||
one,
|
||||
lax.pow(uniform(subkey, (), dtype=dtype), lax.div(one, alpha)))
|
||||
alpha = lax.select(lax.ge(alpha, one), alpha, lax.add(alpha, one))
|
||||
# Gamma(alpha) ~ Gamma(alpha+1) * Uniform()^(1 / alpha)
|
||||
# When alpha is very small, this boost can be problematic because it may result
|
||||
# in floating point underflow; for this reason we compute it in log space if
|
||||
# specified by the `log_space` argument:
|
||||
# log[Gamma(alpha)] ~ log[Gamma(alpha + 1)] + log[Uniform()] / alpha
|
||||
# Note that log[Uniform()] ~ Exponential(), but the exponential() function is
|
||||
# computed via log[1 - Uniform()] to avoid taking log(0). We want the generated
|
||||
# sequence to match between log_space=True and log_space=False, so we avoid this
|
||||
# for now to maintain backward compatibility with the original implementation.
|
||||
# TODO(jakevdp) should we change the convention to avoid -inf in log-space?
|
||||
boost_mask = lax.ge(alpha, one)
|
||||
alpha_orig = alpha
|
||||
alpha = lax.select(boost_mask, alpha, lax.add(alpha, one))
|
||||
|
||||
d = lax.sub(alpha, one_over_three)
|
||||
c = lax.div(one_over_three, lax.sqrt(d))
|
||||
@ -926,21 +949,42 @@ def _gamma_one(key: KeyArray, alpha):
|
||||
return key, X, V, U
|
||||
|
||||
# initial state is chosen such that _cond_fn will return True
|
||||
key, subkey = _split(key)
|
||||
u_boost = uniform(subkey, (), dtype=dtype)
|
||||
_, _, V, _ = lax.while_loop(_cond_fn, _body_fn, (key, zero, one, _lax_const(alpha, 2)))
|
||||
z = lax.mul(lax.mul(d, V), boost)
|
||||
return lax.select(lax.eq(z, zero), jnp.finfo(z.dtype).tiny, z)
|
||||
if log_space:
|
||||
# TODO(jakevdp): there are negative infinities here due to issues mentioned above. How should
|
||||
# we handle those?
|
||||
log_boost = lax.select(boost_mask, zero, lax.mul(lax.log(u_boost), lax.div(one, alpha_orig)))
|
||||
return lax.add(lax.add(lax.log(d), lax.log(V)), log_boost)
|
||||
else:
|
||||
boost = lax.select(boost_mask, one, lax.pow(u_boost, lax.div(one, alpha_orig)))
|
||||
z = lax.mul(lax.mul(d, V), boost)
|
||||
return lax.select(lax.eq(z, zero), jnp.finfo(z.dtype).tiny, z)
|
||||
|
||||
|
||||
def _gamma_grad(sample, a):
|
||||
def _gamma_grad(sample, a, *, prng_impl, log_space):
|
||||
del prng_impl # unused
|
||||
samples = jnp.reshape(sample, -1)
|
||||
alphas = jnp.reshape(a, -1)
|
||||
if xla_bridge.get_backend().platform == 'cpu':
|
||||
grads = lax.map(lambda args: lax.random_gamma_grad(*args), (alphas, samples))
|
||||
if log_space:
|
||||
# d[log(sample)] = d[sample] / sample
|
||||
# This requires computing exp(log_sample), which may be zero due to float roundoff.
|
||||
# In this case, we use the same zero-correction used in gamma() above.
|
||||
samples = lax.exp(samples)
|
||||
zero = lax_internal._const(sample, 0)
|
||||
tiny = lax.full_like(samples, jnp.finfo(samples.dtype).tiny)
|
||||
samples = lax.select(lax.eq(samples, zero), tiny, samples)
|
||||
gamma_grad = lambda alpha, sample: lax.random_gamma_grad(alpha, sample) / sample
|
||||
else:
|
||||
grads = vmap(lax.random_gamma_grad)(alphas, samples)
|
||||
gamma_grad = lax.random_gamma_grad
|
||||
if xla_bridge.get_backend().platform == 'cpu':
|
||||
grads = lax.map(lambda args: gamma_grad(*args), (alphas, samples))
|
||||
else:
|
||||
grads = vmap(gamma_grad)(alphas, samples)
|
||||
return grads.reshape(np.shape(a))
|
||||
|
||||
def _gamma_impl(raw_key, a, *, prng_impl, use_vmap=False):
|
||||
def _gamma_impl(raw_key, a, *, prng_impl, log_space, use_vmap=False):
|
||||
a_shape = jnp.shape(a)
|
||||
# split key to match the shape of a
|
||||
key_ndim = len(raw_key.shape) - len(prng_impl.key_shape)
|
||||
@ -950,24 +994,24 @@ def _gamma_impl(raw_key, a, *, prng_impl, use_vmap=False):
|
||||
keys = prng.PRNGKeyArray(prng_impl, keys)
|
||||
alphas = jnp.reshape(a, -1)
|
||||
if use_vmap:
|
||||
samples = vmap(_gamma_one)(keys, alphas)
|
||||
samples = vmap(partial(_gamma_one, log_space=log_space))(keys, alphas)
|
||||
else:
|
||||
samples = lax.map(lambda args: _gamma_one(*args), (keys, alphas))
|
||||
samples = lax.map(lambda args: _gamma_one(*args, log_space=log_space), (keys, alphas))
|
||||
|
||||
return jnp.reshape(samples, a_shape)
|
||||
|
||||
def _gamma_batching_rule(batched_args, batch_dims, *, prng_impl):
|
||||
def _gamma_batching_rule(batched_args, batch_dims, *, prng_impl, log_space):
|
||||
k, a = batched_args
|
||||
bk, ba = batch_dims
|
||||
size = next(t.shape[i] for t, i in zip(batched_args, batch_dims) if i is not None)
|
||||
k = batching.bdim_at_front(k, bk, size)
|
||||
a = batching.bdim_at_front(a, ba, size)
|
||||
return random_gamma_p.bind(k, a, prng_impl=prng_impl), 0
|
||||
return random_gamma_p.bind(k, a, prng_impl=prng_impl, log_space=log_space), 0
|
||||
|
||||
random_gamma_p = core.Primitive('random_gamma')
|
||||
random_gamma_p.def_impl(_gamma_impl)
|
||||
random_gamma_p.def_abstract_eval(lambda key, a, **_: core.raise_to_shaped(a))
|
||||
ad.defjvp2(random_gamma_p, None, lambda tangent, ans, key, a, **_: tangent * _gamma_grad(ans, a))
|
||||
ad.defjvp2(random_gamma_p, None, lambda tangent, ans, key, a, **kwds: tangent * _gamma_grad(ans, a, **kwds))
|
||||
xla.register_translation(random_gamma_p, xla.lower_fun(
|
||||
partial(_gamma_impl, use_vmap=True),
|
||||
multiple_results=False, new_style=True))
|
||||
@ -995,6 +1039,10 @@ def gamma(key: KeyArray,
|
||||
Returns:
|
||||
A random array with the specified dtype and with shape given by ``shape`` if
|
||||
``shape`` is not None, or else by ``a.shape``.
|
||||
|
||||
See Also:
|
||||
loggamma : sample gamma values in log-space, which can provide improved
|
||||
accuracy for small values of ``a``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
@ -1003,10 +1051,52 @@ def gamma(key: KeyArray,
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
if shape is not None:
|
||||
shape = core.canonicalize_shape(shape)
|
||||
return _gamma(key, a, shape, dtype)
|
||||
return _gamma(key, a, shape=shape, dtype=dtype)
|
||||
|
||||
@partial(jit, static_argnums=(2, 3), inline=True)
|
||||
def _gamma(key, a, shape, dtype):
|
||||
|
||||
def loggamma(key: KeyArray,
|
||||
a: RealArray,
|
||||
shape: Optional[Sequence[int]] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
|
||||
"""Sample log-gamma random values with given shape and float dtype.
|
||||
|
||||
This function is implemented such that the following will hold for a
|
||||
dtype-appropriate tolerance::
|
||||
|
||||
np.testing.assert_allclose(jnp.exp(loggamma(*args)), gamma(*args), rtol=rtol)
|
||||
|
||||
The benefit of log-gamma is that for samples very close to zero (which occur frequently
|
||||
when `a << 1`) sampling in log space provides better precision.
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
a: a float or array of floats broadcast-compatible with ``shape``
|
||||
representing the parameter of the distribution.
|
||||
shape: optional, a tuple of nonnegative integers specifying the result
|
||||
shape. Must be broadcast-compatible with ``a``. The default (None)
|
||||
produces a result shape equal to ``a.shape``.
|
||||
dtype: optional, a float dtype for the returned values (default float64 if
|
||||
jax_enable_x64 is true, otherwise float32).
|
||||
|
||||
Returns:
|
||||
A random array with the specified dtype and with shape given by ``shape`` if
|
||||
``shape`` is not None, or else by ``a.shape``.
|
||||
|
||||
See Also:
|
||||
gamma : standard gamma sampler.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `gamma` must be a float "
|
||||
f"dtype, got {dtype}")
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
if shape is not None:
|
||||
shape = core.canonicalize_shape(shape)
|
||||
return _gamma(key, a, shape=shape, dtype=dtype, log_space=True)
|
||||
|
||||
|
||||
@partial(jit, static_argnames=('shape', 'dtype', 'log_space'), inline=True)
|
||||
def _gamma(key, a, shape, dtype, log_space=False):
|
||||
if shape is None:
|
||||
shape = np.shape(a)
|
||||
else:
|
||||
@ -1015,7 +1105,7 @@ def _gamma(key, a, shape, dtype):
|
||||
a = lax.convert_element_type(a, dtype)
|
||||
if np.shape(a) != shape:
|
||||
a = jnp.broadcast_to(a, shape)
|
||||
return random_gamma_p.bind(key.unsafe_raw_array(), a, prng_impl=key.impl)
|
||||
return random_gamma_p.bind(key.unsafe_raw_array(), a, prng_impl=key.impl, log_space=log_space)
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(2, 3, 4), inline=True)
|
||||
|
@ -98,6 +98,7 @@ from jax._src.random import (
|
||||
gumbel as gumbel,
|
||||
laplace as laplace,
|
||||
logistic as logistic,
|
||||
loggamma as loggamma,
|
||||
maxwell as maxwell,
|
||||
multivariate_normal as multivariate_normal,
|
||||
normal as normal,
|
||||
|
@ -648,6 +648,19 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.beta(a, b).cdf)
|
||||
|
||||
def testBetaSmallParameters(self, dtype=np.float32):
|
||||
# Regression test for beta version of https://github.com/google/jax/issues/9896
|
||||
key = self.seed_prng(0)
|
||||
a, b = 0.0001, 0.0002
|
||||
samples = random.beta(key, a, b, shape=(100,), dtype=dtype)
|
||||
|
||||
# With such small parameters, all samples should be exactly zero or one.
|
||||
zeros = samples[samples < 0.5]
|
||||
self.assertAllClose(zeros, jnp.zeros_like(zeros))
|
||||
|
||||
ones = samples[samples >= 0.5]
|
||||
self.assertAllClose(ones, jnp.ones_like(ones))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
|
||||
for dtype in float_dtypes))
|
||||
@ -684,6 +697,21 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
for i, a in enumerate(alpha):
|
||||
self._CheckKolmogorovSmirnovCDF(samples[..., i], scipy.stats.beta(a, alpha_sum - a).cdf)
|
||||
|
||||
def testDirichletSmallAlpha(self, dtype=np.float32):
|
||||
# Regression test for https://github.com/google/jax/issues/9896
|
||||
key = self.seed_prng(0)
|
||||
alpha = 0.0001 * jnp.ones(3)
|
||||
samples = random.dirichlet(key, alpha, shape=(100,), dtype=dtype)
|
||||
|
||||
# Check that results lie on the simplex.
|
||||
self.assertAllClose(samples.sum(1), jnp.ones(samples.shape[0]),
|
||||
check_dtypes=False, rtol=1E-5)
|
||||
|
||||
# Check that results contain 1 in one of the dimensions:
|
||||
# this is highly likely to be true when alpha is small.
|
||||
self.assertAllClose(samples.max(1), jnp.ones(samples.shape[0]),
|
||||
check_dtypes=False, rtol=1E-5)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
|
||||
for dtype in float_dtypes))
|
||||
@ -698,6 +726,22 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.expon().cdf)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_a={}_dtype={}_prng={}".format(a, np.dtype(dtype).name,
|
||||
prng_name),
|
||||
"a": a, "dtype": dtype, "prng_impl": prng_impl}
|
||||
for prng_name, prng_impl in PRNG_IMPLS
|
||||
for a in [0.1, 1., 10.]
|
||||
for dtype in jtu.dtypes.floating))
|
||||
def testGammaVsLogGamma(self, prng_impl, a, dtype):
|
||||
key = prng.seed_with_impl(prng_impl, 0)
|
||||
rand_gamma = lambda key, a: random.gamma(key, a, (10000,), dtype)
|
||||
rand_loggamma = lambda key, a: random.loggamma(key, a, (10000,), dtype)
|
||||
crand_loggamma = jax.jit(rand_loggamma)
|
||||
|
||||
self.assertAllClose(rand_gamma(key, a), jnp.exp(rand_loggamma(key, a)))
|
||||
self.assertAllClose(rand_gamma(key, a), jnp.exp(crand_loggamma(key, a)))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_a={}_dtype={}_prng={}".format(a, np.dtype(dtype).name,
|
||||
prng_name),
|
||||
@ -722,15 +766,22 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
assert x.shape == (3, 2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_a={}_prng={}".format(alpha, prng_name),
|
||||
"alpha": alpha, "prng_impl": prng_impl}
|
||||
{"testcase_name": "_a={}_prng={}_logspace={}".format(alpha, prng_name, log_space),
|
||||
"alpha": alpha, "log_space": log_space, "prng_impl": prng_impl}
|
||||
for prng_name, prng_impl in PRNG_IMPLS
|
||||
for log_space in [True, False]
|
||||
for alpha in [1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]))
|
||||
def testGammaGrad(self, prng_impl, alpha):
|
||||
def testGammaGrad(self, log_space, prng_impl, alpha):
|
||||
rng = prng.seed_with_impl(prng_impl, 0)
|
||||
alphas = np.full((100,), alpha)
|
||||
z = random.gamma(rng, alphas)
|
||||
actual_grad = jax.grad(lambda x: random.gamma(rng, x).sum())(alphas)
|
||||
if log_space:
|
||||
actual_grad = jax.grad(lambda x: lax.exp(random.loggamma(rng, x)).sum())(alphas)
|
||||
# TODO(jakevdp): this NaN correction is required because we generate negative infinities
|
||||
# in the log-space computation; see related TODO in the source of random._gamma_one().
|
||||
actual_grad = jnp.where(jnp.isnan(actual_grad), 0.0, actual_grad)
|
||||
else:
|
||||
actual_grad = jax.grad(lambda x: random.gamma(rng, x).sum())(alphas)
|
||||
|
||||
eps = 0.01 * alpha / (1.0 + np.sqrt(alpha))
|
||||
cdf_dot = (scipy.stats.gamma.cdf(z, alpha + eps)
|
||||
|
Loading…
x
Reference in New Issue
Block a user