Merge pull request #9186 from froystig:get-default-rng

PiperOrigin-RevId: 421454645
This commit is contained in:
jax authors 2022-01-12 19:48:04 -08:00
commit f0e4f0472d
3 changed files with 27 additions and 4 deletions

View File

@ -71,7 +71,7 @@ def _check_prng_key(key):
'Raw arrays as random keys to jax.random functions are deprecated. '
'Assuming valid threefry2x32 key for now.',
FutureWarning)
return prng.PRNGKeyArray(_get_default_prng_impl(), key), True
return prng.PRNGKeyArray(default_prng_impl(), key), True
else:
raise TypeError(f'unexpected PRNG key type {type(key)}')
@ -94,7 +94,13 @@ PRNG_IMPLS = {
'unsafe_rbg': prng.unsafe_rbg_prng_impl,
}
def _get_default_prng_impl():
def default_prng_impl():
"""Get the default PRNG implementation.
The default implementation is determined by ``config.jax_default_prng_impl``,
which specifies it by name. This function returns the corresponding
``jax.prng.PRNGImpl`` instance.
"""
impl_name = config.jax_default_prng_impl
assert impl_name in PRNG_IMPLS, impl_name
return PRNG_IMPLS[impl_name]
@ -117,13 +123,13 @@ def PRNGKey(seed: int) -> KeyArray:
and ``fold_in``.
"""
impl = _get_default_prng_impl()
impl = default_prng_impl()
key = prng.seed_with_impl(impl, seed)
return _return_prng_keys(True, key)
# TODO(frostig): remove once we always enable_custom_prng
def _check_default_impl_with_no_custom_prng(impl, name):
default_impl = _get_default_prng_impl()
default_impl = default_prng_impl()
default_name = config.jax_default_prng_impl
if not config.jax_enable_custom_prng and default_impl is not impl:
raise RuntimeError('jax_enable_custom_prng must be enabled in order '

View File

@ -89,6 +89,7 @@ from jax._src.random import (
categorical as categorical,
cauchy as cauchy,
choice as choice,
default_prng_impl as default_prng_impl,
dirichlet as dirichlet,
double_sided_maxwell as double_sided_maxwell,
exponential as exponential,

View File

@ -213,12 +213,28 @@ class PrngTest(jtu.JaxTestCase):
('rbg', prng.rbg_prng_impl),
('unsafe_rbg', prng.unsafe_rbg_prng_impl)]:
with jax.default_prng_impl(name):
self.assertIs(random.default_prng_impl(), impl)
key = random.PRNGKey(42)
self.assertIs(key.impl, impl)
k1, k2 = random.split(key, 2)
self.assertIs(k1.impl, impl)
self.assertIs(k2.impl, impl)
def test_default_prng_selection_without_custom_prng_mode(self):
if config.jax_enable_custom_prng:
self.skipTest("test requires that config.jax_enable_custom_prng is False")
for name, impl in [('threefry2x32', prng.threefry_prng_impl),
('rbg', prng.rbg_prng_impl),
('unsafe_rbg', prng.unsafe_rbg_prng_impl)]:
with jax.default_prng_impl(name):
self.assertIs(random.default_prng_impl(), impl)
key = random.PRNGKey(42)
self.assertEqual(key.shape, impl.key_shape)
k1, k2 = random.split(key, 2)
self.assertEqual(k1.shape, impl.key_shape)
self.assertEqual(k2.shape, impl.key_shape)
def test_explicit_threefry2x32_key(self):
if not config.jax_enable_custom_prng:
self.skipTest("test requires config.jax_enable_custom_prng")