Merge pull request #25655 from carlosgmartin:simplify_random_orthogonal

PiperOrigin-RevId: 722770603
This commit is contained in:
jax authors 2025-02-03 13:12:09 -08:00
commit 2c10a65b73

View File

@ -2070,8 +2070,8 @@ def orthogonal(
n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
z = normal(key, (*shape, n, n), dtype)
q, r = jnp.linalg.qr(z)
d = jnp.diagonal(r, 0, -2, -1)
return lax.mul(q, lax.expand_dims(lax.div(d, abs(d).astype(d.dtype)), [-2]))
d = jnp.linalg.diagonal(r)
return q * jnp.expand_dims(jnp.sign(d), -2)
def generalized_normal(
key: ArrayLike,