mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #15139 from jakevdp:fix-wald
PiperOrigin-RevId: 518634544
This commit is contained in:
commit
f106d45371
@ -1883,7 +1883,6 @@ def _rayleigh(key, scale, shape, dtype) -> Array:
|
||||
|
||||
def wald(key: KeyArray,
|
||||
mean: RealArray,
|
||||
scale: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
||||
"""Sample Wald random values with given shape and float dtype.
|
||||
@ -1892,11 +1891,9 @@ def wald(key: KeyArray,
|
||||
key: a PRNG key used as the random key.
|
||||
mean: a float or array of floats broadcast-compatible with ``shape``
|
||||
representing the mean parameter of the distribution.
|
||||
scale: a float or array of floats broadcast-compatible with ``shape``
|
||||
representing the scale parameter of the distribution.
|
||||
shape: optional, a tuple of nonnegative integers specifying the result
|
||||
shape. Must be broadcast-compatible with ``mean`` and ``scale``. The default
|
||||
(None) produces a result shape equal to ``lax.broadcast_shapes(np.shape(mean), np.shape(scale))``.
|
||||
shape. Must be broadcast-compatible with ``mean``. The default
|
||||
(None) produces a result shape equal to ``np.shape(mean)``.
|
||||
dtype: optional, a float dtype for the returned values (default float64 if
|
||||
jax_enable_x64 is true, otherwise float32).
|
||||
|
||||
@ -1911,28 +1908,23 @@ def wald(key: KeyArray,
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
if shape is not None:
|
||||
shape = core.canonicalize_shape(shape)
|
||||
return _wald(key, mean, scale, shape, dtype)
|
||||
return _wald(key, mean, shape, dtype)
|
||||
|
||||
@partial(jit, static_argnums=(3, 4), inline=True)
|
||||
def _wald(key, mean, scale, shape, dtype) -> Array:
|
||||
@partial(jit, static_argnums=(2, 3), inline=True)
|
||||
def _wald(key, mean, shape, dtype) -> Array:
|
||||
if shape is None:
|
||||
shape = lax.broadcast_shapes(np.shape(mean), np.shape(scale))
|
||||
shape = np.shape(mean)
|
||||
else:
|
||||
_check_shape("wald", shape, np.shape(mean), np.shape(scale))
|
||||
_check_shape("wald", shape, np.shape(mean))
|
||||
k1, k2 = _split(key, 2)
|
||||
mean = mean.astype(dtype)
|
||||
scale = scale.astype(dtype)
|
||||
mean = jnp.broadcast_to(mean, shape)
|
||||
scale = jnp.broadcast_to(scale, shape)
|
||||
v = normal(k1, shape, dtype)
|
||||
z = uniform(k2, shape, dtype)
|
||||
two = _lax_const(mean, 2)
|
||||
y = lax.integer_pow(v, 2)
|
||||
y_sq = lax.integer_pow(y, 2)
|
||||
mean_sq = lax.integer_pow(mean, 2)
|
||||
mean_two = lax.mul(mean, two)
|
||||
scale_two = lax.mul(scale, two)
|
||||
sqrt_term = lax.sqrt(mean_two * scale_two * y + mean_sq * y_sq)
|
||||
x = mean + mean_sq * y / scale_two - mean / scale_two * sqrt_term
|
||||
sqrt_term = lax.sqrt(4 * mean * y + mean_sq * y_sq)
|
||||
x = mean + mean_sq * y / 2 - mean / 2 * sqrt_term
|
||||
w = lax.select(lax.le(z, mean / (mean + x)), x, mean_sq / x)
|
||||
return w
|
||||
|
@ -1539,18 +1539,17 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(
|
||||
mean= [0.2, 1., 2., 10. ,100.],
|
||||
scale= [0.2, 1., 2., 10. ,100.],
|
||||
dtype=jtu.dtypes.floating)
|
||||
def testWald(self, mean, scale, dtype):
|
||||
def testWald(self, mean, dtype):
|
||||
key = self.seed_prng(0)
|
||||
rand = lambda key: random.wald(key, mean, scale, shape = (10000, ), dtype = dtype)
|
||||
rand = lambda key: random.wald(key, mean, shape=(10000, ), dtype=dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key)
|
||||
compiled_samples = crand(key)
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.invgauss(mu=mean / scale, scale = scale).cdf)
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.invgauss(mu=mean).cdf)
|
||||
|
||||
class KeyArrayTest(jtu.JaxTestCase):
|
||||
# Key arrays involve:
|
||||
|
Loading…
x
Reference in New Issue
Block a user