mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #25655 from carlosgmartin:simplify_random_orthogonal
PiperOrigin-RevId: 722770603
This commit is contained in:
commit
2c10a65b73
@ -2070,8 +2070,8 @@ def orthogonal(
|
||||
n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
|
||||
z = normal(key, (*shape, n, n), dtype)
|
||||
q, r = jnp.linalg.qr(z)
|
||||
d = jnp.diagonal(r, 0, -2, -1)
|
||||
return lax.mul(q, lax.expand_dims(lax.div(d, abs(d).astype(d.dtype)), [-2]))
|
||||
d = jnp.linalg.diagonal(r)
|
||||
return q * jnp.expand_dims(jnp.sign(d), -2)
|
||||
|
||||
def generalized_normal(
|
||||
key: ArrayLike,
|
||||
|
Loading…
x
Reference in New Issue
Block a user