Add jax.random.multinomial.

This commit is contained in:
carlosgmartin 2025-03-12 18:15:14 -04:00
parent b6d4fe5387
commit 6b69a136aa
5 changed files with 151 additions and 0 deletions

View File

@ -42,6 +42,7 @@ Patch release of 0.5.1
{func}`jax.lax.reduce_and`, {func}`jax.lax.reduce_or`, and {func}`jax.lax.reduce_xor`.
* {func}`jax.lax.linalg.qr`, and {func}`jax.scipy.linalg.qr`, now support
column-pivoting on CPU and GPU. See {jax-issue}`#20282` and
* Added {func}`jax.random.multinomial`.
{jax-issue}`#25955` for more details.
* Changes

View File

@ -53,6 +53,7 @@ Random Samplers
logistic
lognormal
maxwell
multinomial
multivariate_normal
normal
orthogonal

View File

@ -2661,6 +2661,67 @@ random_clone_p.def_abstract_eval(lambda x: x)
batching.defvectorized(random_clone_p)
mlir.register_lowering(random_clone_p, lambda _, k: [k])
def multinomial(
key: Array,
n: RealArray,
p: RealArray,
*,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float,
unroll: int | bool = 1,
):
r"""Sample from a multinomial distribution.
The probability mass function is
.. math::
f(x;n,p) = \frac{n!}{x_1! \ldots x_k!} p_1^{x_1} \ldots p_k^{x_k}
Args:
key: PRNG key.
n: number of trials. Should have shape broadcastable to ``p.shape[:-1]``.
p: probability of each outcome, with outcomes along the last axis.
shape: optional, a tuple of nonnegative integers specifying the result batch
shape, that is, the prefix of the result shape excluding the last axis.
Must be broadcast-compatible with ``p.shape[:-1]``. The default (None)
produces a result shape equal to ``p.shape``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
unroll: optional, unroll parameter passed to :func:`jax.lax.scan` inside the
implementation of this function.
Returns:
An array of counts for each outcome with the specified dtype and with shape
``p.shape`` if ``shape`` is None, otherwise ``shape + (p.shape[-1],)``.
"""
key, _ = _check_prng_key("multinomial", key)
check_arraylike("multinomial", n, p)
n, p = promote_dtypes_inexact(n, p)
if shape is None:
shape = p.shape
n = jnp.broadcast_to(n, shape[:-1])
p = jnp.broadcast_to(p, shape)
def f(remainder, ratio_key):
ratio, key = ratio_key
count = binomial(key, remainder, ratio.clip(0, 1), dtype=remainder.dtype)
return remainder - count, count
p = jnp.moveaxis(p, -1, 0)
remaining_probs = lax.cumsum(p, 0, reverse=True)
ratios = p / jnp.where(remaining_probs == 0, 1, remaining_probs)
keys = split(key, ratios.shape[0])
remainder, counts = lax.scan(f, n, (ratios, keys), unroll=unroll)
# final remainder should be zero
return jnp.moveaxis(counts, 0, -1).astype(dtype)
def clone(key):
"""Clone a key for reuse

View File

@ -232,6 +232,7 @@ from jax._src.random import (
loggamma as loggamma,
lognormal as lognormal,
maxwell as maxwell,
multinomial as multinomial,
multivariate_normal as multivariate_normal,
normal as normal,
orthogonal as orthogonal,

View File

@ -1279,6 +1279,78 @@ class LaxRandomTest(jtu.JaxTestCase):
p = jax.numpy.float16(0.5)
jax.random.binomial(key, n, p) # doesn't error
def testMultinomialExample(self):
key = random.key(0)
probs = jnp.array([
[0.5, 0.2, 0.3],
[0.1, 0.2, 0.7],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[0.5, 0.0, 0.5],
])
trials = 1e5
counts = random.multinomial(key, trials, probs)
freqs = counts / trials
self.assertAllClose(freqs, probs, atol=1e-2)
@jtu.sample_product(
categories=[1, 2, 3, 5, 7, 11],
trials=[1, 2, 3, 5, 7, 11],
dtype=[jnp.float32],
)
def testMultinomialNumpy(
self,
categories,
trials,
dtype,
test_samples=10**6,
tolerance=1e-1,
):
probs = jnp.linspace(-1, 2, categories)[::-1] ** 2
probs /= probs.sum(-1, keepdims=True)
rng = np.random.default_rng(0)
counts_numpy = jnp.array(rng.multinomial(trials, probs, size=test_samples), dtype)
shape = (test_samples,) + probs.shape
key = random.key(0)
counts_jax = random.multinomial(key, trials, probs, shape=shape, dtype=dtype)
assert counts_jax.shape == shape
energy_distance = get_energy_distance(counts_numpy, counts_jax)
assert energy_distance < tolerance
@jtu.sample_product([
dict(shape=shape, outcomes=outcomes)
for shape in [(5,), (2, 3), (2, 3, 5)]
for outcomes in [2, 3, 4]
])
def testMultinomialShape(self, shape, outcomes):
key = random.key(0)
key, subkey = random.split(key)
probs = random.dirichlet(subkey, jnp.ones(outcomes))
trials = 1e5
counts = random.multinomial(key, trials, probs, shape=(*shape, *probs.shape))
freqs = counts / trials
self.assertAllClose(freqs, jnp.broadcast_to(probs, freqs.shape), atol=1e-2)
@jtu.sample_product([
dict(n_dtype=n_dtype, p_dtype=p_dtype, dtype=dtype)
for n_dtype in jtu.dtypes.all_floating
for p_dtype in jtu.dtypes.all_floating
for dtype in jtu.dtypes.all_floating
])
@jax.numpy_dtype_promotion('standard')
def testMultinomialDtype(self, n_dtype, p_dtype, dtype):
key = random.key(0)
n = jnp.astype(10, n_dtype)
p = jnp.astype(jnp.ones(3) / 3, p_dtype)
random.multinomial(key, n, p)
def test_batched_key_errors(self):
keys = lambda: jax.random.split(self.make_key(0))
msg = "{} accepts a single key, but was given a key array of shape.*"
@ -1305,6 +1377,21 @@ class LaxRandomTest(jtu.JaxTestCase):
jax.random.key_impl(keys())
def get_energy_distance(samples_1, samples_2):
"""
Estimates the energy distance between two distributions, given
batches of independent samples from each.
For more information, see https://en.wikipedia.org/wiki/Energy_distance.
"""
x, xp = jnp.split(samples_1, 2)
y, yp = jnp.split(samples_2, 2)
return (
2 * jnp.linalg.norm(x - y, axis=-1)
- jnp.linalg.norm(x - xp, axis=-1)
- jnp.linalg.norm(y - yp, axis=-1)
).mean(0)
threefry_seed = prng_internal.threefry_seed
threefry_split = prng_internal.threefry_split
threefry_random_bits = prng_internal.threefry_random_bits