Merge pull request #15139 from jakevdp:fix-wald

PiperOrigin-RevId: 518634544
This commit is contained in:
jax authors 2023-03-22 11:59:58 -07:00
commit f106d45371
2 changed files with 12 additions and 21 deletions

View File

@ -1883,7 +1883,6 @@ def _rayleigh(key, scale, shape, dtype) -> Array:
def wald(key: KeyArray,
mean: RealArray,
scale: RealArray,
shape: Optional[Shape] = None,
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
"""Sample Wald random values with given shape and float dtype.
@ -1892,11 +1891,9 @@ def wald(key: KeyArray,
key: a PRNG key used as the random key.
mean: a float or array of floats broadcast-compatible with ``shape``
representing the mean parameter of the distribution.
scale: a float or array of floats broadcast-compatible with ``shape``
representing the scale parameter of the distribution.
shape: optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with ``mean`` and ``scale``. The default
(None) produces a result shape equal to ``lax.broadcast_shapes(np.shape(mean), np.shape(scale))``.
shape. Must be broadcast-compatible with ``mean``. The default
(None) produces a result shape equal to ``np.shape(mean)``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
@ -1911,28 +1908,23 @@ def wald(key: KeyArray,
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = core.canonicalize_shape(shape)
return _wald(key, mean, scale, shape, dtype)
return _wald(key, mean, shape, dtype)
@partial(jit, static_argnums=(3, 4), inline=True)
def _wald(key, mean, scale, shape, dtype) -> Array:
@partial(jit, static_argnums=(2, 3), inline=True)
def _wald(key, mean, shape, dtype) -> Array:
if shape is None:
shape = lax.broadcast_shapes(np.shape(mean), np.shape(scale))
shape = np.shape(mean)
else:
_check_shape("wald", shape, np.shape(mean), np.shape(scale))
_check_shape("wald", shape, np.shape(mean))
k1, k2 = _split(key, 2)
mean = mean.astype(dtype)
scale = scale.astype(dtype)
mean = jnp.broadcast_to(mean, shape)
scale = jnp.broadcast_to(scale, shape)
v = normal(k1, shape, dtype)
z = uniform(k2, shape, dtype)
two = _lax_const(mean, 2)
y = lax.integer_pow(v, 2)
y_sq = lax.integer_pow(y, 2)
mean_sq = lax.integer_pow(mean, 2)
mean_two = lax.mul(mean, two)
scale_two = lax.mul(scale, two)
sqrt_term = lax.sqrt(mean_two * scale_two * y + mean_sq * y_sq)
x = mean + mean_sq * y / scale_two - mean / scale_two * sqrt_term
sqrt_term = lax.sqrt(4 * mean * y + mean_sq * y_sq)
x = mean + mean_sq * y / 2 - mean / 2 * sqrt_term
w = lax.select(lax.le(z, mean / (mean + x)), x, mean_sq / x)
return w

View File

@ -1539,18 +1539,17 @@ class LaxRandomTest(jtu.JaxTestCase):
@jtu.sample_product(
mean= [0.2, 1., 2., 10. ,100.],
scale= [0.2, 1., 2., 10. ,100.],
dtype=jtu.dtypes.floating)
def testWald(self, mean, scale, dtype):
def testWald(self, mean, dtype):
key = self.seed_prng(0)
rand = lambda key: random.wald(key, mean, scale, shape = (10000, ), dtype = dtype)
rand = lambda key: random.wald(key, mean, shape=(10000, ), dtype=dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.invgauss(mu=mean / scale, scale = scale).cdf)
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.invgauss(mu=mean).cdf)
class KeyArrayTest(jtu.JaxTestCase):
# Key arrays involve: