mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Added random.generalized_normal and random.ball.
This commit is contained in:
parent
2765293746
commit
ca83a80f95
@ -38,6 +38,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* Changed the semantics of {func}`jax.profiler.start_server(...)` to store the
|
||||
keepalive globally, rather than requiring the user to keep a reference to
|
||||
it.
|
||||
* Added {func}`jax.random.generalized_normal`.
|
||||
* Added {func}`jax.random.ball`.
|
||||
|
||||
## jaxlib 0.3.11 (Unreleased)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).
|
||||
|
@ -16,6 +16,7 @@ List of Available Functions
|
||||
:toctree: _autosummary
|
||||
|
||||
PRNGKey
|
||||
ball
|
||||
bernoulli
|
||||
beta
|
||||
categorical
|
||||
@ -26,6 +27,7 @@ List of Available Functions
|
||||
exponential
|
||||
fold_in
|
||||
gamma
|
||||
generalized_normal
|
||||
gumbel
|
||||
laplace
|
||||
loggamma
|
||||
|
@ -1623,3 +1623,56 @@ def orthogonal(
|
||||
q, r = jnp.linalg.qr(z)
|
||||
d = jnp.diagonal(r, 0, -2, -1)
|
||||
return lax.mul(q, lax.expand_dims(lax.div(d, abs(d).astype(d.dtype)), [-2]))
|
||||
|
||||
def generalized_normal(
|
||||
key: KeyArray,
|
||||
p: float,
|
||||
shape: Sequence[int] = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_
|
||||
) -> jnp.ndarray:
|
||||
"""Sample from the generalized normal distribution.
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
p: a float representing the shape parameter.
|
||||
shape: optional, the batch dimensions of the result. Default ().
|
||||
dtype: optional, a float dtype for the returned values (default float64 if
|
||||
jax_enable_x64 is true, otherwise float32).
|
||||
|
||||
Returns:
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
_check_shape("generalized_normal", shape)
|
||||
keys = split(key)
|
||||
g = gamma(keys[0], 1/p, shape, dtype)
|
||||
r = rademacher(keys[1], shape, dtype)
|
||||
return r * g ** (1 / p)
|
||||
|
||||
def ball(
|
||||
key: KeyArray,
|
||||
d: int,
|
||||
p: float = 2,
|
||||
shape: Sequence[int] = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_
|
||||
):
|
||||
"""Sample uniformly from the unit Lp ball.
|
||||
|
||||
Reference: https://arxiv.org/abs/math/0503650.
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
d: a nonnegative int representing the dimensionality of the ball.
|
||||
p: a float representing the p parameter of the Lp norm.
|
||||
shape: optional, the batch dimensions of the result. Default ().
|
||||
dtype: optional, a float dtype for the returned values (default float64 if
|
||||
jax_enable_x64 is true, otherwise float32).
|
||||
|
||||
Returns:
|
||||
A random array of shape `(*shape, d)` and specified dtype.
|
||||
"""
|
||||
_check_shape("ball", shape)
|
||||
d = core.concrete_or_error(index, d, "The error occurred in jax.random.ball()")
|
||||
keys = split(key)
|
||||
g = generalized_normal(keys[0], p, (*shape, d), dtype)
|
||||
e = exponential(keys[1], shape, dtype)
|
||||
return g / (((jnp.abs(g) ** p).sum(-1) + e) ** (1 / p))[..., None]
|
||||
|
@ -125,6 +125,7 @@ KeyArray = PRNGKeyArray
|
||||
|
||||
from jax._src.random import (
|
||||
PRNGKey as PRNGKey,
|
||||
ball as ball,
|
||||
bernoulli as bernoulli,
|
||||
beta as beta,
|
||||
categorical as categorical,
|
||||
@ -136,6 +137,7 @@ from jax._src.random import (
|
||||
exponential as exponential,
|
||||
fold_in as fold_in,
|
||||
gamma as gamma,
|
||||
generalized_normal as generalized_normal,
|
||||
gumbel as gumbel,
|
||||
laplace as laplace,
|
||||
logistic as logistic,
|
||||
|
@ -1056,6 +1056,50 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
atol=tol, rtol=tol,
|
||||
)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_p={}_shape={}"\
|
||||
.format(p, jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"p": p,
|
||||
"shape": shape,
|
||||
"dtype": dtype}
|
||||
for p in [.5, 1., 1.5, 2., 2.5]
|
||||
for shape in [(), (5,), (10, 5)]
|
||||
for dtype in jtu.dtypes.floating))
|
||||
def testGeneralizedNormal(self, p, shape, dtype):
|
||||
key = self.seed_prng(0)
|
||||
rand = lambda key, p: random.generalized_normal(key, p, shape, dtype)
|
||||
crand = jax.jit(rand)
|
||||
uncompiled_samples = rand(key, p)
|
||||
compiled_samples = crand(key, p)
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self.assertEqual(samples.shape, shape)
|
||||
self.assertEqual(samples.dtype, dtype)
|
||||
self._CheckKolmogorovSmirnovCDF(samples.ravel(), scipy.stats.gennorm(p).cdf)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_d={}_p={}_shape={}"\
|
||||
.format(d, p, jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"d": d,
|
||||
"p": p,
|
||||
"shape": shape,
|
||||
"dtype": dtype}
|
||||
for d in range(1, 5)
|
||||
for p in [.5, 1., 1.5, 2., 2.5]
|
||||
for shape in [(), (5,), (10, 5)]
|
||||
for dtype in jtu.dtypes.floating))
|
||||
def testBall(self, d, p, shape, dtype):
|
||||
key = self.seed_prng(0)
|
||||
rand = lambda key, p: random.ball(key, d, p, shape, dtype)
|
||||
crand = jax.jit(rand)
|
||||
uncompiled_samples = rand(key, p)
|
||||
compiled_samples = crand(key, p)
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self.assertEqual(samples.shape, (*shape, d))
|
||||
self.assertEqual(samples.dtype, dtype)
|
||||
self.assertTrue(((jnp.abs(samples) ** p).sum(-1) <= 1).all())
|
||||
norms = (jnp.abs(samples) ** p).sum(-1) ** (d / p)
|
||||
self._CheckKolmogorovSmirnovCDF(norms.ravel(), scipy.stats.uniform().cdf)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_b={b}_dtype={np.dtype(dtype).name}",
|
||||
"b": b, "dtype": dtype}
|
||||
|
Loading…
x
Reference in New Issue
Block a user