diff --git a/docs/jax.random.rst b/docs/jax.random.rst index 1d7f68e39..64873b727 100644 --- a/docs/jax.random.rst +++ b/docs/jax.random.rst @@ -30,6 +30,7 @@ List of Available Functions fold_in gamma generalized_normal + geometric gumbel laplace loggamma diff --git a/jax/_src/random.py b/jax/_src/random.py index 1f2a46a9d..aa50abdce 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -2098,3 +2098,53 @@ def _wald(key, mean, shape, dtype) -> Array: x = mean + mean_sq * y / 2 - mean / 2 * sqrt_term w = lax.select(lax.le(z, mean / (mean + x)), x, mean_sq / x) return w + +def geometric(key: KeyArray, + p: RealArray, + shape: Optional[Shape] = None, + dtype: DTypeLikeInt = dtypes.int_) -> Array: + r"""Sample Geometric random values with given shape and float dtype. + + The values are returned according to the probability mass function: + + .. math:: + f(k;p) = p(1-p)^{k-1} + + on the domain :math:`0 < p < 1`. + + Args: + key: a PRNG key used as the random key. + p: a float or array of floats broadcast-compatible with ``shape`` + representing the the probability of success of an individual trial. + shape: optional, a tuple of nonnegative integers specifying the result + shape. Must be broadcast-compatible with ``p``. The default + (None) produces a result shape equal to ``np.shape(p)``. + dtype: optional, a int dtype for the returned values (default int64 if + jax_enable_x64 is true, otherwise int32). + + Returns: + A random array with the specified dtype and with shape given by ``shape`` if + ``shape`` is not None, or else by ``p.shape``. + """ + key, _ = _check_prng_key(key) + if not dtypes.issubdtype(dtype, np.integer): + raise ValueError("dtype argument to `geometric` must be an int " + f"dtype, got {dtype}") + dtype = dtypes.canonicalize_dtype(dtype) + if shape is not None: + shape = core.canonicalize_shape(shape) + return _geometric(key, p, shape, dtype) + +@partial(jit, static_argnums=(2, 3), inline=True) +def _geometric(key, p, shape, dtype) -> Array: + if shape is None: + shape = np.shape(p) + else: + _check_shape("geometric", shape, np.shape(p)) + check_arraylike("geometric", p) + p, = promote_dtypes_inexact(p) + u = uniform(key, shape, p.dtype) + log_u = lax.log(u) + log_one_minus_p = lax.log1p(-p) + g = lax.floor(lax.div(log_u, log_one_minus_p)) + 1 + return g.astype(dtype) diff --git a/jax/random.py b/jax/random.py index 8a84098f7..6cbcd9739 100644 --- a/jax/random.py +++ b/jax/random.py @@ -163,6 +163,7 @@ from jax._src.random import ( fold_in as fold_in, gamma as gamma, generalized_normal as generalized_normal, + geometric as geometric, gumbel as gumbel, key_data as key_data, laplace as laplace, diff --git a/tests/random_test.py b/tests/random_test.py index cdf46e6ad..78732c9fe 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1551,6 +1551,22 @@ class LaxRandomTest(jtu.JaxTestCase): for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.invgauss(mu=mean).cdf) + @jtu.sample_product( + p= [0.2, 0.3, 0.4, 0.5 ,0.6], + dtype= [np.int16, np.int32, np.int64]) + def testGeometric(self, p, dtype): + key = self.seed_prng(1) + rand = lambda key: random.geometric(key, p, shape=(10000, ), dtype=dtype) + crand = jax.jit(rand) + + uncompiled_samples = rand(key) + compiled_samples = crand(key) + + for samples in [uncompiled_samples, compiled_samples]: + self._CheckChiSquared(samples, scipy.stats.geom(p).pmf) + self.assertAllClose(samples.mean(), 1 / p, rtol=0.02, check_dtypes=False) + self.assertAllClose(samples.var(), (1 - p) / (p * p) , rtol=0.05, check_dtypes=False) + class KeyArrayTest(jtu.JaxTestCase): # Key arrays involve: # * a Python key array type, backed by an underlying uint32 "base" array,