mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #15114 from JiaYaobo:add_wald_random
PiperOrigin-RevId: 518592428
This commit is contained in:
commit
00e6c73b68
@ -49,5 +49,6 @@ List of Available Functions
|
||||
t
|
||||
truncated_normal
|
||||
uniform
|
||||
wald
|
||||
weibull_min
|
||||
|
||||
|
@ -1880,3 +1880,59 @@ def _rayleigh(key, scale, shape, dtype) -> Array:
|
||||
sqrt_u = lax.sqrt(lax.mul(log_u, n_two))
|
||||
ray = lax.mul(scale, sqrt_u)
|
||||
return ray
|
||||
|
||||
def wald(key: KeyArray,
|
||||
mean: RealArray,
|
||||
scale: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
"""Sample Wald random values with given shape and float dtype.
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
mean: a float or array of floats broadcast-compatible with ``shape``
|
||||
representing the mean parameter of the distribution.
|
||||
scale: a float or array of floats broadcast-compatible with ``shape``
|
||||
representing the scale parameter of the distribution.
|
||||
shape: optional, a tuple of nonnegative integers specifying the result
|
||||
shape. Must be broadcast-compatible with ``mean`` and ``scale``. The default
|
||||
(None) produces a result shape equal to ``lax.broadcast_shapes(np.shape(mean), np.shape(scale))``.
|
||||
dtype: optional, a float dtype for the returned values (default float64 if
|
||||
jax_enable_x64 is true, otherwise float32).
|
||||
|
||||
Returns:
|
||||
A random array with the specified dtype and with shape given by ``shape`` if
|
||||
``shape`` is not None, or else by ``mean.shape`` and ``scale.shape``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError("dtype argument to `wald` must be a float "
|
||||
f"dtype, got {dtype}")
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
if shape is not None:
|
||||
shape = core.canonicalize_shape(shape)
|
||||
return _wald(key, mean, scale, shape, dtype)
|
||||
|
||||
@partial(jit, static_argnums=(3, 4), inline=True)
|
||||
def _wald(key, mean, scale, shape, dtype) -> Array:
|
||||
if shape is None:
|
||||
shape = lax.broadcast_shapes(np.shape(mean), np.shape(scale))
|
||||
else:
|
||||
_check_shape("wald", shape, np.shape(mean), np.shape(scale))
|
||||
k1, k2 = _split(key, 2)
|
||||
mean = mean.astype(dtype)
|
||||
scale = scale.astype(dtype)
|
||||
mean = jnp.broadcast_to(mean, shape)
|
||||
scale = jnp.broadcast_to(scale, shape)
|
||||
v = normal(k1, shape, dtype)
|
||||
z = uniform(k2, shape, dtype)
|
||||
two = _lax_const(mean, 2)
|
||||
y = lax.integer_pow(v, 2)
|
||||
y_sq = lax.integer_pow(y, 2)
|
||||
mean_sq = lax.integer_pow(mean, 2)
|
||||
mean_two = lax.mul(mean, two)
|
||||
scale_two = lax.mul(scale, two)
|
||||
sqrt_term = lax.sqrt(mean_two * scale_two * y + mean_sq * y_sq)
|
||||
x = mean + mean_sq * y / scale_two - mean / scale_two * sqrt_term
|
||||
w = lax.select(lax.le(z, mean / (mean + x)), x, mean_sq / x)
|
||||
return w
|
||||
|
@ -189,5 +189,6 @@ from jax._src.random import (
|
||||
truncated_normal as truncated_normal,
|
||||
uniform as uniform,
|
||||
unsafe_rbg_key as unsafe_rbg_key,
|
||||
wald as wald,
|
||||
weibull_min as weibull_min,
|
||||
)
|
||||
|
@ -1537,6 +1537,21 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.rayleigh(scale=scale).cdf)
|
||||
|
||||
@jtu.sample_product(
|
||||
mean= [0.2, 1., 2., 10. ,100.],
|
||||
scale= [0.2, 1., 2., 10. ,100.],
|
||||
dtype=jtu.dtypes.floating)
|
||||
def testWald(self, mean, scale, dtype):
|
||||
key = self.seed_prng(0)
|
||||
rand = lambda key: random.wald(key, mean, scale, shape = (10000, ), dtype = dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key)
|
||||
compiled_samples = crand(key)
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.invgauss(mu=mean / scale, scale = scale).cdf)
|
||||
|
||||
class KeyArrayTest(jtu.JaxTestCase):
|
||||
# Key arrays involve:
|
||||
# * a Python key array type, backed by an underlying uint32 "base" array,
|
||||
|
Loading…
x
Reference in New Issue
Block a user