mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
add a config setting to control the default PRNG implementation
Also add explicit seeding functions for each PRNG implementation.
This commit is contained in:
parent
b002bc178e
commit
98d245ebb4
@ -47,6 +47,7 @@ from ._src.config import (
|
||||
debug_infs as debug_infs,
|
||||
log_compiles as log_compiles,
|
||||
default_matmul_precision as default_matmul_precision,
|
||||
default_prng_impl as default_prng_impl,
|
||||
numpy_rank_promotion as numpy_rank_promotion,
|
||||
)
|
||||
from ._src.api import (
|
||||
|
@ -512,6 +512,13 @@ enable_custom_prng = config.define_bool_state(
|
||||
'disabling it will be considered deprecated. In a version '
|
||||
'after that the flag will be removed altogether.'))
|
||||
|
||||
default_prng_impl = config.define_enum_state(
|
||||
name='jax_default_prng_impl',
|
||||
enum_values=['threefry2x32', 'rbg', 'unsafe_rbg'],
|
||||
default='threefry2x32',
|
||||
help=('Select the default PRNG implementation, used when one is not '
|
||||
'explicitly provided at seeding time.'))
|
||||
|
||||
hlo_source_file_canonicalization_regex = config.define_string_state(
|
||||
name='jax_hlo_source_file_canonicalization_regex',
|
||||
default=None,
|
||||
|
@ -71,7 +71,7 @@ def _check_prng_key(key):
|
||||
'Raw arrays as random keys to jax.random functions are deprecated. '
|
||||
'Assuming valid threefry2x32 key for now.',
|
||||
FutureWarning)
|
||||
return prng.PRNGKeyArray(prng.threefry_prng_impl, key), True
|
||||
return prng.PRNGKeyArray(_get_default_prng_impl(), key), True
|
||||
else:
|
||||
raise TypeError(f'unexpected PRNG key type {type(key)}')
|
||||
|
||||
@ -88,21 +88,68 @@ def _random_bits(key: prng.PRNGKeyArray, bit_width, shape) -> jnp.ndarray:
|
||||
return key._random_bits(bit_width, shape)
|
||||
|
||||
|
||||
PRNG_IMPLS = {
|
||||
'threefry2x32': prng.threefry_prng_impl,
|
||||
'rbg': prng.rbg_prng_impl,
|
||||
'unsafe_rbg': prng.unsafe_rbg_prng_impl,
|
||||
}
|
||||
|
||||
def _get_default_prng_impl():
|
||||
impl_name = config.jax_default_prng_impl
|
||||
assert impl_name in PRNG_IMPLS, impl_name
|
||||
return PRNG_IMPLS[impl_name]
|
||||
|
||||
|
||||
### key operations
|
||||
|
||||
|
||||
def PRNGKey(seed: int) -> KeyArray:
|
||||
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
|
||||
|
||||
The resulting key carries the default PRNG implementation, as
|
||||
determined by the ``jax_default_prng_impl`` config flag.
|
||||
|
||||
Args:
|
||||
seed: a 64- or 32-bit integer used as the value of the key.
|
||||
|
||||
Returns:
|
||||
A PRNG key, consumable by random functions as well as ``split``
|
||||
and ``fold_in``.
|
||||
|
||||
"""
|
||||
return _return_prng_keys(
|
||||
True, prng.seed_with_impl(prng.threefry_prng_impl, seed))
|
||||
impl = _get_default_prng_impl()
|
||||
key = prng.seed_with_impl(impl, seed)
|
||||
return _return_prng_keys(True, key)
|
||||
|
||||
# TODO(frostig): remove once we always enable_custom_prng
|
||||
def _check_default_impl_with_no_custom_prng(impl, name):
|
||||
default_impl = _get_default_prng_impl()
|
||||
default_name = config.jax_default_prng_impl
|
||||
if not config.jax_enable_custom_prng and default_impl is not impl:
|
||||
raise RuntimeError('jax_enable_custom_prng must be enabled in order '
|
||||
f'to seed an RNG with an implementation "f{name}" '
|
||||
f'differing from the default "f{default_name}".')
|
||||
|
||||
def threefry2x32_key(seed: int) -> KeyArray:
|
||||
"""Creates a threefry2x32 PRNG key from an integer seed."""
|
||||
impl = prng.threefry_prng_impl
|
||||
_check_default_impl_with_no_custom_prng(impl, 'threefry2x32')
|
||||
key = prng.seed_with_impl(impl, seed)
|
||||
return _return_prng_keys(True, key)
|
||||
|
||||
def rbg_key(seed: int) -> KeyArray:
|
||||
"""Creates an RBG PRNG key from an integer seed."""
|
||||
impl = prng.rbg_prng_impl
|
||||
_check_default_impl_with_no_custom_prng(impl, 'rbg')
|
||||
key = prng.seed_with_impl(impl, seed)
|
||||
return _return_prng_keys(True, key)
|
||||
|
||||
def unsafe_rbg_key(seed: int) -> KeyArray:
|
||||
"""Creates an unsafe RBG PRNG key from an integer seed."""
|
||||
impl = prng.unsafe_rbg_prng_impl
|
||||
_check_default_impl_with_no_custom_prng(impl, 'unsafe_rbg')
|
||||
key = prng.seed_with_impl(impl, seed)
|
||||
return _return_prng_keys(True, key)
|
||||
|
||||
def _fold_in(key: KeyArray, data: int) -> KeyArray:
|
||||
# Alternative to fold_in() to use within random samplers.
|
||||
@ -939,7 +986,7 @@ def gamma(key: KeyArray,
|
||||
key, _ = _check_prng_key(key)
|
||||
if key.impl is not prng.threefry_prng_impl:
|
||||
raise NotImplementedError(
|
||||
f'`gamma` is only implemented for the default PRNG, not {key.impl}')
|
||||
f'`gamma` is only implemented for the threefry2x32 RNG, not {key.impl}')
|
||||
return gamma_threefry2x32(key.keys, a, shape, dtype)
|
||||
|
||||
def gamma_threefry2x32(key: jnp.ndarray, # raw ndarray form of a 2x32 key
|
||||
@ -1074,7 +1121,8 @@ def poisson(key: KeyArray,
|
||||
key, _ = _check_prng_key(key)
|
||||
if key.impl is not prng.threefry_prng_impl:
|
||||
raise NotImplementedError(
|
||||
f'`poisson` is only implemented for the default PRNG, not {key.impl}')
|
||||
'`poisson` is only implemented for the threefry2x32 RNG, '
|
||||
f'not {key.impl}')
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
shape = core.canonicalize_shape(shape)
|
||||
if np.shape(lam) != shape:
|
||||
|
@ -106,12 +106,15 @@ from jax._src.random import (
|
||||
rademacher as rademacher,
|
||||
randint as randint,
|
||||
random_gamma_p as random_gamma_p,
|
||||
rbg_key as rbg_key,
|
||||
shuffle as shuffle,
|
||||
split as split,
|
||||
t as t,
|
||||
threefry_2x32 as threefry_2x32,
|
||||
threefry2x32_key as threefry2x32_key,
|
||||
threefry2x32_p as threefry2x32_p,
|
||||
truncated_normal as truncated_normal,
|
||||
uniform as uniform,
|
||||
unsafe_rbg_key as unsafe_rbg_key,
|
||||
weibull_min as weibull_min,
|
||||
)
|
||||
|
@ -205,6 +205,36 @@ class PrngTest(jtu.JaxTestCase):
|
||||
expected = jnp.array(key, dtype=jnp.uint32)
|
||||
self.assertArraysEqual(actual, expected)
|
||||
|
||||
def test_default_prng_selection(self):
|
||||
if not config.jax_enable_custom_prng:
|
||||
self.skipTest("test requires config.jax_enable_custom_prng")
|
||||
for name, impl in [('threefry2x32', prng.threefry_prng_impl),
|
||||
('rbg', prng.rbg_prng_impl),
|
||||
('unsafe_rbg', prng.unsafe_rbg_prng_impl)]:
|
||||
with jax.default_prng_impl(name):
|
||||
key = random.PRNGKey(42)
|
||||
self.assertIs(key.impl, impl)
|
||||
k1, k2 = random.split(key, 2)
|
||||
self.assertIs(k1.impl, impl)
|
||||
self.assertIs(k2.impl, impl)
|
||||
|
||||
def test_explicit_threefry2x32_key(self):
|
||||
if not config.jax_enable_custom_prng:
|
||||
self.skipTest("test requires config.jax_enable_custom_prng")
|
||||
key = random.threefry2x32_key(42)
|
||||
self.assertIs(key.impl, prng.threefry_prng_impl)
|
||||
|
||||
def test_explicit_rbg_key(self):
|
||||
if not config.jax_enable_custom_prng:
|
||||
self.skipTest("test requires config.jax_enable_custom_prng")
|
||||
key = random.rbg_key(42)
|
||||
self.assertIs(key.impl, prng.rbg_prng_impl)
|
||||
|
||||
def test_explicit_unsafe_rbg_key(self):
|
||||
if not config.jax_enable_custom_prng:
|
||||
self.skipTest("test requires config.jax_enable_custom_prng")
|
||||
key = random.unsafe_rbg_key(42)
|
||||
self.assertIs(key.impl, prng.unsafe_rbg_prng_impl)
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class LaxRandomTest(jtu.JaxTestCase):
|
||||
@ -250,7 +280,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
f'{expected_freq}\n{actual_freq}')
|
||||
|
||||
def seed_prng(self, seed):
|
||||
return random.PRNGKey(seed)
|
||||
return random.threefry2x32_key(seed)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
|
||||
@ -1125,7 +1155,7 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
def seed_prng(self, seed):
|
||||
return prng.seed_with_impl(prng.rbg_prng_impl, seed)
|
||||
return random.rbg_key(seed)
|
||||
|
||||
def test_split_shape(self):
|
||||
key = self.seed_prng(73)
|
||||
|
Loading…
x
Reference in New Issue
Block a user