Add jax.random.multinomial.

This commit is contained in:
carlosgmartin 2025-01-31 18:45:55 -05:00
parent a9f4dd7182
commit 32411a430f
5 changed files with 111 additions and 0 deletions

View File

@ -21,6 +21,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
decorator to support customizing the behavior of opaque functions under
JAX-level dead code elimination (DCE). See {jax-issue}`#25956` for more
details.
* Added {func}`jax.random.multinomial`.
## jax 0.5.0 (Jan 17, 2025)

View File

@ -53,6 +53,7 @@ Random Samplers
logistic
lognormal
maxwell
multinomial
multivariate_normal
normal
orthogonal

View File

@ -2627,6 +2627,69 @@ random_clone_p.def_abstract_eval(lambda x: x)
batching.defvectorized(random_clone_p)
mlir.register_lowering(random_clone_p, lambda _, k: [k])
def multinomial(
key: Array,
n: RealArray,
p: RealArray,
*,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float,
):
r"""Sample from a multinomial distribution.
The probability mass function is
.. math::
f(x;n,p) = \frac{n!}{x_1! \ldots x_k!} p_1^{x_1} \ldots p_k^{x_k}
Args:
key: PRNG key.
n: number of trials. Should have shape broadcastable to ``p.shape[:-1]``.
p: probability of each outcome, with outcomes along the last axis.
shape: optional, a tuple of nonnegative integers specifying the result batch
shape, that is, the prefix of the result shape excluding the last axis.
Must be broadcast-compatible with ``p.shape[:-1]``. The default (None)
produces a result shape equal to ``p.shape``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
An array of counts for each outcome with the specified dtype and with shape
``p.shape`` if ``shape`` is None, otherwise ``shape + (p.shape[-1],)``.
"""
key, _ = _check_prng_key("multinomial", key)
check_arraylike("multinomial", n, p)
n, p = promote_dtypes_inexact(n, p)
def f(remainder, ratio_key):
ratio, key = ratio_key
count = binomial(key, remainder, ratio, dtype=remainder.dtype)
return remainder - count, count
p_shape = jnp.shape(p)
if shape is None:
shape = p_shape[:-1]
n = jnp.broadcast_to(n, shape)
p = jnp.broadcast_to(p, (*shape, p_shape[-1]))
p = jnp.moveaxis(p, -1, 0)
remaining_probs = lax.cumsum(p, 0, reverse=True)
ratios = p / jnp.where(remaining_probs == 0, 1, remaining_probs)
keys = split(key, ratios.shape[0])
remainder, counts = lax.scan(f, n, (ratios, keys), unroll=True)
# final remainder should be zero
counts = jnp.moveaxis(counts, 0, -1)
return counts.astype(dtype)
def clone(key):
"""Clone a key for reuse

View File

@ -232,6 +232,7 @@ from jax._src.random import (
loggamma as loggamma,
lognormal as lognormal,
maxwell as maxwell,
multinomial as multinomial,
multivariate_normal as multivariate_normal,
normal as normal,
orthogonal as orthogonal,

View File

@ -1250,6 +1250,51 @@ class LaxRandomTest(jtu.JaxTestCase):
p = jax.numpy.float16(0.5)
jax.random.binomial(key, n, p) # doesn't error
def testMultinomialExample(self):
key = random.key(0)
probs = jnp.array([
[0.5, 0.2, 0.3],
[0.1, 0.2, 0.7],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[0.5, 0.0, 0.5],
])
trials = 1e8
counts = random.multinomial(key, trials, probs)
freqs = counts / trials
self.assertAllClose(freqs, probs, atol=1e-3)
@jtu.sample_product([
dict(shape=shape, outcomes=outcomes)
for shape in [(5,), (2, 3), (2, 3, 5)]
for outcomes in [2, 3, 4]
])
def testMultinomialShape(self, shape, outcomes):
key = random.key(0)
key, subkey = random.split(key)
probs = random.dirichlet(subkey, jnp.ones(outcomes))
trials = 1e8
counts = random.multinomial(key, trials, probs, shape=shape)
freqs = counts / trials
self.assertAllClose(freqs, jnp.broadcast_to(probs, freqs.shape), atol=1e-3)
@jtu.sample_product([
dict(n_dtype=n_dtype, p_dtype=p_dtype, dtype=dtype)
for n_dtype in jtu.dtypes.all_floating
for p_dtype in jtu.dtypes.all_floating
for dtype in jtu.dtypes.all_floating
])
@jax.numpy_dtype_promotion('standard')
def testMultinomialDtype(self, n_dtype, p_dtype, dtype):
key = random.key(0)
n = jnp.astype(10, n_dtype)
p = jnp.astype(jnp.ones(3) / 3, p_dtype)
random.multinomial(key, n, p)
def test_batched_key_errors(self):
keys = lambda: jax.random.split(self.make_key(0))
msg = "{} accepts a single key, but was given a key array of shape.*"