Merge pull request #27157 from mar-muel:improve-random-choice-performance

PiperOrigin-RevId: 737665351
This commit is contained in:
jax authors 2025-03-17 10:30:15 -07:00
commit de9ad6bad9

View File

@ -670,8 +670,8 @@ def choice(key: ArrayLike,
ind = jnp.searchsorted(p_cuml, r).astype(int)
else:
# Gumbel top-k trick: https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/
g = -gumbel(key, (n_inputs,), dtype=p_arr.dtype) - jnp.log(p_arr)
ind = jnp.argsort(g)[:n_draws]
g = gumbel(key, (n_inputs,), dtype=p_arr.dtype) + jnp.log(p_arr)
ind = lax.top_k(g, k=n_draws)[1].astype(int)
result = ind if arr.ndim == 0 else jnp.take(arr, ind, axis)
return result.reshape(shape if arr.ndim == 0 else