mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
jax.random.choice: make return dtype consistent
This commit is contained in:
parent
ec7a939d18
commit
5e14744c2c
@ -567,7 +567,7 @@ def choice(key: KeyArray,
|
||||
if replace:
|
||||
p_cuml = jnp.cumsum(p_arr)
|
||||
r = p_cuml[-1] * (1 - uniform(key, shape, dtype=p_cuml.dtype))
|
||||
ind = jnp.searchsorted(p_cuml, r)
|
||||
ind = jnp.searchsorted(p_cuml, r).astype(int)
|
||||
else:
|
||||
# Gumbel top-k trick: https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/
|
||||
g = -gumbel(key, (n_inputs,), dtype=p_arr.dtype) - jnp.log(p_arr)
|
||||
|
@ -754,8 +754,10 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
sample = rand(key, x)
|
||||
if not is_range:
|
||||
self.assertEqual(dtype, sample.dtype)
|
||||
np_shape = np.shape(np_choice(x, shape or None, replace, p, axis))
|
||||
self.assertEqual(np_shape, sample.shape)
|
||||
expected_shape = np.shape(np_choice(x, shape or None, replace, p, axis))
|
||||
self.assertEqual(expected_shape, sample.shape)
|
||||
expected_dtype = dtypes.result_type(int if is_range else x)
|
||||
self.assertEqual(expected_dtype, sample.dtype)
|
||||
if not replace and shape:
|
||||
def lsort(x):
|
||||
if not math.prod(x.shape): return x
|
||||
|
Loading…
x
Reference in New Issue
Block a user