diff --git a/CHANGELOG.md b/CHANGELOG.md index 140f66c30..c30877eca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. {func}`jax.lax.dynamic_update_slice` and related functions. The default is true, matching the current behavior. If set to false, JAX does not need to emit code clamping negative indices, which improves code size. + * Added a `replace` option to {func}`jax.random.categorical` to enable sampling + without replacement. ## jax 0.5.2 (Mar 4, 2025) diff --git a/jax/_src/random.py b/jax/_src/random.py index 4c1436e3f..094268c65 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -1548,12 +1548,18 @@ def _gumbel(key, shape, dtype, mode) -> Array: _uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.))) -def categorical(key: ArrayLike, - logits: RealArray, - axis: int = -1, - shape: Shape | None = None) -> Array: +def categorical( + key: ArrayLike, + logits: RealArray, + axis: int = -1, + shape: Shape | None = None, + replace: bool = True, +) -> Array: """Sample random values from categorical distributions. + Sampling with replacement uses the Gumbel max trick. Sampling without replacement uses + the Gumbel top-k trick. See [1] for reference. + Args: key: a PRNG key used as the random key. logits: Unnormalized log probabilities of the categorical distribution(s) to sample from, @@ -1562,32 +1568,57 @@ def categorical(key: ArrayLike, shape: Optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible with ``np.delete(logits.shape, axis)``. The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``. + replace: If True, perform sampling without replacement. Default (False) is to + perform sampling with replacement. Returns: A random array with int dtype and shape given by ``shape`` if ``shape`` is not None, or else ``np.delete(logits.shape, axis)``. + + References: + .. [1] Wouter Kool, Herke van Hoof, Max Welling. "Stochastic Beams and Where to Find + Them: The Gumbel-Top-k Trick for Sampling Sequences Without Replacement". + Proceedings of the 36th International Conference on Machine Learning, PMLR + 97:3499-3508, 2019. https://proceedings.mlr.press/v97/kool19a.html. """ key, _ = _check_prng_key("categorical", key) check_arraylike("categorical", logits) logits_arr = jnp.asarray(logits) - - if axis >= 0: - axis -= len(logits_arr.shape) - batch_shape = tuple(np.delete(logits_arr.shape, axis)) if shape is None: shape = batch_shape else: shape = core.canonicalize_shape(shape) _check_shape("categorical", shape, batch_shape) - shape_prefix = shape[:len(shape)-len(batch_shape)] - logits_shape = list(shape[len(shape) - len(batch_shape):]) - logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis]) - return jnp.argmax( - gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype) + - lax.expand_dims(logits_arr, tuple(range(len(shape_prefix)))), - axis=axis) + + if replace: + if axis >= 0: + axis -= len(logits_arr.shape) + + logits_shape = list(shape[len(shape) - len(batch_shape):]) + logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis]) + return jnp.argmax( + gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype) + + lax.expand_dims(logits_arr, tuple(range(len(shape_prefix)))), + axis=axis) + else: + logits_arr += gumbel(key, logits_arr.shape, logits_arr.dtype) + k = math.prod(shape_prefix) + if k > logits_arr.shape[axis]: + raise ValueError( + f"Number of samples without replacement ({k}) cannot exceed number of " + f"categories ({logits_arr.shape[axis]})." + ) + + _, indices = lax.top_k(jnp.moveaxis(logits_arr, axis, -1), k) + assert indices.shape == batch_shape + (k,) + assert shape == shape_prefix + batch_shape + + dimensions = (indices.ndim - 1, *range(indices.ndim - 1)) + indices = lax.reshape(indices, shape, dimensions) + assert indices.shape == shape + return indices def laplace(key: ArrayLike, diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index 366f5ab3c..b6f8b4f13 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -365,6 +365,38 @@ class LaxRandomTest(jtu.JaxTestCase): pmf = lambda x: np.where(x < len(p), p[np.minimum(len(p) - 1, x)], 0.0) self._CheckChiSquared(samples, pmf=pmf) + @jtu.sample_product( + logits_shape=[(7,), (8, 9), (10, 11, 12)], + prefix_shape=[(2,), (3, 4), (5, 6)], + ) + def testCategoricalWithoutReplacement(self, logits_shape, prefix_shape): + key = random.key(0) + + key, subkey = random.split(key) + logits = random.normal(subkey, logits_shape) + + key, subkey = random.split(key) + axis = random.randint(subkey, (), -len(logits_shape), len(logits_shape)) + + dists_shape = tuple(np.delete(logits_shape, axis)) + n_categories = logits_shape[axis] + shape = prefix_shape + dists_shape + prefix_size = math.prod(prefix_shape) + + if n_categories < prefix_size: + with self.assertRaisesRegex(ValueError, "Number of samples without replacement"): + random.categorical(key, logits, axis=axis, shape=shape, replace=False) + + else: + output = random.categorical(key, logits, axis=axis, shape=shape, replace=False) + self.assertEqual(output.shape, shape) + assert (0 <= output).all() + assert (output < n_categories).all() + flat = output.reshape((prefix_size, math.prod(dists_shape))) + counts = jax.vmap(partial(jnp.bincount, length=n_categories), 1)(flat) + assert (counts <= 1).all() + + def testBernoulliShape(self): key = self.make_key(0) with jax.numpy_rank_promotion('allow'):