mirror of
https://github.com/ROCm/jax.git
synced 2025-04-13 02:16:06 +00:00
Add replace option to random.categorical to enable sampling without replacement.
This commit is contained in:
parent
de9ad6bad9
commit
3f59fa6888
@ -22,6 +22,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
{func}`jax.lax.dynamic_update_slice` and related functions. The default is
|
||||
true, matching the current behavior. If set to false, JAX does not need to
|
||||
emit code clamping negative indices, which improves code size.
|
||||
* Added a `replace` option to {func}`jax.random.categorical` to enable sampling
|
||||
without replacement.
|
||||
|
||||
## jax 0.5.2 (Mar 4, 2025)
|
||||
|
||||
|
@ -1548,12 +1548,18 @@ def _gumbel(key, shape, dtype, mode) -> Array:
|
||||
_uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.)))
|
||||
|
||||
|
||||
def categorical(key: ArrayLike,
|
||||
logits: RealArray,
|
||||
axis: int = -1,
|
||||
shape: Shape | None = None) -> Array:
|
||||
def categorical(
|
||||
key: ArrayLike,
|
||||
logits: RealArray,
|
||||
axis: int = -1,
|
||||
shape: Shape | None = None,
|
||||
replace: bool = True,
|
||||
) -> Array:
|
||||
"""Sample random values from categorical distributions.
|
||||
|
||||
Sampling with replacement uses the Gumbel max trick. Sampling without replacement uses
|
||||
the Gumbel top-k trick. See [1] for reference.
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
logits: Unnormalized log probabilities of the categorical distribution(s) to sample from,
|
||||
@ -1562,32 +1568,57 @@ def categorical(key: ArrayLike,
|
||||
shape: Optional, a tuple of nonnegative integers representing the result shape.
|
||||
Must be broadcast-compatible with ``np.delete(logits.shape, axis)``.
|
||||
The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``.
|
||||
replace: If True, perform sampling without replacement. Default (False) is to
|
||||
perform sampling with replacement.
|
||||
|
||||
Returns:
|
||||
A random array with int dtype and shape given by ``shape`` if ``shape``
|
||||
is not None, or else ``np.delete(logits.shape, axis)``.
|
||||
|
||||
References:
|
||||
.. [1] Wouter Kool, Herke van Hoof, Max Welling. "Stochastic Beams and Where to Find
|
||||
Them: The Gumbel-Top-k Trick for Sampling Sequences Without Replacement".
|
||||
Proceedings of the 36th International Conference on Machine Learning, PMLR
|
||||
97:3499-3508, 2019. https://proceedings.mlr.press/v97/kool19a.html.
|
||||
"""
|
||||
key, _ = _check_prng_key("categorical", key)
|
||||
check_arraylike("categorical", logits)
|
||||
logits_arr = jnp.asarray(logits)
|
||||
|
||||
if axis >= 0:
|
||||
axis -= len(logits_arr.shape)
|
||||
|
||||
batch_shape = tuple(np.delete(logits_arr.shape, axis))
|
||||
if shape is None:
|
||||
shape = batch_shape
|
||||
else:
|
||||
shape = core.canonicalize_shape(shape)
|
||||
_check_shape("categorical", shape, batch_shape)
|
||||
|
||||
shape_prefix = shape[:len(shape)-len(batch_shape)]
|
||||
logits_shape = list(shape[len(shape) - len(batch_shape):])
|
||||
logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis])
|
||||
return jnp.argmax(
|
||||
gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype) +
|
||||
lax.expand_dims(logits_arr, tuple(range(len(shape_prefix)))),
|
||||
axis=axis)
|
||||
|
||||
if replace:
|
||||
if axis >= 0:
|
||||
axis -= len(logits_arr.shape)
|
||||
|
||||
logits_shape = list(shape[len(shape) - len(batch_shape):])
|
||||
logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis])
|
||||
return jnp.argmax(
|
||||
gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype) +
|
||||
lax.expand_dims(logits_arr, tuple(range(len(shape_prefix)))),
|
||||
axis=axis)
|
||||
else:
|
||||
logits_arr += gumbel(key, logits_arr.shape, logits_arr.dtype)
|
||||
k = math.prod(shape_prefix)
|
||||
if k > logits_arr.shape[axis]:
|
||||
raise ValueError(
|
||||
f"Number of samples without replacement ({k}) cannot exceed number of "
|
||||
f"categories ({logits_arr.shape[axis]})."
|
||||
)
|
||||
|
||||
_, indices = lax.top_k(jnp.moveaxis(logits_arr, axis, -1), k)
|
||||
assert indices.shape == batch_shape + (k,)
|
||||
assert shape == shape_prefix + batch_shape
|
||||
|
||||
dimensions = (indices.ndim - 1, *range(indices.ndim - 1))
|
||||
indices = lax.reshape(indices, shape, dimensions)
|
||||
assert indices.shape == shape
|
||||
return indices
|
||||
|
||||
|
||||
def laplace(key: ArrayLike,
|
||||
|
@ -365,6 +365,38 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
pmf = lambda x: np.where(x < len(p), p[np.minimum(len(p) - 1, x)], 0.0)
|
||||
self._CheckChiSquared(samples, pmf=pmf)
|
||||
|
||||
@jtu.sample_product(
|
||||
logits_shape=[(7,), (8, 9), (10, 11, 12)],
|
||||
prefix_shape=[(2,), (3, 4), (5, 6)],
|
||||
)
|
||||
def testCategoricalWithoutReplacement(self, logits_shape, prefix_shape):
|
||||
key = random.key(0)
|
||||
|
||||
key, subkey = random.split(key)
|
||||
logits = random.normal(subkey, logits_shape)
|
||||
|
||||
key, subkey = random.split(key)
|
||||
axis = random.randint(subkey, (), -len(logits_shape), len(logits_shape))
|
||||
|
||||
dists_shape = tuple(np.delete(logits_shape, axis))
|
||||
n_categories = logits_shape[axis]
|
||||
shape = prefix_shape + dists_shape
|
||||
prefix_size = math.prod(prefix_shape)
|
||||
|
||||
if n_categories < prefix_size:
|
||||
with self.assertRaisesRegex(ValueError, "Number of samples without replacement"):
|
||||
random.categorical(key, logits, axis=axis, shape=shape, replace=False)
|
||||
|
||||
else:
|
||||
output = random.categorical(key, logits, axis=axis, shape=shape, replace=False)
|
||||
self.assertEqual(output.shape, shape)
|
||||
assert (0 <= output).all()
|
||||
assert (output < n_categories).all()
|
||||
flat = output.reshape((prefix_size, math.prod(dists_shape)))
|
||||
counts = jax.vmap(partial(jnp.bincount, length=n_categories), 1)(flat)
|
||||
assert (counts <= 1).all()
|
||||
|
||||
|
||||
def testBernoulliShape(self):
|
||||
key = self.make_key(0)
|
||||
with jax.numpy_rank_promotion('allow'):
|
||||
|
Loading…
x
Reference in New Issue
Block a user