mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #16086 from froystig:upgraded-key-ctor
PiperOrigin-RevId: 534152508
This commit is contained in:
commit
69d6c1b13c
@ -54,6 +54,7 @@ Shape = Sequence[int]
|
||||
|
||||
# TODO(frostig): simplify once we always enable_custom_prng
|
||||
KeyArray = Union[Array, prng.PRNGKeyArray]
|
||||
PRNGKeyArray = prng.PRNGKeyArray
|
||||
|
||||
UINT_DTYPES = prng.UINT_DTYPES
|
||||
|
||||
@ -115,6 +116,22 @@ def default_prng_impl():
|
||||
|
||||
### key operations
|
||||
|
||||
def key(seed: Union[int, Array]) -> PRNGKeyArray:
|
||||
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
|
||||
|
||||
The result is a scalar array with a key that indicates the default PRNG
|
||||
implementation, as determined by the ``jax_default_prng_impl`` config flag.
|
||||
|
||||
Args:
|
||||
seed: a 64- or 32-bit integer used as the value of the key.
|
||||
|
||||
Returns:
|
||||
A scalar PRNG key array, consumable by random functions as well as ``split``
|
||||
and ``fold_in``.
|
||||
"""
|
||||
# TODO(frostig): Take impl as optional argument
|
||||
impl = default_prng_impl()
|
||||
return prng.seed_with_impl(impl, seed)
|
||||
|
||||
def PRNGKey(seed: Union[int, Array]) -> KeyArray:
|
||||
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
|
||||
@ -128,7 +145,6 @@ def PRNGKey(seed: Union[int, Array]) -> KeyArray:
|
||||
Returns:
|
||||
A PRNG key, consumable by random functions as well as ``split``
|
||||
and ``fold_in``.
|
||||
|
||||
"""
|
||||
impl = default_prng_impl()
|
||||
if isinstance(seed, prng.PRNGKeyArray):
|
||||
|
@ -170,6 +170,7 @@ from jax._src.random import (
|
||||
generalized_normal as generalized_normal,
|
||||
geometric as geometric,
|
||||
gumbel as gumbel,
|
||||
key as key,
|
||||
key_data as key_data,
|
||||
laplace as laplace,
|
||||
logistic as logistic,
|
||||
|
Loading…
x
Reference in New Issue
Block a user