From ca83a80f9596263a5639ccfc6ce82e77e72de458 Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Fri, 3 Jun 2022 15:11:29 -0400 Subject: [PATCH] Added random.generalized_normal and random.ball. --- CHANGELOG.md | 2 ++ docs/jax.random.rst | 2 ++ jax/_src/random.py | 53 ++++++++++++++++++++++++++++++++++++++++++++ jax/random.py | 2 ++ tests/random_test.py | 44 ++++++++++++++++++++++++++++++++++++ 5 files changed, 103 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 05887f5e6..2a8fc7196 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. * Changed the semantics of {func}`jax.profiler.start_server(...)` to store the keepalive globally, rather than requiring the user to keep a reference to it. + * Added {func}`jax.random.generalized_normal`. + * Added {func}`jax.random.ball`. ## jaxlib 0.3.11 (Unreleased) * [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main). diff --git a/docs/jax.random.rst b/docs/jax.random.rst index dc578aee7..e1e12193f 100644 --- a/docs/jax.random.rst +++ b/docs/jax.random.rst @@ -16,6 +16,7 @@ List of Available Functions :toctree: _autosummary PRNGKey + ball bernoulli beta categorical @@ -26,6 +27,7 @@ List of Available Functions exponential fold_in gamma + generalized_normal gumbel laplace loggamma diff --git a/jax/_src/random.py b/jax/_src/random.py index 1efd55c3d..cc6322df7 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -1623,3 +1623,56 @@ def orthogonal( q, r = jnp.linalg.qr(z) d = jnp.diagonal(r, 0, -2, -1) return lax.mul(q, lax.expand_dims(lax.div(d, abs(d).astype(d.dtype)), [-2])) + +def generalized_normal( + key: KeyArray, + p: float, + shape: Sequence[int] = (), + dtype: DTypeLikeFloat = dtypes.float_ +) -> jnp.ndarray: + """Sample from the generalized normal distribution. + + Args: + key: a PRNG key used as the random key. + p: a float representing the shape parameter. + shape: optional, the batch dimensions of the result. Default (). + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). + + Returns: + A random array with the specified shape and dtype. + """ + _check_shape("generalized_normal", shape) + keys = split(key) + g = gamma(keys[0], 1/p, shape, dtype) + r = rademacher(keys[1], shape, dtype) + return r * g ** (1 / p) + +def ball( + key: KeyArray, + d: int, + p: float = 2, + shape: Sequence[int] = (), + dtype: DTypeLikeFloat = dtypes.float_ +): + """Sample uniformly from the unit Lp ball. + + Reference: https://arxiv.org/abs/math/0503650. + + Args: + key: a PRNG key used as the random key. + d: a nonnegative int representing the dimensionality of the ball. + p: a float representing the p parameter of the Lp norm. + shape: optional, the batch dimensions of the result. Default (). + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). + + Returns: + A random array of shape `(*shape, d)` and specified dtype. + """ + _check_shape("ball", shape) + d = core.concrete_or_error(index, d, "The error occurred in jax.random.ball()") + keys = split(key) + g = generalized_normal(keys[0], p, (*shape, d), dtype) + e = exponential(keys[1], shape, dtype) + return g / (((jnp.abs(g) ** p).sum(-1) + e) ** (1 / p))[..., None] diff --git a/jax/random.py b/jax/random.py index 3e3cd5fe8..4de9f2881 100644 --- a/jax/random.py +++ b/jax/random.py @@ -125,6 +125,7 @@ KeyArray = PRNGKeyArray from jax._src.random import ( PRNGKey as PRNGKey, + ball as ball, bernoulli as bernoulli, beta as beta, categorical as categorical, @@ -136,6 +137,7 @@ from jax._src.random import ( exponential as exponential, fold_in as fold_in, gamma as gamma, + generalized_normal as generalized_normal, gumbel as gumbel, laplace as laplace, logistic as logistic, diff --git a/tests/random_test.py b/tests/random_test.py index f4f4f2618..9281b0065 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1056,6 +1056,50 @@ class LaxRandomTest(jtu.JaxTestCase): atol=tol, rtol=tol, ) + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_p={}_shape={}"\ + .format(p, jtu.format_shape_dtype_string(shape, dtype)), + "p": p, + "shape": shape, + "dtype": dtype} + for p in [.5, 1., 1.5, 2., 2.5] + for shape in [(), (5,), (10, 5)] + for dtype in jtu.dtypes.floating)) + def testGeneralizedNormal(self, p, shape, dtype): + key = self.seed_prng(0) + rand = lambda key, p: random.generalized_normal(key, p, shape, dtype) + crand = jax.jit(rand) + uncompiled_samples = rand(key, p) + compiled_samples = crand(key, p) + for samples in [uncompiled_samples, compiled_samples]: + self.assertEqual(samples.shape, shape) + self.assertEqual(samples.dtype, dtype) + self._CheckKolmogorovSmirnovCDF(samples.ravel(), scipy.stats.gennorm(p).cdf) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_d={}_p={}_shape={}"\ + .format(d, p, jtu.format_shape_dtype_string(shape, dtype)), + "d": d, + "p": p, + "shape": shape, + "dtype": dtype} + for d in range(1, 5) + for p in [.5, 1., 1.5, 2., 2.5] + for shape in [(), (5,), (10, 5)] + for dtype in jtu.dtypes.floating)) + def testBall(self, d, p, shape, dtype): + key = self.seed_prng(0) + rand = lambda key, p: random.ball(key, d, p, shape, dtype) + crand = jax.jit(rand) + uncompiled_samples = rand(key, p) + compiled_samples = crand(key, p) + for samples in [uncompiled_samples, compiled_samples]: + self.assertEqual(samples.shape, (*shape, d)) + self.assertEqual(samples.dtype, dtype) + self.assertTrue(((jnp.abs(samples) ** p).sum(-1) <= 1).all()) + norms = (jnp.abs(samples) ** p).sum(-1) ** (d / p) + self._CheckKolmogorovSmirnovCDF(norms.ravel(), scipy.stats.uniform().cdf) + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": f"_b={b}_dtype={np.dtype(dtype).name}", "b": b, "dtype": dtype}