diff --git a/CHANGELOG.md b/CHANGELOG.md index 76278a264..b76ae14c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. is not of an integer type, matching the behavior of {func}`numpy.diag`. Previously non-integer `k` was silently cast to integers. + * Added {func}`jax.random.orthogonal`. * Deprecations * Many functions and objects available in {mod}`jax.test_util` are now deprecated and will raise a warning on import. This includes `cases_from_list`, `check_close`, `check_eq`, `device_under_test`, diff --git a/docs/jax.random.rst b/docs/jax.random.rst index f0660782a..dc578aee7 100644 --- a/docs/jax.random.rst +++ b/docs/jax.random.rst @@ -33,6 +33,7 @@ List of Available Functions maxwell multivariate_normal normal + orthogonal pareto permutation poisson diff --git a/jax/_src/random.py b/jax/_src/random.py index 847171230..36f6cb1f1 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -15,6 +15,7 @@ from functools import partial from typing import Any, Optional, Sequence, Union +from operator import index import warnings import numpy as np @@ -1595,3 +1596,30 @@ def threefry_2x32(keypair, count): warnings.warn('jax.random.threefry_2x32 has moved to jax.prng.threefry_2x32 ' 'and will be removed from `random` module.', FutureWarning) return prng.threefry_2x32(keypair, count) + +def orthogonal( + key: KeyArray, + n: int, + shape: Sequence[int] = (), + dtype: DTypeLikeFloat = dtypes.float_ +) -> jnp.ndarray: + """Sample uniformly from the orthogonal group O(n). + + If the dtype is complex, sample uniformly from the unitary group U(n). + + Args: + key: a PRNG key used as the random key. + n: an integer indicating the resulting dimension. + shape: optional, the batch dimensions of the result. Default (). + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). + + Returns: + A random array of shape `(*shape, n, n)` and specified dtype. + """ + _check_shape("orthogonal", shape) + n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()") + z = normal(key, (*shape, n, n), dtype) + q, r = jnp.linalg.qr(z) + d = jnp.diagonal(r, 0, -2, -1) + return q * jnp.expand_dims(d / abs(d), -2) diff --git a/jax/random.py b/jax/random.py index a349e2752..3e3cd5fe8 100644 --- a/jax/random.py +++ b/jax/random.py @@ -143,6 +143,7 @@ from jax._src.random import ( maxwell as maxwell, multivariate_normal as multivariate_normal, normal as normal, + orthogonal as orthogonal, pareto as pareto, permutation as permutation, poisson as poisson, diff --git a/tests/random_test.py b/tests/random_test.py index 55a580130..b243a31e2 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1032,6 +1032,26 @@ class LaxRandomTest(jtu.JaxTestCase): for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.logistic().cdf) + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_n={}_shape={}"\ + .format(n, jtu.format_shape_dtype_string(shape, dtype)), + "n": n, + "shape": shape, + "dtype": dtype} + for n in range(1, 5) + for shape in [(), (5,), (10, 5)] + for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) + def testOrthogonal(self, n, shape, dtype): + key = self.seed_prng(0) + q = random.orthogonal(key, n, shape, dtype) + self.assertEqual(q.shape, (*shape, n, n)) + self.assertEqual(q.dtype, dtype) + with jax.numpy_rank_promotion('allow'): + self.assertAllClose( + jnp.einsum('...ij,...jk->...ik', q, jnp.conj(q).swapaxes(-2, -1)), + jnp.broadcast_to(jnp.eye(n, dtype=dtype), (*shape, n, n)) + ) + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_b={}_dtype={}".format(b, np.dtype(dtype).name), "b": b, "dtype": dtype}