Merge pull request #15282 from JiaYaobo:geom_random

PiperOrigin-RevId: 520635974
This commit is contained in:
jax authors 2023-03-30 07:45:19 -07:00
commit dedfc8df75
4 changed files with 68 additions and 0 deletions

View File

@ -30,6 +30,7 @@ List of Available Functions
fold_in
gamma
generalized_normal
geometric
gumbel
laplace
loggamma

View File

@ -2098,3 +2098,53 @@ def _wald(key, mean, shape, dtype) -> Array:
x = mean + mean_sq * y / 2 - mean / 2 * sqrt_term
w = lax.select(lax.le(z, mean / (mean + x)), x, mean_sq / x)
return w
def geometric(key: KeyArray,
p: RealArray,
shape: Optional[Shape] = None,
dtype: DTypeLikeInt = dtypes.int_) -> Array:
r"""Sample Geometric random values with given shape and float dtype.
The values are returned according to the probability mass function:
.. math::
f(k;p) = p(1-p)^{k-1}
on the domain :math:`0 < p < 1`.
Args:
key: a PRNG key used as the random key.
p: a float or array of floats broadcast-compatible with ``shape``
representing the the probability of success of an individual trial.
shape: optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with ``p``. The default
(None) produces a result shape equal to ``np.shape(p)``.
dtype: optional, a int dtype for the returned values (default int64 if
jax_enable_x64 is true, otherwise int32).
Returns:
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``p.shape``.
"""
key, _ = _check_prng_key(key)
if not dtypes.issubdtype(dtype, np.integer):
raise ValueError("dtype argument to `geometric` must be an int "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = core.canonicalize_shape(shape)
return _geometric(key, p, shape, dtype)
@partial(jit, static_argnums=(2, 3), inline=True)
def _geometric(key, p, shape, dtype) -> Array:
if shape is None:
shape = np.shape(p)
else:
_check_shape("geometric", shape, np.shape(p))
check_arraylike("geometric", p)
p, = promote_dtypes_inexact(p)
u = uniform(key, shape, p.dtype)
log_u = lax.log(u)
log_one_minus_p = lax.log1p(-p)
g = lax.floor(lax.div(log_u, log_one_minus_p)) + 1
return g.astype(dtype)

View File

@ -163,6 +163,7 @@ from jax._src.random import (
fold_in as fold_in,
gamma as gamma,
generalized_normal as generalized_normal,
geometric as geometric,
gumbel as gumbel,
key_data as key_data,
laplace as laplace,

View File

@ -1551,6 +1551,22 @@ class LaxRandomTest(jtu.JaxTestCase):
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.invgauss(mu=mean).cdf)
@jtu.sample_product(
p= [0.2, 0.3, 0.4, 0.5 ,0.6],
dtype= [np.int16, np.int32, np.int64])
def testGeometric(self, p, dtype):
key = self.seed_prng(1)
rand = lambda key: random.geometric(key, p, shape=(10000, ), dtype=dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckChiSquared(samples, scipy.stats.geom(p).pmf)
self.assertAllClose(samples.mean(), 1 / p, rtol=0.02, check_dtypes=False)
self.assertAllClose(samples.var(), (1 - p) / (p * p) , rtol=0.05, check_dtypes=False)
class KeyArrayTest(jtu.JaxTestCase):
# Key arrays involve:
# * a Python key array type, backed by an underlying uint32 "base" array,