Add replace option to random.categorical to enable sampling without replacement.

This commit is contained in:
carlosgmartin 2025-03-13 19:00:19 -04:00
parent de9ad6bad9
commit 3f59fa6888
3 changed files with 80 additions and 15 deletions

View File

@ -22,6 +22,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
{func}`jax.lax.dynamic_update_slice` and related functions. The default is {func}`jax.lax.dynamic_update_slice` and related functions. The default is
true, matching the current behavior. If set to false, JAX does not need to true, matching the current behavior. If set to false, JAX does not need to
emit code clamping negative indices, which improves code size. emit code clamping negative indices, which improves code size.
* Added a `replace` option to {func}`jax.random.categorical` to enable sampling
without replacement.
## jax 0.5.2 (Mar 4, 2025) ## jax 0.5.2 (Mar 4, 2025)

View File

@ -1548,12 +1548,18 @@ def _gumbel(key, shape, dtype, mode) -> Array:
_uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.))) _uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.)))
def categorical(key: ArrayLike, def categorical(
logits: RealArray, key: ArrayLike,
axis: int = -1, logits: RealArray,
shape: Shape | None = None) -> Array: axis: int = -1,
shape: Shape | None = None,
replace: bool = True,
) -> Array:
"""Sample random values from categorical distributions. """Sample random values from categorical distributions.
Sampling with replacement uses the Gumbel max trick. Sampling without replacement uses
the Gumbel top-k trick. See [1] for reference.
Args: Args:
key: a PRNG key used as the random key. key: a PRNG key used as the random key.
logits: Unnormalized log probabilities of the categorical distribution(s) to sample from, logits: Unnormalized log probabilities of the categorical distribution(s) to sample from,
@ -1562,32 +1568,57 @@ def categorical(key: ArrayLike,
shape: Optional, a tuple of nonnegative integers representing the result shape. shape: Optional, a tuple of nonnegative integers representing the result shape.
Must be broadcast-compatible with ``np.delete(logits.shape, axis)``. Must be broadcast-compatible with ``np.delete(logits.shape, axis)``.
The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``. The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``.
replace: If True, perform sampling without replacement. Default (False) is to
perform sampling with replacement.
Returns: Returns:
A random array with int dtype and shape given by ``shape`` if ``shape`` A random array with int dtype and shape given by ``shape`` if ``shape``
is not None, or else ``np.delete(logits.shape, axis)``. is not None, or else ``np.delete(logits.shape, axis)``.
References:
.. [1] Wouter Kool, Herke van Hoof, Max Welling. "Stochastic Beams and Where to Find
Them: The Gumbel-Top-k Trick for Sampling Sequences Without Replacement".
Proceedings of the 36th International Conference on Machine Learning, PMLR
97:3499-3508, 2019. https://proceedings.mlr.press/v97/kool19a.html.
""" """
key, _ = _check_prng_key("categorical", key) key, _ = _check_prng_key("categorical", key)
check_arraylike("categorical", logits) check_arraylike("categorical", logits)
logits_arr = jnp.asarray(logits) logits_arr = jnp.asarray(logits)
if axis >= 0:
axis -= len(logits_arr.shape)
batch_shape = tuple(np.delete(logits_arr.shape, axis)) batch_shape = tuple(np.delete(logits_arr.shape, axis))
if shape is None: if shape is None:
shape = batch_shape shape = batch_shape
else: else:
shape = core.canonicalize_shape(shape) shape = core.canonicalize_shape(shape)
_check_shape("categorical", shape, batch_shape) _check_shape("categorical", shape, batch_shape)
shape_prefix = shape[:len(shape)-len(batch_shape)] shape_prefix = shape[:len(shape)-len(batch_shape)]
logits_shape = list(shape[len(shape) - len(batch_shape):])
logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis]) if replace:
return jnp.argmax( if axis >= 0:
gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype) + axis -= len(logits_arr.shape)
lax.expand_dims(logits_arr, tuple(range(len(shape_prefix)))),
axis=axis) logits_shape = list(shape[len(shape) - len(batch_shape):])
logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis])
return jnp.argmax(
gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype) +
lax.expand_dims(logits_arr, tuple(range(len(shape_prefix)))),
axis=axis)
else:
logits_arr += gumbel(key, logits_arr.shape, logits_arr.dtype)
k = math.prod(shape_prefix)
if k > logits_arr.shape[axis]:
raise ValueError(
f"Number of samples without replacement ({k}) cannot exceed number of "
f"categories ({logits_arr.shape[axis]})."
)
_, indices = lax.top_k(jnp.moveaxis(logits_arr, axis, -1), k)
assert indices.shape == batch_shape + (k,)
assert shape == shape_prefix + batch_shape
dimensions = (indices.ndim - 1, *range(indices.ndim - 1))
indices = lax.reshape(indices, shape, dimensions)
assert indices.shape == shape
return indices
def laplace(key: ArrayLike, def laplace(key: ArrayLike,

View File

@ -365,6 +365,38 @@ class LaxRandomTest(jtu.JaxTestCase):
pmf = lambda x: np.where(x < len(p), p[np.minimum(len(p) - 1, x)], 0.0) pmf = lambda x: np.where(x < len(p), p[np.minimum(len(p) - 1, x)], 0.0)
self._CheckChiSquared(samples, pmf=pmf) self._CheckChiSquared(samples, pmf=pmf)
@jtu.sample_product(
logits_shape=[(7,), (8, 9), (10, 11, 12)],
prefix_shape=[(2,), (3, 4), (5, 6)],
)
def testCategoricalWithoutReplacement(self, logits_shape, prefix_shape):
key = random.key(0)
key, subkey = random.split(key)
logits = random.normal(subkey, logits_shape)
key, subkey = random.split(key)
axis = random.randint(subkey, (), -len(logits_shape), len(logits_shape))
dists_shape = tuple(np.delete(logits_shape, axis))
n_categories = logits_shape[axis]
shape = prefix_shape + dists_shape
prefix_size = math.prod(prefix_shape)
if n_categories < prefix_size:
with self.assertRaisesRegex(ValueError, "Number of samples without replacement"):
random.categorical(key, logits, axis=axis, shape=shape, replace=False)
else:
output = random.categorical(key, logits, axis=axis, shape=shape, replace=False)
self.assertEqual(output.shape, shape)
assert (0 <= output).all()
assert (output < n_categories).all()
flat = output.reshape((prefix_size, math.prod(dists_shape)))
counts = jax.vmap(partial(jnp.bincount, length=n_categories), 1)(flat)
assert (counts <= 1).all()
def testBernoulliShape(self): def testBernoulliShape(self):
key = self.make_key(0) key = self.make_key(0)
with jax.numpy_rank_promotion('allow'): with jax.numpy_rank_promotion('allow'):