mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Add jax.random.multinomial.
This commit is contained in:
parent
b6d4fe5387
commit
6b69a136aa
@ -42,6 +42,7 @@ Patch release of 0.5.1
|
||||
{func}`jax.lax.reduce_and`, {func}`jax.lax.reduce_or`, and {func}`jax.lax.reduce_xor`.
|
||||
* {func}`jax.lax.linalg.qr`, and {func}`jax.scipy.linalg.qr`, now support
|
||||
column-pivoting on CPU and GPU. See {jax-issue}`#20282` and
|
||||
* Added {func}`jax.random.multinomial`.
|
||||
{jax-issue}`#25955` for more details.
|
||||
|
||||
* Changes
|
||||
|
@ -53,6 +53,7 @@ Random Samplers
|
||||
logistic
|
||||
lognormal
|
||||
maxwell
|
||||
multinomial
|
||||
multivariate_normal
|
||||
normal
|
||||
orthogonal
|
||||
|
@ -2661,6 +2661,67 @@ random_clone_p.def_abstract_eval(lambda x: x)
|
||||
batching.defvectorized(random_clone_p)
|
||||
mlir.register_lowering(random_clone_p, lambda _, k: [k])
|
||||
|
||||
|
||||
def multinomial(
|
||||
key: Array,
|
||||
n: RealArray,
|
||||
p: RealArray,
|
||||
*,
|
||||
shape: Shape | None = None,
|
||||
dtype: DTypeLikeFloat = float,
|
||||
unroll: int | bool = 1,
|
||||
):
|
||||
r"""Sample from a multinomial distribution.
|
||||
|
||||
The probability mass function is
|
||||
|
||||
.. math::
|
||||
f(x;n,p) = \frac{n!}{x_1! \ldots x_k!} p_1^{x_1} \ldots p_k^{x_k}
|
||||
|
||||
Args:
|
||||
key: PRNG key.
|
||||
n: number of trials. Should have shape broadcastable to ``p.shape[:-1]``.
|
||||
p: probability of each outcome, with outcomes along the last axis.
|
||||
shape: optional, a tuple of nonnegative integers specifying the result batch
|
||||
shape, that is, the prefix of the result shape excluding the last axis.
|
||||
Must be broadcast-compatible with ``p.shape[:-1]``. The default (None)
|
||||
produces a result shape equal to ``p.shape``.
|
||||
dtype: optional, a float dtype for the returned values (default float64 if
|
||||
jax_enable_x64 is true, otherwise float32).
|
||||
unroll: optional, unroll parameter passed to :func:`jax.lax.scan` inside the
|
||||
implementation of this function.
|
||||
|
||||
Returns:
|
||||
An array of counts for each outcome with the specified dtype and with shape
|
||||
``p.shape`` if ``shape`` is None, otherwise ``shape + (p.shape[-1],)``.
|
||||
"""
|
||||
|
||||
key, _ = _check_prng_key("multinomial", key)
|
||||
check_arraylike("multinomial", n, p)
|
||||
n, p = promote_dtypes_inexact(n, p)
|
||||
|
||||
if shape is None:
|
||||
shape = p.shape
|
||||
n = jnp.broadcast_to(n, shape[:-1])
|
||||
p = jnp.broadcast_to(p, shape)
|
||||
|
||||
def f(remainder, ratio_key):
|
||||
ratio, key = ratio_key
|
||||
count = binomial(key, remainder, ratio.clip(0, 1), dtype=remainder.dtype)
|
||||
return remainder - count, count
|
||||
|
||||
p = jnp.moveaxis(p, -1, 0)
|
||||
|
||||
remaining_probs = lax.cumsum(p, 0, reverse=True)
|
||||
ratios = p / jnp.where(remaining_probs == 0, 1, remaining_probs)
|
||||
|
||||
keys = split(key, ratios.shape[0])
|
||||
remainder, counts = lax.scan(f, n, (ratios, keys), unroll=unroll)
|
||||
# final remainder should be zero
|
||||
|
||||
return jnp.moveaxis(counts, 0, -1).astype(dtype)
|
||||
|
||||
|
||||
def clone(key):
|
||||
"""Clone a key for reuse
|
||||
|
||||
|
@ -232,6 +232,7 @@ from jax._src.random import (
|
||||
loggamma as loggamma,
|
||||
lognormal as lognormal,
|
||||
maxwell as maxwell,
|
||||
multinomial as multinomial,
|
||||
multivariate_normal as multivariate_normal,
|
||||
normal as normal,
|
||||
orthogonal as orthogonal,
|
||||
|
@ -1279,6 +1279,78 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
p = jax.numpy.float16(0.5)
|
||||
jax.random.binomial(key, n, p) # doesn't error
|
||||
|
||||
def testMultinomialExample(self):
|
||||
key = random.key(0)
|
||||
probs = jnp.array([
|
||||
[0.5, 0.2, 0.3],
|
||||
[0.1, 0.2, 0.7],
|
||||
[1.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 1.0],
|
||||
[0.5, 0.0, 0.5],
|
||||
])
|
||||
trials = 1e5
|
||||
counts = random.multinomial(key, trials, probs)
|
||||
freqs = counts / trials
|
||||
self.assertAllClose(freqs, probs, atol=1e-2)
|
||||
|
||||
@jtu.sample_product(
|
||||
categories=[1, 2, 3, 5, 7, 11],
|
||||
trials=[1, 2, 3, 5, 7, 11],
|
||||
dtype=[jnp.float32],
|
||||
)
|
||||
def testMultinomialNumpy(
|
||||
self,
|
||||
categories,
|
||||
trials,
|
||||
dtype,
|
||||
test_samples=10**6,
|
||||
tolerance=1e-1,
|
||||
):
|
||||
probs = jnp.linspace(-1, 2, categories)[::-1] ** 2
|
||||
probs /= probs.sum(-1, keepdims=True)
|
||||
|
||||
rng = np.random.default_rng(0)
|
||||
counts_numpy = jnp.array(rng.multinomial(trials, probs, size=test_samples), dtype)
|
||||
|
||||
shape = (test_samples,) + probs.shape
|
||||
key = random.key(0)
|
||||
counts_jax = random.multinomial(key, trials, probs, shape=shape, dtype=dtype)
|
||||
assert counts_jax.shape == shape
|
||||
|
||||
energy_distance = get_energy_distance(counts_numpy, counts_jax)
|
||||
assert energy_distance < tolerance
|
||||
|
||||
@jtu.sample_product([
|
||||
dict(shape=shape, outcomes=outcomes)
|
||||
for shape in [(5,), (2, 3), (2, 3, 5)]
|
||||
for outcomes in [2, 3, 4]
|
||||
])
|
||||
def testMultinomialShape(self, shape, outcomes):
|
||||
key = random.key(0)
|
||||
|
||||
key, subkey = random.split(key)
|
||||
probs = random.dirichlet(subkey, jnp.ones(outcomes))
|
||||
|
||||
trials = 1e5
|
||||
counts = random.multinomial(key, trials, probs, shape=(*shape, *probs.shape))
|
||||
freqs = counts / trials
|
||||
|
||||
self.assertAllClose(freqs, jnp.broadcast_to(probs, freqs.shape), atol=1e-2)
|
||||
|
||||
@jtu.sample_product([
|
||||
dict(n_dtype=n_dtype, p_dtype=p_dtype, dtype=dtype)
|
||||
for n_dtype in jtu.dtypes.all_floating
|
||||
for p_dtype in jtu.dtypes.all_floating
|
||||
for dtype in jtu.dtypes.all_floating
|
||||
])
|
||||
@jax.numpy_dtype_promotion('standard')
|
||||
def testMultinomialDtype(self, n_dtype, p_dtype, dtype):
|
||||
key = random.key(0)
|
||||
n = jnp.astype(10, n_dtype)
|
||||
p = jnp.astype(jnp.ones(3) / 3, p_dtype)
|
||||
random.multinomial(key, n, p)
|
||||
|
||||
def test_batched_key_errors(self):
|
||||
keys = lambda: jax.random.split(self.make_key(0))
|
||||
msg = "{} accepts a single key, but was given a key array of shape.*"
|
||||
@ -1305,6 +1377,21 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
jax.random.key_impl(keys())
|
||||
|
||||
|
||||
def get_energy_distance(samples_1, samples_2):
|
||||
"""
|
||||
Estimates the energy distance between two distributions, given
|
||||
batches of independent samples from each.
|
||||
For more information, see https://en.wikipedia.org/wiki/Energy_distance.
|
||||
"""
|
||||
x, xp = jnp.split(samples_1, 2)
|
||||
y, yp = jnp.split(samples_2, 2)
|
||||
return (
|
||||
2 * jnp.linalg.norm(x - y, axis=-1)
|
||||
- jnp.linalg.norm(x - xp, axis=-1)
|
||||
- jnp.linalg.norm(y - yp, axis=-1)
|
||||
).mean(0)
|
||||
|
||||
|
||||
threefry_seed = prng_internal.threefry_seed
|
||||
threefry_split = prng_internal.threefry_split
|
||||
threefry_random_bits = prng_internal.threefry_random_bits
|
||||
|
Loading…
x
Reference in New Issue
Block a user