Added random.generalized_normal and random.ball.

This commit is contained in:
carlosgmartin 2022-06-03 15:11:29 -04:00
parent 2765293746
commit ca83a80f95
5 changed files with 103 additions and 0 deletions

View File

@ -38,6 +38,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* Changed the semantics of {func}`jax.profiler.start_server(...)` to store the
keepalive globally, rather than requiring the user to keep a reference to
it.
* Added {func}`jax.random.generalized_normal`.
* Added {func}`jax.random.ball`.
## jaxlib 0.3.11 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).

View File

@ -16,6 +16,7 @@ List of Available Functions
:toctree: _autosummary
PRNGKey
ball
bernoulli
beta
categorical
@ -26,6 +27,7 @@ List of Available Functions
exponential
fold_in
gamma
generalized_normal
gumbel
laplace
loggamma

View File

@ -1623,3 +1623,56 @@ def orthogonal(
q, r = jnp.linalg.qr(z)
d = jnp.diagonal(r, 0, -2, -1)
return lax.mul(q, lax.expand_dims(lax.div(d, abs(d).astype(d.dtype)), [-2]))
def generalized_normal(
key: KeyArray,
p: float,
shape: Sequence[int] = (),
dtype: DTypeLikeFloat = dtypes.float_
) -> jnp.ndarray:
"""Sample from the generalized normal distribution.
Args:
key: a PRNG key used as the random key.
p: a float representing the shape parameter.
shape: optional, the batch dimensions of the result. Default ().
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified shape and dtype.
"""
_check_shape("generalized_normal", shape)
keys = split(key)
g = gamma(keys[0], 1/p, shape, dtype)
r = rademacher(keys[1], shape, dtype)
return r * g ** (1 / p)
def ball(
key: KeyArray,
d: int,
p: float = 2,
shape: Sequence[int] = (),
dtype: DTypeLikeFloat = dtypes.float_
):
"""Sample uniformly from the unit Lp ball.
Reference: https://arxiv.org/abs/math/0503650.
Args:
key: a PRNG key used as the random key.
d: a nonnegative int representing the dimensionality of the ball.
p: a float representing the p parameter of the Lp norm.
shape: optional, the batch dimensions of the result. Default ().
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array of shape `(*shape, d)` and specified dtype.
"""
_check_shape("ball", shape)
d = core.concrete_or_error(index, d, "The error occurred in jax.random.ball()")
keys = split(key)
g = generalized_normal(keys[0], p, (*shape, d), dtype)
e = exponential(keys[1], shape, dtype)
return g / (((jnp.abs(g) ** p).sum(-1) + e) ** (1 / p))[..., None]

View File

@ -125,6 +125,7 @@ KeyArray = PRNGKeyArray
from jax._src.random import (
PRNGKey as PRNGKey,
ball as ball,
bernoulli as bernoulli,
beta as beta,
categorical as categorical,
@ -136,6 +137,7 @@ from jax._src.random import (
exponential as exponential,
fold_in as fold_in,
gamma as gamma,
generalized_normal as generalized_normal,
gumbel as gumbel,
laplace as laplace,
logistic as logistic,

View File

@ -1056,6 +1056,50 @@ class LaxRandomTest(jtu.JaxTestCase):
atol=tol, rtol=tol,
)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_p={}_shape={}"\
.format(p, jtu.format_shape_dtype_string(shape, dtype)),
"p": p,
"shape": shape,
"dtype": dtype}
for p in [.5, 1., 1.5, 2., 2.5]
for shape in [(), (5,), (10, 5)]
for dtype in jtu.dtypes.floating))
def testGeneralizedNormal(self, p, shape, dtype):
key = self.seed_prng(0)
rand = lambda key, p: random.generalized_normal(key, p, shape, dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key, p)
compiled_samples = crand(key, p)
for samples in [uncompiled_samples, compiled_samples]:
self.assertEqual(samples.shape, shape)
self.assertEqual(samples.dtype, dtype)
self._CheckKolmogorovSmirnovCDF(samples.ravel(), scipy.stats.gennorm(p).cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_d={}_p={}_shape={}"\
.format(d, p, jtu.format_shape_dtype_string(shape, dtype)),
"d": d,
"p": p,
"shape": shape,
"dtype": dtype}
for d in range(1, 5)
for p in [.5, 1., 1.5, 2., 2.5]
for shape in [(), (5,), (10, 5)]
for dtype in jtu.dtypes.floating))
def testBall(self, d, p, shape, dtype):
key = self.seed_prng(0)
rand = lambda key, p: random.ball(key, d, p, shape, dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key, p)
compiled_samples = crand(key, p)
for samples in [uncompiled_samples, compiled_samples]:
self.assertEqual(samples.shape, (*shape, d))
self.assertEqual(samples.dtype, dtype)
self.assertTrue(((jnp.abs(samples) ** p).sum(-1) <= 1).all())
norms = (jnp.abs(samples) ** p).sum(-1) ** (d / p)
self._CheckKolmogorovSmirnovCDF(norms.ravel(), scipy.stats.uniform().cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_b={b}_dtype={np.dtype(dtype).name}",
"b": b, "dtype": dtype}