diff --git a/jax/_src/random.py b/jax/_src/random.py index c91d2f786..4c1436e3f 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -670,8 +670,8 @@ def choice(key: ArrayLike, ind = jnp.searchsorted(p_cuml, r).astype(int) else: # Gumbel top-k trick: https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/ - g = -gumbel(key, (n_inputs,), dtype=p_arr.dtype) - jnp.log(p_arr) - ind = jnp.argsort(g)[:n_draws] + g = gumbel(key, (n_inputs,), dtype=p_arr.dtype) + jnp.log(p_arr) + ind = lax.top_k(g, k=n_draws)[1].astype(int) result = ind if arr.ndim == 0 else jnp.take(arr, ind, axis) return result.reshape(shape if arr.ndim == 0 else