mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #10362 from JeppeKlitgaard:fix-random-reexport
PiperOrigin-RevId: 442936248
This commit is contained in:
commit
88ac4edf63
@ -102,18 +102,18 @@ The possible reasons not use the default RNG are:
|
||||
|
||||
Here is a short summary:
|
||||
|
||||
.. table::
|
||||
.. table::
|
||||
:widths: auto
|
||||
|
||||
================================= ================= === ==========
|
||||
Property ThreeFry, default rbg unsafe_rbg
|
||||
================================= ================= === ==========
|
||||
Fast on TPU ✅ ✅
|
||||
always correct w/ scan ✅ ✅
|
||||
always correct w/ remat ✅ ✅
|
||||
identical across CPU/GPU/TPU ✅ ✅
|
||||
identical across JAX/XLA versions ✅
|
||||
identical across shardings ✅
|
||||
always correct w/ scan ✅ ✅
|
||||
always correct w/ remat ✅ ✅
|
||||
identical across CPU/GPU/TPU ✅ ✅
|
||||
identical across JAX/XLA versions ✅
|
||||
identical across shardings ✅
|
||||
================================= ================= === ==========
|
||||
|
||||
"""
|
||||
@ -122,7 +122,8 @@ Here is a short summary:
|
||||
|
||||
# TODO(frostig): replace with KeyArray from jax._src.random once we
|
||||
# always enable_custom_prng
|
||||
from jax._src.prng import PRNGKeyArray as KeyArray
|
||||
from jax._src.prng import PRNGKeyArray
|
||||
KeyArray = PRNGKeyArray
|
||||
|
||||
from jax._src.random import (
|
||||
PRNGKey as PRNGKey,
|
||||
|
Loading…
x
Reference in New Issue
Block a user