mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Add replace option to random.categorical to enable sampling without replacement.
This commit is contained in:
parent
de9ad6bad9
commit
3f59fa6888
@ -22,6 +22,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
|||||||
{func}`jax.lax.dynamic_update_slice` and related functions. The default is
|
{func}`jax.lax.dynamic_update_slice` and related functions. The default is
|
||||||
true, matching the current behavior. If set to false, JAX does not need to
|
true, matching the current behavior. If set to false, JAX does not need to
|
||||||
emit code clamping negative indices, which improves code size.
|
emit code clamping negative indices, which improves code size.
|
||||||
|
* Added a `replace` option to {func}`jax.random.categorical` to enable sampling
|
||||||
|
without replacement.
|
||||||
|
|
||||||
## jax 0.5.2 (Mar 4, 2025)
|
## jax 0.5.2 (Mar 4, 2025)
|
||||||
|
|
||||||
|
@ -1548,12 +1548,18 @@ def _gumbel(key, shape, dtype, mode) -> Array:
|
|||||||
_uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.)))
|
_uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.)))
|
||||||
|
|
||||||
|
|
||||||
def categorical(key: ArrayLike,
|
def categorical(
|
||||||
logits: RealArray,
|
key: ArrayLike,
|
||||||
axis: int = -1,
|
logits: RealArray,
|
||||||
shape: Shape | None = None) -> Array:
|
axis: int = -1,
|
||||||
|
shape: Shape | None = None,
|
||||||
|
replace: bool = True,
|
||||||
|
) -> Array:
|
||||||
"""Sample random values from categorical distributions.
|
"""Sample random values from categorical distributions.
|
||||||
|
|
||||||
|
Sampling with replacement uses the Gumbel max trick. Sampling without replacement uses
|
||||||
|
the Gumbel top-k trick. See [1] for reference.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key: a PRNG key used as the random key.
|
key: a PRNG key used as the random key.
|
||||||
logits: Unnormalized log probabilities of the categorical distribution(s) to sample from,
|
logits: Unnormalized log probabilities of the categorical distribution(s) to sample from,
|
||||||
@ -1562,32 +1568,57 @@ def categorical(key: ArrayLike,
|
|||||||
shape: Optional, a tuple of nonnegative integers representing the result shape.
|
shape: Optional, a tuple of nonnegative integers representing the result shape.
|
||||||
Must be broadcast-compatible with ``np.delete(logits.shape, axis)``.
|
Must be broadcast-compatible with ``np.delete(logits.shape, axis)``.
|
||||||
The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``.
|
The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``.
|
||||||
|
replace: If True, perform sampling without replacement. Default (False) is to
|
||||||
|
perform sampling with replacement.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A random array with int dtype and shape given by ``shape`` if ``shape``
|
A random array with int dtype and shape given by ``shape`` if ``shape``
|
||||||
is not None, or else ``np.delete(logits.shape, axis)``.
|
is not None, or else ``np.delete(logits.shape, axis)``.
|
||||||
|
|
||||||
|
References:
|
||||||
|
.. [1] Wouter Kool, Herke van Hoof, Max Welling. "Stochastic Beams and Where to Find
|
||||||
|
Them: The Gumbel-Top-k Trick for Sampling Sequences Without Replacement".
|
||||||
|
Proceedings of the 36th International Conference on Machine Learning, PMLR
|
||||||
|
97:3499-3508, 2019. https://proceedings.mlr.press/v97/kool19a.html.
|
||||||
"""
|
"""
|
||||||
key, _ = _check_prng_key("categorical", key)
|
key, _ = _check_prng_key("categorical", key)
|
||||||
check_arraylike("categorical", logits)
|
check_arraylike("categorical", logits)
|
||||||
logits_arr = jnp.asarray(logits)
|
logits_arr = jnp.asarray(logits)
|
||||||
|
|
||||||
if axis >= 0:
|
|
||||||
axis -= len(logits_arr.shape)
|
|
||||||
|
|
||||||
batch_shape = tuple(np.delete(logits_arr.shape, axis))
|
batch_shape = tuple(np.delete(logits_arr.shape, axis))
|
||||||
if shape is None:
|
if shape is None:
|
||||||
shape = batch_shape
|
shape = batch_shape
|
||||||
else:
|
else:
|
||||||
shape = core.canonicalize_shape(shape)
|
shape = core.canonicalize_shape(shape)
|
||||||
_check_shape("categorical", shape, batch_shape)
|
_check_shape("categorical", shape, batch_shape)
|
||||||
|
|
||||||
shape_prefix = shape[:len(shape)-len(batch_shape)]
|
shape_prefix = shape[:len(shape)-len(batch_shape)]
|
||||||
logits_shape = list(shape[len(shape) - len(batch_shape):])
|
|
||||||
logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis])
|
if replace:
|
||||||
return jnp.argmax(
|
if axis >= 0:
|
||||||
gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype) +
|
axis -= len(logits_arr.shape)
|
||||||
lax.expand_dims(logits_arr, tuple(range(len(shape_prefix)))),
|
|
||||||
axis=axis)
|
logits_shape = list(shape[len(shape) - len(batch_shape):])
|
||||||
|
logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis])
|
||||||
|
return jnp.argmax(
|
||||||
|
gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype) +
|
||||||
|
lax.expand_dims(logits_arr, tuple(range(len(shape_prefix)))),
|
||||||
|
axis=axis)
|
||||||
|
else:
|
||||||
|
logits_arr += gumbel(key, logits_arr.shape, logits_arr.dtype)
|
||||||
|
k = math.prod(shape_prefix)
|
||||||
|
if k > logits_arr.shape[axis]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of samples without replacement ({k}) cannot exceed number of "
|
||||||
|
f"categories ({logits_arr.shape[axis]})."
|
||||||
|
)
|
||||||
|
|
||||||
|
_, indices = lax.top_k(jnp.moveaxis(logits_arr, axis, -1), k)
|
||||||
|
assert indices.shape == batch_shape + (k,)
|
||||||
|
assert shape == shape_prefix + batch_shape
|
||||||
|
|
||||||
|
dimensions = (indices.ndim - 1, *range(indices.ndim - 1))
|
||||||
|
indices = lax.reshape(indices, shape, dimensions)
|
||||||
|
assert indices.shape == shape
|
||||||
|
return indices
|
||||||
|
|
||||||
|
|
||||||
def laplace(key: ArrayLike,
|
def laplace(key: ArrayLike,
|
||||||
|
@ -365,6 +365,38 @@ class LaxRandomTest(jtu.JaxTestCase):
|
|||||||
pmf = lambda x: np.where(x < len(p), p[np.minimum(len(p) - 1, x)], 0.0)
|
pmf = lambda x: np.where(x < len(p), p[np.minimum(len(p) - 1, x)], 0.0)
|
||||||
self._CheckChiSquared(samples, pmf=pmf)
|
self._CheckChiSquared(samples, pmf=pmf)
|
||||||
|
|
||||||
|
@jtu.sample_product(
|
||||||
|
logits_shape=[(7,), (8, 9), (10, 11, 12)],
|
||||||
|
prefix_shape=[(2,), (3, 4), (5, 6)],
|
||||||
|
)
|
||||||
|
def testCategoricalWithoutReplacement(self, logits_shape, prefix_shape):
|
||||||
|
key = random.key(0)
|
||||||
|
|
||||||
|
key, subkey = random.split(key)
|
||||||
|
logits = random.normal(subkey, logits_shape)
|
||||||
|
|
||||||
|
key, subkey = random.split(key)
|
||||||
|
axis = random.randint(subkey, (), -len(logits_shape), len(logits_shape))
|
||||||
|
|
||||||
|
dists_shape = tuple(np.delete(logits_shape, axis))
|
||||||
|
n_categories = logits_shape[axis]
|
||||||
|
shape = prefix_shape + dists_shape
|
||||||
|
prefix_size = math.prod(prefix_shape)
|
||||||
|
|
||||||
|
if n_categories < prefix_size:
|
||||||
|
with self.assertRaisesRegex(ValueError, "Number of samples without replacement"):
|
||||||
|
random.categorical(key, logits, axis=axis, shape=shape, replace=False)
|
||||||
|
|
||||||
|
else:
|
||||||
|
output = random.categorical(key, logits, axis=axis, shape=shape, replace=False)
|
||||||
|
self.assertEqual(output.shape, shape)
|
||||||
|
assert (0 <= output).all()
|
||||||
|
assert (output < n_categories).all()
|
||||||
|
flat = output.reshape((prefix_size, math.prod(dists_shape)))
|
||||||
|
counts = jax.vmap(partial(jnp.bincount, length=n_categories), 1)(flat)
|
||||||
|
assert (counts <= 1).all()
|
||||||
|
|
||||||
|
|
||||||
def testBernoulliShape(self):
|
def testBernoulliShape(self):
|
||||||
key = self.make_key(0)
|
key = self.make_key(0)
|
||||||
with jax.numpy_rank_promotion('allow'):
|
with jax.numpy_rank_promotion('allow'):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user