From 6b69a136aaabdb9b81c39873799f116ca79fcce7 Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Wed, 12 Mar 2025 18:15:14 -0400 Subject: [PATCH] Add jax.random.multinomial. --- CHANGELOG.md | 1 + docs/jax.random.rst | 1 + jax/_src/random.py | 61 ++++++++++++++++++++++++++++ jax/random.py | 1 + tests/random_lax_test.py | 87 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 151 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index faa3a3150..c166b2e80 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,7 @@ Patch release of 0.5.1 {func}`jax.lax.reduce_and`, {func}`jax.lax.reduce_or`, and {func}`jax.lax.reduce_xor`. * {func}`jax.lax.linalg.qr`, and {func}`jax.scipy.linalg.qr`, now support column-pivoting on CPU and GPU. See {jax-issue}`#20282` and + * Added {func}`jax.random.multinomial`. {jax-issue}`#25955` for more details. * Changes diff --git a/docs/jax.random.rst b/docs/jax.random.rst index 6c5427c05..837037a49 100644 --- a/docs/jax.random.rst +++ b/docs/jax.random.rst @@ -53,6 +53,7 @@ Random Samplers logistic lognormal maxwell + multinomial multivariate_normal normal orthogonal diff --git a/jax/_src/random.py b/jax/_src/random.py index 51edc1dc8..285fb643c 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -2661,6 +2661,67 @@ random_clone_p.def_abstract_eval(lambda x: x) batching.defvectorized(random_clone_p) mlir.register_lowering(random_clone_p, lambda _, k: [k]) + +def multinomial( + key: Array, + n: RealArray, + p: RealArray, + *, + shape: Shape | None = None, + dtype: DTypeLikeFloat = float, + unroll: int | bool = 1, +): + r"""Sample from a multinomial distribution. + + The probability mass function is + + .. math:: + f(x;n,p) = \frac{n!}{x_1! \ldots x_k!} p_1^{x_1} \ldots p_k^{x_k} + + Args: + key: PRNG key. + n: number of trials. Should have shape broadcastable to ``p.shape[:-1]``. + p: probability of each outcome, with outcomes along the last axis. + shape: optional, a tuple of nonnegative integers specifying the result batch + shape, that is, the prefix of the result shape excluding the last axis. + Must be broadcast-compatible with ``p.shape[:-1]``. The default (None) + produces a result shape equal to ``p.shape``. + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). + unroll: optional, unroll parameter passed to :func:`jax.lax.scan` inside the + implementation of this function. + + Returns: + An array of counts for each outcome with the specified dtype and with shape + ``p.shape`` if ``shape`` is None, otherwise ``shape + (p.shape[-1],)``. + """ + + key, _ = _check_prng_key("multinomial", key) + check_arraylike("multinomial", n, p) + n, p = promote_dtypes_inexact(n, p) + + if shape is None: + shape = p.shape + n = jnp.broadcast_to(n, shape[:-1]) + p = jnp.broadcast_to(p, shape) + + def f(remainder, ratio_key): + ratio, key = ratio_key + count = binomial(key, remainder, ratio.clip(0, 1), dtype=remainder.dtype) + return remainder - count, count + + p = jnp.moveaxis(p, -1, 0) + + remaining_probs = lax.cumsum(p, 0, reverse=True) + ratios = p / jnp.where(remaining_probs == 0, 1, remaining_probs) + + keys = split(key, ratios.shape[0]) + remainder, counts = lax.scan(f, n, (ratios, keys), unroll=unroll) + # final remainder should be zero + + return jnp.moveaxis(counts, 0, -1).astype(dtype) + + def clone(key): """Clone a key for reuse diff --git a/jax/random.py b/jax/random.py index 7722611fa..9db584895 100644 --- a/jax/random.py +++ b/jax/random.py @@ -232,6 +232,7 @@ from jax._src.random import ( loggamma as loggamma, lognormal as lognormal, maxwell as maxwell, + multinomial as multinomial, multivariate_normal as multivariate_normal, normal as normal, orthogonal as orthogonal, diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index aa2abc10e..366f5ab3c 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -1279,6 +1279,78 @@ class LaxRandomTest(jtu.JaxTestCase): p = jax.numpy.float16(0.5) jax.random.binomial(key, n, p) # doesn't error + def testMultinomialExample(self): + key = random.key(0) + probs = jnp.array([ + [0.5, 0.2, 0.3], + [0.1, 0.2, 0.7], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [0.5, 0.0, 0.5], + ]) + trials = 1e5 + counts = random.multinomial(key, trials, probs) + freqs = counts / trials + self.assertAllClose(freqs, probs, atol=1e-2) + + @jtu.sample_product( + categories=[1, 2, 3, 5, 7, 11], + trials=[1, 2, 3, 5, 7, 11], + dtype=[jnp.float32], + ) + def testMultinomialNumpy( + self, + categories, + trials, + dtype, + test_samples=10**6, + tolerance=1e-1, + ): + probs = jnp.linspace(-1, 2, categories)[::-1] ** 2 + probs /= probs.sum(-1, keepdims=True) + + rng = np.random.default_rng(0) + counts_numpy = jnp.array(rng.multinomial(trials, probs, size=test_samples), dtype) + + shape = (test_samples,) + probs.shape + key = random.key(0) + counts_jax = random.multinomial(key, trials, probs, shape=shape, dtype=dtype) + assert counts_jax.shape == shape + + energy_distance = get_energy_distance(counts_numpy, counts_jax) + assert energy_distance < tolerance + + @jtu.sample_product([ + dict(shape=shape, outcomes=outcomes) + for shape in [(5,), (2, 3), (2, 3, 5)] + for outcomes in [2, 3, 4] + ]) + def testMultinomialShape(self, shape, outcomes): + key = random.key(0) + + key, subkey = random.split(key) + probs = random.dirichlet(subkey, jnp.ones(outcomes)) + + trials = 1e5 + counts = random.multinomial(key, trials, probs, shape=(*shape, *probs.shape)) + freqs = counts / trials + + self.assertAllClose(freqs, jnp.broadcast_to(probs, freqs.shape), atol=1e-2) + + @jtu.sample_product([ + dict(n_dtype=n_dtype, p_dtype=p_dtype, dtype=dtype) + for n_dtype in jtu.dtypes.all_floating + for p_dtype in jtu.dtypes.all_floating + for dtype in jtu.dtypes.all_floating + ]) + @jax.numpy_dtype_promotion('standard') + def testMultinomialDtype(self, n_dtype, p_dtype, dtype): + key = random.key(0) + n = jnp.astype(10, n_dtype) + p = jnp.astype(jnp.ones(3) / 3, p_dtype) + random.multinomial(key, n, p) + def test_batched_key_errors(self): keys = lambda: jax.random.split(self.make_key(0)) msg = "{} accepts a single key, but was given a key array of shape.*" @@ -1305,6 +1377,21 @@ class LaxRandomTest(jtu.JaxTestCase): jax.random.key_impl(keys()) +def get_energy_distance(samples_1, samples_2): + """ + Estimates the energy distance between two distributions, given + batches of independent samples from each. + For more information, see https://en.wikipedia.org/wiki/Energy_distance. + """ + x, xp = jnp.split(samples_1, 2) + y, yp = jnp.split(samples_2, 2) + return ( + 2 * jnp.linalg.norm(x - y, axis=-1) + - jnp.linalg.norm(x - xp, axis=-1) + - jnp.linalg.norm(y - yp, axis=-1) + ).mean(0) + + threefry_seed = prng_internal.threefry_seed threefry_split = prng_internal.threefry_split threefry_random_bits = prng_internal.threefry_random_bits