add a config setting to control the default PRNG implementation

Also add explicit seeding functions for each PRNG implementation.
This commit is contained in:
Roy Frostig 2021-10-07 19:15:43 -07:00
parent b002bc178e
commit 98d245ebb4
5 changed files with 96 additions and 7 deletions

View File

@ -47,6 +47,7 @@ from ._src.config import (
debug_infs as debug_infs,
log_compiles as log_compiles,
default_matmul_precision as default_matmul_precision,
default_prng_impl as default_prng_impl,
numpy_rank_promotion as numpy_rank_promotion,
)
from ._src.api import (

View File

@ -512,6 +512,13 @@ enable_custom_prng = config.define_bool_state(
'disabling it will be considered deprecated. In a version '
'after that the flag will be removed altogether.'))
default_prng_impl = config.define_enum_state(
name='jax_default_prng_impl',
enum_values=['threefry2x32', 'rbg', 'unsafe_rbg'],
default='threefry2x32',
help=('Select the default PRNG implementation, used when one is not '
'explicitly provided at seeding time.'))
hlo_source_file_canonicalization_regex = config.define_string_state(
name='jax_hlo_source_file_canonicalization_regex',
default=None,

View File

@ -71,7 +71,7 @@ def _check_prng_key(key):
'Raw arrays as random keys to jax.random functions are deprecated. '
'Assuming valid threefry2x32 key for now.',
FutureWarning)
return prng.PRNGKeyArray(prng.threefry_prng_impl, key), True
return prng.PRNGKeyArray(_get_default_prng_impl(), key), True
else:
raise TypeError(f'unexpected PRNG key type {type(key)}')
@ -88,21 +88,68 @@ def _random_bits(key: prng.PRNGKeyArray, bit_width, shape) -> jnp.ndarray:
return key._random_bits(bit_width, shape)
PRNG_IMPLS = {
'threefry2x32': prng.threefry_prng_impl,
'rbg': prng.rbg_prng_impl,
'unsafe_rbg': prng.unsafe_rbg_prng_impl,
}
def _get_default_prng_impl():
impl_name = config.jax_default_prng_impl
assert impl_name in PRNG_IMPLS, impl_name
return PRNG_IMPLS[impl_name]
### key operations
def PRNGKey(seed: int) -> KeyArray:
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
The resulting key carries the default PRNG implementation, as
determined by the ``jax_default_prng_impl`` config flag.
Args:
seed: a 64- or 32-bit integer used as the value of the key.
Returns:
A PRNG key, consumable by random functions as well as ``split``
and ``fold_in``.
"""
return _return_prng_keys(
True, prng.seed_with_impl(prng.threefry_prng_impl, seed))
impl = _get_default_prng_impl()
key = prng.seed_with_impl(impl, seed)
return _return_prng_keys(True, key)
# TODO(frostig): remove once we always enable_custom_prng
def _check_default_impl_with_no_custom_prng(impl, name):
default_impl = _get_default_prng_impl()
default_name = config.jax_default_prng_impl
if not config.jax_enable_custom_prng and default_impl is not impl:
raise RuntimeError('jax_enable_custom_prng must be enabled in order '
f'to seed an RNG with an implementation "f{name}" '
f'differing from the default "f{default_name}".')
def threefry2x32_key(seed: int) -> KeyArray:
"""Creates a threefry2x32 PRNG key from an integer seed."""
impl = prng.threefry_prng_impl
_check_default_impl_with_no_custom_prng(impl, 'threefry2x32')
key = prng.seed_with_impl(impl, seed)
return _return_prng_keys(True, key)
def rbg_key(seed: int) -> KeyArray:
"""Creates an RBG PRNG key from an integer seed."""
impl = prng.rbg_prng_impl
_check_default_impl_with_no_custom_prng(impl, 'rbg')
key = prng.seed_with_impl(impl, seed)
return _return_prng_keys(True, key)
def unsafe_rbg_key(seed: int) -> KeyArray:
"""Creates an unsafe RBG PRNG key from an integer seed."""
impl = prng.unsafe_rbg_prng_impl
_check_default_impl_with_no_custom_prng(impl, 'unsafe_rbg')
key = prng.seed_with_impl(impl, seed)
return _return_prng_keys(True, key)
def _fold_in(key: KeyArray, data: int) -> KeyArray:
# Alternative to fold_in() to use within random samplers.
@ -939,7 +986,7 @@ def gamma(key: KeyArray,
key, _ = _check_prng_key(key)
if key.impl is not prng.threefry_prng_impl:
raise NotImplementedError(
f'`gamma` is only implemented for the default PRNG, not {key.impl}')
f'`gamma` is only implemented for the threefry2x32 RNG, not {key.impl}')
return gamma_threefry2x32(key.keys, a, shape, dtype)
def gamma_threefry2x32(key: jnp.ndarray, # raw ndarray form of a 2x32 key
@ -1074,7 +1121,8 @@ def poisson(key: KeyArray,
key, _ = _check_prng_key(key)
if key.impl is not prng.threefry_prng_impl:
raise NotImplementedError(
f'`poisson` is only implemented for the default PRNG, not {key.impl}')
'`poisson` is only implemented for the threefry2x32 RNG, '
f'not {key.impl}')
dtype = dtypes.canonicalize_dtype(dtype)
shape = core.canonicalize_shape(shape)
if np.shape(lam) != shape:

View File

@ -106,12 +106,15 @@ from jax._src.random import (
rademacher as rademacher,
randint as randint,
random_gamma_p as random_gamma_p,
rbg_key as rbg_key,
shuffle as shuffle,
split as split,
t as t,
threefry_2x32 as threefry_2x32,
threefry2x32_key as threefry2x32_key,
threefry2x32_p as threefry2x32_p,
truncated_normal as truncated_normal,
uniform as uniform,
unsafe_rbg_key as unsafe_rbg_key,
weibull_min as weibull_min,
)

View File

@ -205,6 +205,36 @@ class PrngTest(jtu.JaxTestCase):
expected = jnp.array(key, dtype=jnp.uint32)
self.assertArraysEqual(actual, expected)
def test_default_prng_selection(self):
if not config.jax_enable_custom_prng:
self.skipTest("test requires config.jax_enable_custom_prng")
for name, impl in [('threefry2x32', prng.threefry_prng_impl),
('rbg', prng.rbg_prng_impl),
('unsafe_rbg', prng.unsafe_rbg_prng_impl)]:
with jax.default_prng_impl(name):
key = random.PRNGKey(42)
self.assertIs(key.impl, impl)
k1, k2 = random.split(key, 2)
self.assertIs(k1.impl, impl)
self.assertIs(k2.impl, impl)
def test_explicit_threefry2x32_key(self):
if not config.jax_enable_custom_prng:
self.skipTest("test requires config.jax_enable_custom_prng")
key = random.threefry2x32_key(42)
self.assertIs(key.impl, prng.threefry_prng_impl)
def test_explicit_rbg_key(self):
if not config.jax_enable_custom_prng:
self.skipTest("test requires config.jax_enable_custom_prng")
key = random.rbg_key(42)
self.assertIs(key.impl, prng.rbg_prng_impl)
def test_explicit_unsafe_rbg_key(self):
if not config.jax_enable_custom_prng:
self.skipTest("test requires config.jax_enable_custom_prng")
key = random.unsafe_rbg_key(42)
self.assertIs(key.impl, prng.unsafe_rbg_prng_impl)
@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxRandomTest(jtu.JaxTestCase):
@ -250,7 +280,7 @@ class LaxRandomTest(jtu.JaxTestCase):
f'{expected_freq}\n{actual_freq}')
def seed_prng(self, seed):
return random.PRNGKey(seed)
return random.threefry2x32_key(seed)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
@ -1125,7 +1155,7 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxRandomWithRBGPRNGTest(LaxRandomTest):
def seed_prng(self, seed):
return prng.seed_with_impl(prng.rbg_prng_impl, seed)
return random.rbg_key(seed)
def test_split_shape(self):
key = self.seed_prng(73)