mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #16589 from froystig:key-option-impl
PiperOrigin-RevId: 544796032
This commit is contained in:
commit
7b0596b90f
@ -116,49 +116,65 @@ def default_prng_impl():
|
||||
|
||||
### key operations
|
||||
|
||||
def key(seed: Union[int, Array]) -> PRNGKeyArray:
|
||||
def resolve_prng_impl(impl_spec: Optional[str]):
|
||||
if impl_spec is None:
|
||||
return default_prng_impl()
|
||||
if impl_spec in PRNG_IMPLS:
|
||||
return PRNG_IMPLS[impl_spec]
|
||||
|
||||
keys_fmt = ', '.join(f'"{s}"' for s in PRNG_IMPLS.keys())
|
||||
raise ValueError(f'unrecognized PRNG implementation "{impl_spec}". '
|
||||
f'Did you mean one of: {keys_fmt}?')
|
||||
|
||||
def _key(ctor_name: str, seed: Union[int, Array], impl_spec: Optional[str]
|
||||
) -> PRNGKeyArray:
|
||||
impl = resolve_prng_impl(impl_spec)
|
||||
if isinstance(seed, prng.PRNGKeyArray):
|
||||
raise TypeError(
|
||||
f"{ctor_name} accepts a scalar seed, but was given a PRNGKeyArray.")
|
||||
if np.ndim(seed):
|
||||
raise TypeError(
|
||||
f"{ctor_name} accepts a scalar seed, but was given an array of "
|
||||
f"shape {np.shape(seed)} != (). Use jax.vmap for batching")
|
||||
return prng.seed_with_impl(impl, seed)
|
||||
|
||||
def key(seed: Union[int, Array], *,
|
||||
impl: Optional[str] = None) -> PRNGKeyArray:
|
||||
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
|
||||
|
||||
The result is a scalar array with a key that indicates the default PRNG
|
||||
implementation, as determined by the ``jax_default_prng_impl`` config flag.
|
||||
implementation, as determined by the optional ``impl`` argument or,
|
||||
otherwise, by the ``jax_default_prng_impl`` config flag.
|
||||
|
||||
Args:
|
||||
seed: a 64- or 32-bit integer used as the value of the key.
|
||||
impl: optional string specifying the PRNG implementation (e.g.
|
||||
``'threefry2x32'``)
|
||||
|
||||
Returns:
|
||||
A scalar PRNG key array, consumable by random functions as well as ``split``
|
||||
and ``fold_in``.
|
||||
"""
|
||||
# TODO(frostig): Take impl as optional argument
|
||||
impl = default_prng_impl()
|
||||
if isinstance(seed, prng.PRNGKeyArray):
|
||||
raise TypeError("key accepts a scalar seed, but was given a PRNGKeyArray.")
|
||||
if np.ndim(seed):
|
||||
raise TypeError("key accepts a scalar seed, but was given an array of "
|
||||
f"shape {np.shape(seed)} != (). Use jax.vmap for batching")
|
||||
return prng.seed_with_impl(impl, seed)
|
||||
return _key('key', seed, impl)
|
||||
|
||||
def PRNGKey(seed: Union[int, Array]) -> KeyArray:
|
||||
def PRNGKey(seed: Union[int, Array], *,
|
||||
impl: Optional[str] = None) -> KeyArray:
|
||||
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
|
||||
|
||||
The resulting key carries the default PRNG implementation, as
|
||||
determined by the ``jax_default_prng_impl`` config flag.
|
||||
determined by the optional ``impl`` argument or, otherwise, by the
|
||||
``jax_default_prng_impl`` config flag.
|
||||
|
||||
Args:
|
||||
seed: a 64- or 32-bit integer used as the value of the key.
|
||||
impl: optional string specifying the PRNG implementation (e.g.
|
||||
``'threefry2x32'``)
|
||||
|
||||
Returns:
|
||||
A PRNG key, consumable by random functions as well as ``split``
|
||||
and ``fold_in``.
|
||||
"""
|
||||
impl = default_prng_impl()
|
||||
if isinstance(seed, prng.PRNGKeyArray):
|
||||
raise TypeError("PRNGKey accepts a scalar seed, but was given a PRNGKeyArray.")
|
||||
if np.ndim(seed):
|
||||
raise TypeError("PRNGKey accepts a scalar seed, but was given an array of "
|
||||
f"shape {np.shape(seed)} != (). Use jax.vmap for batching")
|
||||
key = prng.seed_with_impl(impl, seed)
|
||||
return _return_prng_keys(True, key)
|
||||
return _return_prng_keys(True, _key('PRNGKey', seed, impl))
|
||||
|
||||
# TODO(frostig): remove once we always enable_custom_prng
|
||||
def _check_default_impl_with_no_custom_prng(impl, name):
|
||||
|
@ -498,7 +498,6 @@ class PrngTest(jtu.JaxTestCase):
|
||||
self.assertEqual(k1.shape, impl.key_shape)
|
||||
self.assertEqual(k2.shape, impl.key_shape)
|
||||
|
||||
|
||||
def test_explicit_threefry2x32_key(self):
|
||||
if not config.jax_enable_custom_prng:
|
||||
self.skipTest("test requires config.jax_enable_custom_prng")
|
||||
@ -517,6 +516,28 @@ class PrngTest(jtu.JaxTestCase):
|
||||
key = random.unsafe_rbg_key(42)
|
||||
self.assertIs(key.impl, prng.unsafe_rbg_prng_impl)
|
||||
|
||||
def test_key_construction_with_explicit_impl_name(self):
|
||||
def check_key_has_impl(key, impl):
|
||||
if isinstance(key, random.PRNGKeyArray):
|
||||
self.assertIs(key.impl, impl)
|
||||
else:
|
||||
self.assertEqual(key.dtype, jnp.dtype('uint32'))
|
||||
self.assertEqual(key.shape, impl.key_shape)
|
||||
|
||||
key = random.key(42, impl='threefry2x32')
|
||||
check_key_has_impl(key, prng.threefry_prng_impl)
|
||||
key = random.key(42, impl='rbg')
|
||||
check_key_has_impl(key, prng.rbg_prng_impl)
|
||||
key = random.key(42, impl='unsafe_rbg')
|
||||
check_key_has_impl(key, prng.unsafe_rbg_prng_impl)
|
||||
|
||||
key = random.PRNGKey(42, impl='threefry2x32')
|
||||
check_key_has_impl(key, prng.threefry_prng_impl)
|
||||
key = random.PRNGKey(42, impl='rbg')
|
||||
check_key_has_impl(key, prng.rbg_prng_impl)
|
||||
key = random.PRNGKey(42, impl='unsafe_rbg')
|
||||
check_key_has_impl(key, prng.unsafe_rbg_prng_impl)
|
||||
|
||||
def test_key_array_indexing_0d(self):
|
||||
if not config.jax_enable_custom_prng:
|
||||
self.skipTest("test requires config.jax_enable_custom_prng")
|
||||
|
Loading…
x
Reference in New Issue
Block a user