mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Added random.orthogonal.
This commit is contained in:
parent
823ad552d6
commit
b276c31b75
@ -44,6 +44,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
is not of an integer type, matching the behavior of
|
is not of an integer type, matching the behavior of
|
||||||
{func}`numpy.diag`. Previously non-integer `k` was silently
|
{func}`numpy.diag`. Previously non-integer `k` was silently
|
||||||
cast to integers.
|
cast to integers.
|
||||||
|
* Added {func}`jax.random.orthogonal`.
|
||||||
* Deprecations
|
* Deprecations
|
||||||
* Many functions and objects available in {mod}`jax.test_util` are now deprecated and will raise a
|
* Many functions and objects available in {mod}`jax.test_util` are now deprecated and will raise a
|
||||||
warning on import. This includes `cases_from_list`, `check_close`, `check_eq`, `device_under_test`,
|
warning on import. This includes `cases_from_list`, `check_close`, `check_eq`, `device_under_test`,
|
||||||
|
@ -33,6 +33,7 @@ List of Available Functions
|
|||||||
maxwell
|
maxwell
|
||||||
multivariate_normal
|
multivariate_normal
|
||||||
normal
|
normal
|
||||||
|
orthogonal
|
||||||
pareto
|
pareto
|
||||||
permutation
|
permutation
|
||||||
poisson
|
poisson
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Optional, Sequence, Union
|
from typing import Any, Optional, Sequence, Union
|
||||||
|
from operator import index
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -1595,3 +1596,30 @@ def threefry_2x32(keypair, count):
|
|||||||
warnings.warn('jax.random.threefry_2x32 has moved to jax.prng.threefry_2x32 '
|
warnings.warn('jax.random.threefry_2x32 has moved to jax.prng.threefry_2x32 '
|
||||||
'and will be removed from `random` module.', FutureWarning)
|
'and will be removed from `random` module.', FutureWarning)
|
||||||
return prng.threefry_2x32(keypair, count)
|
return prng.threefry_2x32(keypair, count)
|
||||||
|
|
||||||
|
def orthogonal(
|
||||||
|
key: KeyArray,
|
||||||
|
n: int,
|
||||||
|
shape: Sequence[int] = (),
|
||||||
|
dtype: DTypeLikeFloat = dtypes.float_
|
||||||
|
) -> jnp.ndarray:
|
||||||
|
"""Sample uniformly from the orthogonal group O(n).
|
||||||
|
|
||||||
|
If the dtype is complex, sample uniformly from the unitary group U(n).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: a PRNG key used as the random key.
|
||||||
|
n: an integer indicating the resulting dimension.
|
||||||
|
shape: optional, the batch dimensions of the result. Default ().
|
||||||
|
dtype: optional, a float dtype for the returned values (default float64 if
|
||||||
|
jax_enable_x64 is true, otherwise float32).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A random array of shape `(*shape, n, n)` and specified dtype.
|
||||||
|
"""
|
||||||
|
_check_shape("orthogonal", shape)
|
||||||
|
n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
|
||||||
|
z = normal(key, (*shape, n, n), dtype)
|
||||||
|
q, r = jnp.linalg.qr(z)
|
||||||
|
d = jnp.diagonal(r, 0, -2, -1)
|
||||||
|
return q * jnp.expand_dims(d / abs(d), -2)
|
||||||
|
@ -143,6 +143,7 @@ from jax._src.random import (
|
|||||||
maxwell as maxwell,
|
maxwell as maxwell,
|
||||||
multivariate_normal as multivariate_normal,
|
multivariate_normal as multivariate_normal,
|
||||||
normal as normal,
|
normal as normal,
|
||||||
|
orthogonal as orthogonal,
|
||||||
pareto as pareto,
|
pareto as pareto,
|
||||||
permutation as permutation,
|
permutation as permutation,
|
||||||
poisson as poisson,
|
poisson as poisson,
|
||||||
|
@ -1032,6 +1032,26 @@ class LaxRandomTest(jtu.JaxTestCase):
|
|||||||
for samples in [uncompiled_samples, compiled_samples]:
|
for samples in [uncompiled_samples, compiled_samples]:
|
||||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.logistic().cdf)
|
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.logistic().cdf)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||||||
|
{"testcase_name": "_n={}_shape={}"\
|
||||||
|
.format(n, jtu.format_shape_dtype_string(shape, dtype)),
|
||||||
|
"n": n,
|
||||||
|
"shape": shape,
|
||||||
|
"dtype": dtype}
|
||||||
|
for n in range(1, 5)
|
||||||
|
for shape in [(), (5,), (10, 5)]
|
||||||
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
||||||
|
def testOrthogonal(self, n, shape, dtype):
|
||||||
|
key = self.seed_prng(0)
|
||||||
|
q = random.orthogonal(key, n, shape, dtype)
|
||||||
|
self.assertEqual(q.shape, (*shape, n, n))
|
||||||
|
self.assertEqual(q.dtype, dtype)
|
||||||
|
with jax.numpy_rank_promotion('allow'):
|
||||||
|
self.assertAllClose(
|
||||||
|
jnp.einsum('...ij,...jk->...ik', q, jnp.conj(q).swapaxes(-2, -1)),
|
||||||
|
jnp.broadcast_to(jnp.eye(n, dtype=dtype), (*shape, n, n))
|
||||||
|
)
|
||||||
|
|
||||||
@parameterized.named_parameters(jtu.cases_from_list(
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||||||
{"testcase_name": "_b={}_dtype={}".format(b, np.dtype(dtype).name),
|
{"testcase_name": "_b={}_dtype={}".format(b, np.dtype(dtype).name),
|
||||||
"b": b, "dtype": dtype}
|
"b": b, "dtype": dtype}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user