1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-26 06:56:08 +00:00

Merge pull request from mar-muel:improve-random-choice-performance

PiperOrigin-RevId: 737665351
This commit is contained in:
jax authors 2025-03-17 10:30:15 -07:00
commit de9ad6bad9

@ -670,8 +670,8 @@ def choice(key: ArrayLike,
ind = jnp.searchsorted(p_cuml, r).astype(int) ind = jnp.searchsorted(p_cuml, r).astype(int)
else: else:
# Gumbel top-k trick: https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/ # Gumbel top-k trick: https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/
g = -gumbel(key, (n_inputs,), dtype=p_arr.dtype) - jnp.log(p_arr) g = gumbel(key, (n_inputs,), dtype=p_arr.dtype) + jnp.log(p_arr)
ind = jnp.argsort(g)[:n_draws] ind = lax.top_k(g, k=n_draws)[1].astype(int)
result = ind if arr.ndim == 0 else jnp.take(arr, ind, axis) result = ind if arr.ndim == 0 else jnp.take(arr, ind, axis)
return result.reshape(shape if arr.ndim == 0 else return result.reshape(shape if arr.ndim == 0 else