Add replace option to random.categorical to enable sampling without replacement.

This commit is contained in:
carlosgmartin 2025-03-13 19:00:19 -04:00
parent de9ad6bad9
commit 3f59fa6888
3 changed files with 80 additions and 15 deletions

View File

@ -22,6 +22,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
{func}`jax.lax.dynamic_update_slice` and related functions. The default is
true, matching the current behavior. If set to false, JAX does not need to
emit code clamping negative indices, which improves code size.
* Added a `replace` option to {func}`jax.random.categorical` to enable sampling
without replacement.
## jax 0.5.2 (Mar 4, 2025)

View File

@ -1548,12 +1548,18 @@ def _gumbel(key, shape, dtype, mode) -> Array:
_uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.)))
def categorical(key: ArrayLike,
logits: RealArray,
axis: int = -1,
shape: Shape | None = None) -> Array:
def categorical(
key: ArrayLike,
logits: RealArray,
axis: int = -1,
shape: Shape | None = None,
replace: bool = True,
) -> Array:
"""Sample random values from categorical distributions.
Sampling with replacement uses the Gumbel max trick. Sampling without replacement uses
the Gumbel top-k trick. See [1] for reference.
Args:
key: a PRNG key used as the random key.
logits: Unnormalized log probabilities of the categorical distribution(s) to sample from,
@ -1562,32 +1568,57 @@ def categorical(key: ArrayLike,
shape: Optional, a tuple of nonnegative integers representing the result shape.
Must be broadcast-compatible with ``np.delete(logits.shape, axis)``.
The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``.
replace: If True, perform sampling without replacement. Default (False) is to
perform sampling with replacement.
Returns:
A random array with int dtype and shape given by ``shape`` if ``shape``
is not None, or else ``np.delete(logits.shape, axis)``.
References:
.. [1] Wouter Kool, Herke van Hoof, Max Welling. "Stochastic Beams and Where to Find
Them: The Gumbel-Top-k Trick for Sampling Sequences Without Replacement".
Proceedings of the 36th International Conference on Machine Learning, PMLR
97:3499-3508, 2019. https://proceedings.mlr.press/v97/kool19a.html.
"""
key, _ = _check_prng_key("categorical", key)
check_arraylike("categorical", logits)
logits_arr = jnp.asarray(logits)
if axis >= 0:
axis -= len(logits_arr.shape)
batch_shape = tuple(np.delete(logits_arr.shape, axis))
if shape is None:
shape = batch_shape
else:
shape = core.canonicalize_shape(shape)
_check_shape("categorical", shape, batch_shape)
shape_prefix = shape[:len(shape)-len(batch_shape)]
logits_shape = list(shape[len(shape) - len(batch_shape):])
logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis])
return jnp.argmax(
gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype) +
lax.expand_dims(logits_arr, tuple(range(len(shape_prefix)))),
axis=axis)
if replace:
if axis >= 0:
axis -= len(logits_arr.shape)
logits_shape = list(shape[len(shape) - len(batch_shape):])
logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis])
return jnp.argmax(
gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype) +
lax.expand_dims(logits_arr, tuple(range(len(shape_prefix)))),
axis=axis)
else:
logits_arr += gumbel(key, logits_arr.shape, logits_arr.dtype)
k = math.prod(shape_prefix)
if k > logits_arr.shape[axis]:
raise ValueError(
f"Number of samples without replacement ({k}) cannot exceed number of "
f"categories ({logits_arr.shape[axis]})."
)
_, indices = lax.top_k(jnp.moveaxis(logits_arr, axis, -1), k)
assert indices.shape == batch_shape + (k,)
assert shape == shape_prefix + batch_shape
dimensions = (indices.ndim - 1, *range(indices.ndim - 1))
indices = lax.reshape(indices, shape, dimensions)
assert indices.shape == shape
return indices
def laplace(key: ArrayLike,

View File

@ -365,6 +365,38 @@ class LaxRandomTest(jtu.JaxTestCase):
pmf = lambda x: np.where(x < len(p), p[np.minimum(len(p) - 1, x)], 0.0)
self._CheckChiSquared(samples, pmf=pmf)
@jtu.sample_product(
logits_shape=[(7,), (8, 9), (10, 11, 12)],
prefix_shape=[(2,), (3, 4), (5, 6)],
)
def testCategoricalWithoutReplacement(self, logits_shape, prefix_shape):
key = random.key(0)
key, subkey = random.split(key)
logits = random.normal(subkey, logits_shape)
key, subkey = random.split(key)
axis = random.randint(subkey, (), -len(logits_shape), len(logits_shape))
dists_shape = tuple(np.delete(logits_shape, axis))
n_categories = logits_shape[axis]
shape = prefix_shape + dists_shape
prefix_size = math.prod(prefix_shape)
if n_categories < prefix_size:
with self.assertRaisesRegex(ValueError, "Number of samples without replacement"):
random.categorical(key, logits, axis=axis, shape=shape, replace=False)
else:
output = random.categorical(key, logits, axis=axis, shape=shape, replace=False)
self.assertEqual(output.shape, shape)
assert (0 <= output).all()
assert (output < n_categories).all()
flat = output.reshape((prefix_size, math.prod(dists_shape)))
counts = jax.vmap(partial(jnp.bincount, length=n_categories), 1)(flat)
assert (counts <= 1).all()
def testBernoulliShape(self):
key = self.make_key(0)
with jax.numpy_rank_promotion('allow'):