mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #9186 from froystig:get-default-rng
PiperOrigin-RevId: 421454645
This commit is contained in:
commit
f0e4f0472d
@ -71,7 +71,7 @@ def _check_prng_key(key):
|
||||
'Raw arrays as random keys to jax.random functions are deprecated. '
|
||||
'Assuming valid threefry2x32 key for now.',
|
||||
FutureWarning)
|
||||
return prng.PRNGKeyArray(_get_default_prng_impl(), key), True
|
||||
return prng.PRNGKeyArray(default_prng_impl(), key), True
|
||||
else:
|
||||
raise TypeError(f'unexpected PRNG key type {type(key)}')
|
||||
|
||||
@ -94,7 +94,13 @@ PRNG_IMPLS = {
|
||||
'unsafe_rbg': prng.unsafe_rbg_prng_impl,
|
||||
}
|
||||
|
||||
def _get_default_prng_impl():
|
||||
def default_prng_impl():
|
||||
"""Get the default PRNG implementation.
|
||||
|
||||
The default implementation is determined by ``config.jax_default_prng_impl``,
|
||||
which specifies it by name. This function returns the corresponding
|
||||
``jax.prng.PRNGImpl`` instance.
|
||||
"""
|
||||
impl_name = config.jax_default_prng_impl
|
||||
assert impl_name in PRNG_IMPLS, impl_name
|
||||
return PRNG_IMPLS[impl_name]
|
||||
@ -117,13 +123,13 @@ def PRNGKey(seed: int) -> KeyArray:
|
||||
and ``fold_in``.
|
||||
|
||||
"""
|
||||
impl = _get_default_prng_impl()
|
||||
impl = default_prng_impl()
|
||||
key = prng.seed_with_impl(impl, seed)
|
||||
return _return_prng_keys(True, key)
|
||||
|
||||
# TODO(frostig): remove once we always enable_custom_prng
|
||||
def _check_default_impl_with_no_custom_prng(impl, name):
|
||||
default_impl = _get_default_prng_impl()
|
||||
default_impl = default_prng_impl()
|
||||
default_name = config.jax_default_prng_impl
|
||||
if not config.jax_enable_custom_prng and default_impl is not impl:
|
||||
raise RuntimeError('jax_enable_custom_prng must be enabled in order '
|
||||
|
@ -89,6 +89,7 @@ from jax._src.random import (
|
||||
categorical as categorical,
|
||||
cauchy as cauchy,
|
||||
choice as choice,
|
||||
default_prng_impl as default_prng_impl,
|
||||
dirichlet as dirichlet,
|
||||
double_sided_maxwell as double_sided_maxwell,
|
||||
exponential as exponential,
|
||||
|
@ -213,12 +213,28 @@ class PrngTest(jtu.JaxTestCase):
|
||||
('rbg', prng.rbg_prng_impl),
|
||||
('unsafe_rbg', prng.unsafe_rbg_prng_impl)]:
|
||||
with jax.default_prng_impl(name):
|
||||
self.assertIs(random.default_prng_impl(), impl)
|
||||
key = random.PRNGKey(42)
|
||||
self.assertIs(key.impl, impl)
|
||||
k1, k2 = random.split(key, 2)
|
||||
self.assertIs(k1.impl, impl)
|
||||
self.assertIs(k2.impl, impl)
|
||||
|
||||
def test_default_prng_selection_without_custom_prng_mode(self):
|
||||
if config.jax_enable_custom_prng:
|
||||
self.skipTest("test requires that config.jax_enable_custom_prng is False")
|
||||
for name, impl in [('threefry2x32', prng.threefry_prng_impl),
|
||||
('rbg', prng.rbg_prng_impl),
|
||||
('unsafe_rbg', prng.unsafe_rbg_prng_impl)]:
|
||||
with jax.default_prng_impl(name):
|
||||
self.assertIs(random.default_prng_impl(), impl)
|
||||
key = random.PRNGKey(42)
|
||||
self.assertEqual(key.shape, impl.key_shape)
|
||||
k1, k2 = random.split(key, 2)
|
||||
self.assertEqual(k1.shape, impl.key_shape)
|
||||
self.assertEqual(k2.shape, impl.key_shape)
|
||||
|
||||
|
||||
def test_explicit_threefry2x32_key(self):
|
||||
if not config.jax_enable_custom_prng:
|
||||
self.skipTest("test requires config.jax_enable_custom_prng")
|
||||
|
Loading…
x
Reference in New Issue
Block a user