mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add jax.random.multinomial.
This commit is contained in:
parent
a9f4dd7182
commit
32411a430f
@ -21,6 +21,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
decorator to support customizing the behavior of opaque functions under
|
||||
JAX-level dead code elimination (DCE). See {jax-issue}`#25956` for more
|
||||
details.
|
||||
* Added {func}`jax.random.multinomial`.
|
||||
|
||||
## jax 0.5.0 (Jan 17, 2025)
|
||||
|
||||
|
@ -53,6 +53,7 @@ Random Samplers
|
||||
logistic
|
||||
lognormal
|
||||
maxwell
|
||||
multinomial
|
||||
multivariate_normal
|
||||
normal
|
||||
orthogonal
|
||||
|
@ -2627,6 +2627,69 @@ random_clone_p.def_abstract_eval(lambda x: x)
|
||||
batching.defvectorized(random_clone_p)
|
||||
mlir.register_lowering(random_clone_p, lambda _, k: [k])
|
||||
|
||||
|
||||
def multinomial(
|
||||
key: Array,
|
||||
n: RealArray,
|
||||
p: RealArray,
|
||||
*,
|
||||
shape: Shape | None = None,
|
||||
dtype: DTypeLikeFloat = float,
|
||||
):
|
||||
r"""Sample from a multinomial distribution.
|
||||
|
||||
The probability mass function is
|
||||
|
||||
.. math::
|
||||
f(x;n,p) = \frac{n!}{x_1! \ldots x_k!} p_1^{x_1} \ldots p_k^{x_k}
|
||||
|
||||
Args:
|
||||
key: PRNG key.
|
||||
n: number of trials. Should have shape broadcastable to ``p.shape[:-1]``.
|
||||
p: probability of each outcome, with outcomes along the last axis.
|
||||
shape: optional, a tuple of nonnegative integers specifying the result batch
|
||||
shape, that is, the prefix of the result shape excluding the last axis.
|
||||
Must be broadcast-compatible with ``p.shape[:-1]``. The default (None)
|
||||
produces a result shape equal to ``p.shape``.
|
||||
dtype: optional, a float dtype for the returned values (default float64 if
|
||||
jax_enable_x64 is true, otherwise float32).
|
||||
|
||||
Returns:
|
||||
An array of counts for each outcome with the specified dtype and with shape
|
||||
``p.shape`` if ``shape`` is None, otherwise ``shape + (p.shape[-1],)``.
|
||||
"""
|
||||
|
||||
key, _ = _check_prng_key("multinomial", key)
|
||||
check_arraylike("multinomial", n, p)
|
||||
n, p = promote_dtypes_inexact(n, p)
|
||||
|
||||
def f(remainder, ratio_key):
|
||||
ratio, key = ratio_key
|
||||
count = binomial(key, remainder, ratio, dtype=remainder.dtype)
|
||||
return remainder - count, count
|
||||
|
||||
p_shape = jnp.shape(p)
|
||||
|
||||
if shape is None:
|
||||
shape = p_shape[:-1]
|
||||
|
||||
n = jnp.broadcast_to(n, shape)
|
||||
p = jnp.broadcast_to(p, (*shape, p_shape[-1]))
|
||||
|
||||
p = jnp.moveaxis(p, -1, 0)
|
||||
|
||||
remaining_probs = lax.cumsum(p, 0, reverse=True)
|
||||
ratios = p / jnp.where(remaining_probs == 0, 1, remaining_probs)
|
||||
|
||||
keys = split(key, ratios.shape[0])
|
||||
remainder, counts = lax.scan(f, n, (ratios, keys), unroll=True)
|
||||
# final remainder should be zero
|
||||
|
||||
counts = jnp.moveaxis(counts, 0, -1)
|
||||
|
||||
return counts.astype(dtype)
|
||||
|
||||
|
||||
def clone(key):
|
||||
"""Clone a key for reuse
|
||||
|
||||
|
@ -232,6 +232,7 @@ from jax._src.random import (
|
||||
loggamma as loggamma,
|
||||
lognormal as lognormal,
|
||||
maxwell as maxwell,
|
||||
multinomial as multinomial,
|
||||
multivariate_normal as multivariate_normal,
|
||||
normal as normal,
|
||||
orthogonal as orthogonal,
|
||||
|
@ -1250,6 +1250,51 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
p = jax.numpy.float16(0.5)
|
||||
jax.random.binomial(key, n, p) # doesn't error
|
||||
|
||||
def testMultinomialExample(self):
|
||||
key = random.key(0)
|
||||
probs = jnp.array([
|
||||
[0.5, 0.2, 0.3],
|
||||
[0.1, 0.2, 0.7],
|
||||
[1.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 1.0],
|
||||
[0.5, 0.0, 0.5],
|
||||
])
|
||||
trials = 1e8
|
||||
counts = random.multinomial(key, trials, probs)
|
||||
freqs = counts / trials
|
||||
self.assertAllClose(freqs, probs, atol=1e-3)
|
||||
|
||||
@jtu.sample_product([
|
||||
dict(shape=shape, outcomes=outcomes)
|
||||
for shape in [(5,), (2, 3), (2, 3, 5)]
|
||||
for outcomes in [2, 3, 4]
|
||||
])
|
||||
def testMultinomialShape(self, shape, outcomes):
|
||||
key = random.key(0)
|
||||
|
||||
key, subkey = random.split(key)
|
||||
probs = random.dirichlet(subkey, jnp.ones(outcomes))
|
||||
|
||||
trials = 1e8
|
||||
counts = random.multinomial(key, trials, probs, shape=shape)
|
||||
freqs = counts / trials
|
||||
|
||||
self.assertAllClose(freqs, jnp.broadcast_to(probs, freqs.shape), atol=1e-3)
|
||||
|
||||
@jtu.sample_product([
|
||||
dict(n_dtype=n_dtype, p_dtype=p_dtype, dtype=dtype)
|
||||
for n_dtype in jtu.dtypes.all_floating
|
||||
for p_dtype in jtu.dtypes.all_floating
|
||||
for dtype in jtu.dtypes.all_floating
|
||||
])
|
||||
@jax.numpy_dtype_promotion('standard')
|
||||
def testMultinomialDtype(self, n_dtype, p_dtype, dtype):
|
||||
key = random.key(0)
|
||||
n = jnp.astype(10, n_dtype)
|
||||
p = jnp.astype(jnp.ones(3) / 3, p_dtype)
|
||||
random.multinomial(key, n, p)
|
||||
|
||||
def test_batched_key_errors(self):
|
||||
keys = lambda: jax.random.split(self.make_key(0))
|
||||
msg = "{} accepts a single key, but was given a key array of shape.*"
|
||||
|
Loading…
x
Reference in New Issue
Block a user