jax.random.choice: make return dtype consistent

This commit is contained in:
Jake VanderPlas 2023-05-16 08:52:11 -07:00
parent ec7a939d18
commit 5e14744c2c
2 changed files with 5 additions and 3 deletions

View File

@ -567,7 +567,7 @@ def choice(key: KeyArray,
if replace:
p_cuml = jnp.cumsum(p_arr)
r = p_cuml[-1] * (1 - uniform(key, shape, dtype=p_cuml.dtype))
ind = jnp.searchsorted(p_cuml, r)
ind = jnp.searchsorted(p_cuml, r).astype(int)
else:
# Gumbel top-k trick: https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/
g = -gumbel(key, (n_inputs,), dtype=p_arr.dtype) - jnp.log(p_arr)

View File

@ -754,8 +754,10 @@ class LaxRandomTest(jtu.JaxTestCase):
sample = rand(key, x)
if not is_range:
self.assertEqual(dtype, sample.dtype)
np_shape = np.shape(np_choice(x, shape or None, replace, p, axis))
self.assertEqual(np_shape, sample.shape)
expected_shape = np.shape(np_choice(x, shape or None, replace, p, axis))
self.assertEqual(expected_shape, sample.shape)
expected_dtype = dtypes.result_type(int if is_range else x)
self.assertEqual(expected_dtype, sample.dtype)
if not replace and shape:
def lsort(x):
if not math.prod(x.shape): return x