Merge pull request #16086 from froystig:upgraded-key-ctor

PiperOrigin-RevId: 534152508
This commit is contained in:
jax authors 2023-05-22 12:46:41 -07:00
commit 69d6c1b13c
2 changed files with 18 additions and 1 deletions

View File

@ -54,6 +54,7 @@ Shape = Sequence[int]
# TODO(frostig): simplify once we always enable_custom_prng
KeyArray = Union[Array, prng.PRNGKeyArray]
PRNGKeyArray = prng.PRNGKeyArray
UINT_DTYPES = prng.UINT_DTYPES
@ -115,6 +116,22 @@ def default_prng_impl():
### key operations
def key(seed: Union[int, Array]) -> PRNGKeyArray:
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
The result is a scalar array with a key that indicates the default PRNG
implementation, as determined by the ``jax_default_prng_impl`` config flag.
Args:
seed: a 64- or 32-bit integer used as the value of the key.
Returns:
A scalar PRNG key array, consumable by random functions as well as ``split``
and ``fold_in``.
"""
# TODO(frostig): Take impl as optional argument
impl = default_prng_impl()
return prng.seed_with_impl(impl, seed)
def PRNGKey(seed: Union[int, Array]) -> KeyArray:
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
@ -128,7 +145,6 @@ def PRNGKey(seed: Union[int, Array]) -> KeyArray:
Returns:
A PRNG key, consumable by random functions as well as ``split``
and ``fold_in``.
"""
impl = default_prng_impl()
if isinstance(seed, prng.PRNGKeyArray):

View File

@ -170,6 +170,7 @@ from jax._src.random import (
generalized_normal as generalized_normal,
geometric as geometric,
gumbel as gumbel,
key as key,
key_data as key_data,
laplace as laplace,
logistic as logistic,