Merge pull request #16589 from froystig:key-option-impl

PiperOrigin-RevId: 544796032
This commit is contained in:
jax authors 2023-06-30 18:55:22 -07:00
commit 7b0596b90f
2 changed files with 58 additions and 21 deletions

View File

@ -116,49 +116,65 @@ def default_prng_impl():
### key operations
def key(seed: Union[int, Array]) -> PRNGKeyArray:
def resolve_prng_impl(impl_spec: Optional[str]):
if impl_spec is None:
return default_prng_impl()
if impl_spec in PRNG_IMPLS:
return PRNG_IMPLS[impl_spec]
keys_fmt = ', '.join(f'"{s}"' for s in PRNG_IMPLS.keys())
raise ValueError(f'unrecognized PRNG implementation "{impl_spec}". '
f'Did you mean one of: {keys_fmt}?')
def _key(ctor_name: str, seed: Union[int, Array], impl_spec: Optional[str]
) -> PRNGKeyArray:
impl = resolve_prng_impl(impl_spec)
if isinstance(seed, prng.PRNGKeyArray):
raise TypeError(
f"{ctor_name} accepts a scalar seed, but was given a PRNGKeyArray.")
if np.ndim(seed):
raise TypeError(
f"{ctor_name} accepts a scalar seed, but was given an array of "
f"shape {np.shape(seed)} != (). Use jax.vmap for batching")
return prng.seed_with_impl(impl, seed)
def key(seed: Union[int, Array], *,
impl: Optional[str] = None) -> PRNGKeyArray:
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
The result is a scalar array with a key that indicates the default PRNG
implementation, as determined by the ``jax_default_prng_impl`` config flag.
implementation, as determined by the optional ``impl`` argument or,
otherwise, by the ``jax_default_prng_impl`` config flag.
Args:
seed: a 64- or 32-bit integer used as the value of the key.
impl: optional string specifying the PRNG implementation (e.g.
``'threefry2x32'``)
Returns:
A scalar PRNG key array, consumable by random functions as well as ``split``
and ``fold_in``.
"""
# TODO(frostig): Take impl as optional argument
impl = default_prng_impl()
if isinstance(seed, prng.PRNGKeyArray):
raise TypeError("key accepts a scalar seed, but was given a PRNGKeyArray.")
if np.ndim(seed):
raise TypeError("key accepts a scalar seed, but was given an array of "
f"shape {np.shape(seed)} != (). Use jax.vmap for batching")
return prng.seed_with_impl(impl, seed)
return _key('key', seed, impl)
def PRNGKey(seed: Union[int, Array]) -> KeyArray:
def PRNGKey(seed: Union[int, Array], *,
impl: Optional[str] = None) -> KeyArray:
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
The resulting key carries the default PRNG implementation, as
determined by the ``jax_default_prng_impl`` config flag.
determined by the optional ``impl`` argument or, otherwise, by the
``jax_default_prng_impl`` config flag.
Args:
seed: a 64- or 32-bit integer used as the value of the key.
impl: optional string specifying the PRNG implementation (e.g.
``'threefry2x32'``)
Returns:
A PRNG key, consumable by random functions as well as ``split``
and ``fold_in``.
"""
impl = default_prng_impl()
if isinstance(seed, prng.PRNGKeyArray):
raise TypeError("PRNGKey accepts a scalar seed, but was given a PRNGKeyArray.")
if np.ndim(seed):
raise TypeError("PRNGKey accepts a scalar seed, but was given an array of "
f"shape {np.shape(seed)} != (). Use jax.vmap for batching")
key = prng.seed_with_impl(impl, seed)
return _return_prng_keys(True, key)
return _return_prng_keys(True, _key('PRNGKey', seed, impl))
# TODO(frostig): remove once we always enable_custom_prng
def _check_default_impl_with_no_custom_prng(impl, name):

View File

@ -498,7 +498,6 @@ class PrngTest(jtu.JaxTestCase):
self.assertEqual(k1.shape, impl.key_shape)
self.assertEqual(k2.shape, impl.key_shape)
def test_explicit_threefry2x32_key(self):
if not config.jax_enable_custom_prng:
self.skipTest("test requires config.jax_enable_custom_prng")
@ -517,6 +516,28 @@ class PrngTest(jtu.JaxTestCase):
key = random.unsafe_rbg_key(42)
self.assertIs(key.impl, prng.unsafe_rbg_prng_impl)
def test_key_construction_with_explicit_impl_name(self):
def check_key_has_impl(key, impl):
if isinstance(key, random.PRNGKeyArray):
self.assertIs(key.impl, impl)
else:
self.assertEqual(key.dtype, jnp.dtype('uint32'))
self.assertEqual(key.shape, impl.key_shape)
key = random.key(42, impl='threefry2x32')
check_key_has_impl(key, prng.threefry_prng_impl)
key = random.key(42, impl='rbg')
check_key_has_impl(key, prng.rbg_prng_impl)
key = random.key(42, impl='unsafe_rbg')
check_key_has_impl(key, prng.unsafe_rbg_prng_impl)
key = random.PRNGKey(42, impl='threefry2x32')
check_key_has_impl(key, prng.threefry_prng_impl)
key = random.PRNGKey(42, impl='rbg')
check_key_has_impl(key, prng.rbg_prng_impl)
key = random.PRNGKey(42, impl='unsafe_rbg')
check_key_has_impl(key, prng.unsafe_rbg_prng_impl)
def test_key_array_indexing_0d(self):
if not config.jax_enable_custom_prng:
self.skipTest("test requires config.jax_enable_custom_prng")