Merge pull request #10362 from JeppeKlitgaard:fix-random-reexport

PiperOrigin-RevId: 442936248
This commit is contained in:
jax authors 2022-04-19 16:04:31 -07:00
commit 88ac4edf63

View File

@ -102,18 +102,18 @@ The possible reasons not use the default RNG are:
Here is a short summary:
.. table::
.. table::
:widths: auto
================================= ================= === ==========
Property ThreeFry, default rbg unsafe_rbg
================================= ================= === ==========
Fast on TPU
always correct w/ scan
always correct w/ remat
identical across CPU/GPU/TPU
identical across JAX/XLA versions
identical across shardings
always correct w/ scan
always correct w/ remat
identical across CPU/GPU/TPU
identical across JAX/XLA versions
identical across shardings
================================= ================= === ==========
"""
@ -122,7 +122,8 @@ Here is a short summary:
# TODO(frostig): replace with KeyArray from jax._src.random once we
# always enable_custom_prng
from jax._src.prng import PRNGKeyArray as KeyArray
from jax._src.prng import PRNGKeyArray
KeyArray = PRNGKeyArray
from jax._src.random import (
PRNGKey as PRNGKey,