mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Added random.orthogonal.
This commit is contained in:
parent
823ad552d6
commit
b276c31b75
@ -44,6 +44,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
is not of an integer type, matching the behavior of
|
||||
{func}`numpy.diag`. Previously non-integer `k` was silently
|
||||
cast to integers.
|
||||
* Added {func}`jax.random.orthogonal`.
|
||||
* Deprecations
|
||||
* Many functions and objects available in {mod}`jax.test_util` are now deprecated and will raise a
|
||||
warning on import. This includes `cases_from_list`, `check_close`, `check_eq`, `device_under_test`,
|
||||
|
@ -33,6 +33,7 @@ List of Available Functions
|
||||
maxwell
|
||||
multivariate_normal
|
||||
normal
|
||||
orthogonal
|
||||
pareto
|
||||
permutation
|
||||
poisson
|
||||
|
@ -15,6 +15,7 @@
|
||||
|
||||
from functools import partial
|
||||
from typing import Any, Optional, Sequence, Union
|
||||
from operator import index
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -1595,3 +1596,30 @@ def threefry_2x32(keypair, count):
|
||||
warnings.warn('jax.random.threefry_2x32 has moved to jax.prng.threefry_2x32 '
|
||||
'and will be removed from `random` module.', FutureWarning)
|
||||
return prng.threefry_2x32(keypair, count)
|
||||
|
||||
def orthogonal(
|
||||
key: KeyArray,
|
||||
n: int,
|
||||
shape: Sequence[int] = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_
|
||||
) -> jnp.ndarray:
|
||||
"""Sample uniformly from the orthogonal group O(n).
|
||||
|
||||
If the dtype is complex, sample uniformly from the unitary group U(n).
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
n: an integer indicating the resulting dimension.
|
||||
shape: optional, the batch dimensions of the result. Default ().
|
||||
dtype: optional, a float dtype for the returned values (default float64 if
|
||||
jax_enable_x64 is true, otherwise float32).
|
||||
|
||||
Returns:
|
||||
A random array of shape `(*shape, n, n)` and specified dtype.
|
||||
"""
|
||||
_check_shape("orthogonal", shape)
|
||||
n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
|
||||
z = normal(key, (*shape, n, n), dtype)
|
||||
q, r = jnp.linalg.qr(z)
|
||||
d = jnp.diagonal(r, 0, -2, -1)
|
||||
return q * jnp.expand_dims(d / abs(d), -2)
|
||||
|
@ -143,6 +143,7 @@ from jax._src.random import (
|
||||
maxwell as maxwell,
|
||||
multivariate_normal as multivariate_normal,
|
||||
normal as normal,
|
||||
orthogonal as orthogonal,
|
||||
pareto as pareto,
|
||||
permutation as permutation,
|
||||
poisson as poisson,
|
||||
|
@ -1032,6 +1032,26 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.logistic().cdf)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_n={}_shape={}"\
|
||||
.format(n, jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"n": n,
|
||||
"shape": shape,
|
||||
"dtype": dtype}
|
||||
for n in range(1, 5)
|
||||
for shape in [(), (5,), (10, 5)]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
||||
def testOrthogonal(self, n, shape, dtype):
|
||||
key = self.seed_prng(0)
|
||||
q = random.orthogonal(key, n, shape, dtype)
|
||||
self.assertEqual(q.shape, (*shape, n, n))
|
||||
self.assertEqual(q.dtype, dtype)
|
||||
with jax.numpy_rank_promotion('allow'):
|
||||
self.assertAllClose(
|
||||
jnp.einsum('...ij,...jk->...ik', q, jnp.conj(q).swapaxes(-2, -1)),
|
||||
jnp.broadcast_to(jnp.eye(n, dtype=dtype), (*shape, n, n))
|
||||
)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_b={}_dtype={}".format(b, np.dtype(dtype).name),
|
||||
"b": b, "dtype": dtype}
|
||||
|
Loading…
x
Reference in New Issue
Block a user