mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Merge pull request #27157 from mar-muel:improve-random-choice-performance
PiperOrigin-RevId: 737665351
This commit is contained in:
commit
de9ad6bad9
@ -670,8 +670,8 @@ def choice(key: ArrayLike,
|
||||
ind = jnp.searchsorted(p_cuml, r).astype(int)
|
||||
else:
|
||||
# Gumbel top-k trick: https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/
|
||||
g = -gumbel(key, (n_inputs,), dtype=p_arr.dtype) - jnp.log(p_arr)
|
||||
ind = jnp.argsort(g)[:n_draws]
|
||||
g = gumbel(key, (n_inputs,), dtype=p_arr.dtype) + jnp.log(p_arr)
|
||||
ind = lax.top_k(g, k=n_draws)[1].astype(int)
|
||||
result = ind if arr.ndim == 0 else jnp.take(arr, ind, axis)
|
||||
|
||||
return result.reshape(shape if arr.ndim == 0 else
|
||||
|
Loading…
x
Reference in New Issue
Block a user