Added random.orthogonal.

This commit is contained in:
Carlos Martin 2022-04-29 14:20:50 -04:00
parent 823ad552d6
commit b276c31b75
5 changed files with 51 additions and 0 deletions

View File

@ -44,6 +44,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
is not of an integer type, matching the behavior of is not of an integer type, matching the behavior of
{func}`numpy.diag`. Previously non-integer `k` was silently {func}`numpy.diag`. Previously non-integer `k` was silently
cast to integers. cast to integers.
* Added {func}`jax.random.orthogonal`.
* Deprecations * Deprecations
* Many functions and objects available in {mod}`jax.test_util` are now deprecated and will raise a * Many functions and objects available in {mod}`jax.test_util` are now deprecated and will raise a
warning on import. This includes `cases_from_list`, `check_close`, `check_eq`, `device_under_test`, warning on import. This includes `cases_from_list`, `check_close`, `check_eq`, `device_under_test`,

View File

@ -33,6 +33,7 @@ List of Available Functions
maxwell maxwell
multivariate_normal multivariate_normal
normal normal
orthogonal
pareto pareto
permutation permutation
poisson poisson

View File

@ -15,6 +15,7 @@
from functools import partial from functools import partial
from typing import Any, Optional, Sequence, Union from typing import Any, Optional, Sequence, Union
from operator import index
import warnings import warnings
import numpy as np import numpy as np
@ -1595,3 +1596,30 @@ def threefry_2x32(keypair, count):
warnings.warn('jax.random.threefry_2x32 has moved to jax.prng.threefry_2x32 ' warnings.warn('jax.random.threefry_2x32 has moved to jax.prng.threefry_2x32 '
'and will be removed from `random` module.', FutureWarning) 'and will be removed from `random` module.', FutureWarning)
return prng.threefry_2x32(keypair, count) return prng.threefry_2x32(keypair, count)
def orthogonal(
key: KeyArray,
n: int,
shape: Sequence[int] = (),
dtype: DTypeLikeFloat = dtypes.float_
) -> jnp.ndarray:
"""Sample uniformly from the orthogonal group O(n).
If the dtype is complex, sample uniformly from the unitary group U(n).
Args:
key: a PRNG key used as the random key.
n: an integer indicating the resulting dimension.
shape: optional, the batch dimensions of the result. Default ().
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array of shape `(*shape, n, n)` and specified dtype.
"""
_check_shape("orthogonal", shape)
n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
z = normal(key, (*shape, n, n), dtype)
q, r = jnp.linalg.qr(z)
d = jnp.diagonal(r, 0, -2, -1)
return q * jnp.expand_dims(d / abs(d), -2)

View File

@ -143,6 +143,7 @@ from jax._src.random import (
maxwell as maxwell, maxwell as maxwell,
multivariate_normal as multivariate_normal, multivariate_normal as multivariate_normal,
normal as normal, normal as normal,
orthogonal as orthogonal,
pareto as pareto, pareto as pareto,
permutation as permutation, permutation as permutation,
poisson as poisson, poisson as poisson,

View File

@ -1032,6 +1032,26 @@ class LaxRandomTest(jtu.JaxTestCase):
for samples in [uncompiled_samples, compiled_samples]: for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.logistic().cdf) self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.logistic().cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_n={}_shape={}"\
.format(n, jtu.format_shape_dtype_string(shape, dtype)),
"n": n,
"shape": shape,
"dtype": dtype}
for n in range(1, 5)
for shape in [(), (5,), (10, 5)]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
def testOrthogonal(self, n, shape, dtype):
key = self.seed_prng(0)
q = random.orthogonal(key, n, shape, dtype)
self.assertEqual(q.shape, (*shape, n, n))
self.assertEqual(q.dtype, dtype)
with jax.numpy_rank_promotion('allow'):
self.assertAllClose(
jnp.einsum('...ij,...jk->...ik', q, jnp.conj(q).swapaxes(-2, -1)),
jnp.broadcast_to(jnp.eye(n, dtype=dtype), (*shape, n, n))
)
@parameterized.named_parameters(jtu.cases_from_list( @parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_b={}_dtype={}".format(b, np.dtype(dtype).name), {"testcase_name": "_b={}_dtype={}".format(b, np.dtype(dtype).name),
"b": b, "dtype": dtype} "b": b, "dtype": dtype}