Fix broadcasting in jax.random.rayleigh

This commit is contained in:
Neil Girdhar 2023-03-30 16:38:08 -04:00
parent 67a28ce30f
commit 1d1b131f4b

View File

@ -2036,6 +2036,7 @@ def _rayleigh(key, scale, shape, dtype) -> Array:
_check_shape("rayleigh", shape, np.shape(scale))
u = uniform(key, shape, dtype)
scale = scale.astype(dtype)
scale = jnp.broadcast_to(scale, shape)
log_u = lax.log(u)
n_two = _lax_const(scale, -2)
sqrt_u = lax.sqrt(lax.mul(log_u, n_two))