mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #15282 from JiaYaobo:geom_random
PiperOrigin-RevId: 520635974
This commit is contained in:
commit
dedfc8df75
@ -30,6 +30,7 @@ List of Available Functions
|
||||
fold_in
|
||||
gamma
|
||||
generalized_normal
|
||||
geometric
|
||||
gumbel
|
||||
laplace
|
||||
loggamma
|
||||
|
@ -2098,3 +2098,53 @@ def _wald(key, mean, shape, dtype) -> Array:
|
||||
x = mean + mean_sq * y / 2 - mean / 2 * sqrt_term
|
||||
w = lax.select(lax.le(z, mean / (mean + x)), x, mean_sq / x)
|
||||
return w
|
||||
|
||||
def geometric(key: KeyArray,
|
||||
p: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeInt = dtypes.int_) -> Array:
|
||||
r"""Sample Geometric random values with given shape and float dtype.
|
||||
|
||||
The values are returned according to the probability mass function:
|
||||
|
||||
.. math::
|
||||
f(k;p) = p(1-p)^{k-1}
|
||||
|
||||
on the domain :math:`0 < p < 1`.
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
p: a float or array of floats broadcast-compatible with ``shape``
|
||||
representing the the probability of success of an individual trial.
|
||||
shape: optional, a tuple of nonnegative integers specifying the result
|
||||
shape. Must be broadcast-compatible with ``p``. The default
|
||||
(None) produces a result shape equal to ``np.shape(p)``.
|
||||
dtype: optional, a int dtype for the returned values (default int64 if
|
||||
jax_enable_x64 is true, otherwise int32).
|
||||
|
||||
Returns:
|
||||
A random array with the specified dtype and with shape given by ``shape`` if
|
||||
``shape`` is not None, or else by ``p.shape``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
if not dtypes.issubdtype(dtype, np.integer):
|
||||
raise ValueError("dtype argument to `geometric` must be an int "
|
||||
f"dtype, got {dtype}")
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
if shape is not None:
|
||||
shape = core.canonicalize_shape(shape)
|
||||
return _geometric(key, p, shape, dtype)
|
||||
|
||||
@partial(jit, static_argnums=(2, 3), inline=True)
|
||||
def _geometric(key, p, shape, dtype) -> Array:
|
||||
if shape is None:
|
||||
shape = np.shape(p)
|
||||
else:
|
||||
_check_shape("geometric", shape, np.shape(p))
|
||||
check_arraylike("geometric", p)
|
||||
p, = promote_dtypes_inexact(p)
|
||||
u = uniform(key, shape, p.dtype)
|
||||
log_u = lax.log(u)
|
||||
log_one_minus_p = lax.log1p(-p)
|
||||
g = lax.floor(lax.div(log_u, log_one_minus_p)) + 1
|
||||
return g.astype(dtype)
|
||||
|
@ -163,6 +163,7 @@ from jax._src.random import (
|
||||
fold_in as fold_in,
|
||||
gamma as gamma,
|
||||
generalized_normal as generalized_normal,
|
||||
geometric as geometric,
|
||||
gumbel as gumbel,
|
||||
key_data as key_data,
|
||||
laplace as laplace,
|
||||
|
@ -1551,6 +1551,22 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.invgauss(mu=mean).cdf)
|
||||
|
||||
@jtu.sample_product(
|
||||
p= [0.2, 0.3, 0.4, 0.5 ,0.6],
|
||||
dtype= [np.int16, np.int32, np.int64])
|
||||
def testGeometric(self, p, dtype):
|
||||
key = self.seed_prng(1)
|
||||
rand = lambda key: random.geometric(key, p, shape=(10000, ), dtype=dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key)
|
||||
compiled_samples = crand(key)
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckChiSquared(samples, scipy.stats.geom(p).pmf)
|
||||
self.assertAllClose(samples.mean(), 1 / p, rtol=0.02, check_dtypes=False)
|
||||
self.assertAllClose(samples.var(), (1 - p) / (p * p) , rtol=0.05, check_dtypes=False)
|
||||
|
||||
class KeyArrayTest(jtu.JaxTestCase):
|
||||
# Key arrays involve:
|
||||
# * a Python key array type, backed by an underlying uint32 "base" array,
|
||||
|
Loading…
x
Reference in New Issue
Block a user