mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix broadcasting in jax.random.rayleigh
This commit is contained in:
parent
67a28ce30f
commit
1d1b131f4b
@ -2036,6 +2036,7 @@ def _rayleigh(key, scale, shape, dtype) -> Array:
|
||||
_check_shape("rayleigh", shape, np.shape(scale))
|
||||
u = uniform(key, shape, dtype)
|
||||
scale = scale.astype(dtype)
|
||||
scale = jnp.broadcast_to(scale, shape)
|
||||
log_u = lax.log(u)
|
||||
n_two = _lax_const(scale, -2)
|
||||
sqrt_u = lax.sqrt(lax.mul(log_u, n_two))
|
||||
|
Loading…
x
Reference in New Issue
Block a user