Added random.orthogonal.

This commit is contained in:
Carlos Martin 2022-04-29 14:20:50 -04:00
parent 823ad552d6
commit b276c31b75
5 changed files with 51 additions and 0 deletions

View File

@ -44,6 +44,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
is not of an integer type, matching the behavior of
{func}`numpy.diag`. Previously non-integer `k` was silently
cast to integers.
* Added {func}`jax.random.orthogonal`.
* Deprecations
* Many functions and objects available in {mod}`jax.test_util` are now deprecated and will raise a
warning on import. This includes `cases_from_list`, `check_close`, `check_eq`, `device_under_test`,

View File

@ -33,6 +33,7 @@ List of Available Functions
maxwell
multivariate_normal
normal
orthogonal
pareto
permutation
poisson

View File

@ -15,6 +15,7 @@
from functools import partial
from typing import Any, Optional, Sequence, Union
from operator import index
import warnings
import numpy as np
@ -1595,3 +1596,30 @@ def threefry_2x32(keypair, count):
warnings.warn('jax.random.threefry_2x32 has moved to jax.prng.threefry_2x32 '
'and will be removed from `random` module.', FutureWarning)
return prng.threefry_2x32(keypair, count)
def orthogonal(
key: KeyArray,
n: int,
shape: Sequence[int] = (),
dtype: DTypeLikeFloat = dtypes.float_
) -> jnp.ndarray:
"""Sample uniformly from the orthogonal group O(n).
If the dtype is complex, sample uniformly from the unitary group U(n).
Args:
key: a PRNG key used as the random key.
n: an integer indicating the resulting dimension.
shape: optional, the batch dimensions of the result. Default ().
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array of shape `(*shape, n, n)` and specified dtype.
"""
_check_shape("orthogonal", shape)
n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
z = normal(key, (*shape, n, n), dtype)
q, r = jnp.linalg.qr(z)
d = jnp.diagonal(r, 0, -2, -1)
return q * jnp.expand_dims(d / abs(d), -2)

View File

@ -143,6 +143,7 @@ from jax._src.random import (
maxwell as maxwell,
multivariate_normal as multivariate_normal,
normal as normal,
orthogonal as orthogonal,
pareto as pareto,
permutation as permutation,
poisson as poisson,

View File

@ -1032,6 +1032,26 @@ class LaxRandomTest(jtu.JaxTestCase):
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.logistic().cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_n={}_shape={}"\
.format(n, jtu.format_shape_dtype_string(shape, dtype)),
"n": n,
"shape": shape,
"dtype": dtype}
for n in range(1, 5)
for shape in [(), (5,), (10, 5)]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
def testOrthogonal(self, n, shape, dtype):
key = self.seed_prng(0)
q = random.orthogonal(key, n, shape, dtype)
self.assertEqual(q.shape, (*shape, n, n))
self.assertEqual(q.dtype, dtype)
with jax.numpy_rank_promotion('allow'):
self.assertAllClose(
jnp.einsum('...ij,...jk->...ik', q, jnp.conj(q).swapaxes(-2, -1)),
jnp.broadcast_to(jnp.eye(n, dtype=dtype), (*shape, n, n))
)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_b={}_dtype={}".format(b, np.dtype(dtype).name),
"b": b, "dtype": dtype}