From 98f790f5d524372d45ca86bc8da07ca498d72b3b Mon Sep 17 00:00:00 2001
From: Roy Frostig <frostig@google.com>
Date: Mon, 14 Aug 2023 12:59:09 -0700
Subject: [PATCH] update package/API reference docs to new-style typed PRNG
 keys

---
 docs/jax.nn.initializers.rst       |  4 ++--
 jax/_src/api.py                    |  2 +-
 jax/_src/core.py                   |  4 ++--
 jax/_src/lax/control_flow/loops.py |  2 +-
 jax/_src/nn/initializers.py        | 26 ++++++++++++------------
 jax/_src/random.py                 |  4 ++--
 jax/random.py                      | 32 +++++++++++++++++-------------
 7 files changed, 39 insertions(+), 35 deletions(-)

diff --git a/docs/jax.nn.initializers.rst b/docs/jax.nn.initializers.rst
index d96ba43f0..246e0cdbe 100644
--- a/docs/jax.nn.initializers.rst
+++ b/docs/jax.nn.initializers.rst
@@ -14,8 +14,8 @@ consistent with definitions used in Keras and Sonnet.
 
 An initializer is a function that takes three arguments:
 ``(key, shape, dtype)`` and returns an array with dimensions ``shape`` and
-data type ``dtype``. Argument ``key`` is a :class:`jax.random.PRNGKey` random
-key used when generating random numbers to initialize the array.
+data type ``dtype``. Argument ``key`` is a PRNG key (e.g. from
+:func:`jax.random.key`), used to generate random numbers to initialize the array.
 
 .. autosummary::
   :toctree: _autosummary
diff --git a/jax/_src/api.py b/jax/_src/api.py
index 005e0ceca..2fecc4fd7 100644
--- a/jax/_src/api.py
+++ b/jax/_src/api.py
@@ -280,7 +280,7 @@ def jit(
     ... def selu(x, alpha=1.67, lmbda=1.05):
     ...   return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
     >>>
-    >>> key = jax.random.PRNGKey(0)
+    >>> key = jax.random.key(0)
     >>> x = jax.random.normal(key, (10,))
     >>> print(selu(x))  # doctest: +SKIP
     [-0.54485  0.27744 -0.29255 -0.91421 -0.62452 -0.24748
diff --git a/jax/_src/core.py b/jax/_src/core.py
index c351e6980..b207b39e5 100644
--- a/jax/_src/core.py
+++ b/jax/_src/core.py
@@ -1277,7 +1277,7 @@ def ensure_compile_time_eval():
     @jax.jit
     def jax_fn(x):
       with jax.ensure_compile_time_eval():
-        y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100)
+        y = random.randint(random.key(0), (1000,1000), 0, 100)
       y2 = y @ y
       x2 = jnp.sum(y2) * x
       return x2
@@ -1285,7 +1285,7 @@ def ensure_compile_time_eval():
   A similar behavior can often be achieved simply by 'hoisting' the constant
   expression out of the corresponding staging API::
 
-    y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100)
+    y = random.randint(random.key(0), (1000,1000), 0, 100)
 
     @jax.jit
     def jax_fn(x):
diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py
index e939f2170..0c6ae2fe0 100644
--- a/jax/_src/lax/control_flow/loops.py
+++ b/jax/_src/lax/control_flow/loops.py
@@ -2101,7 +2101,7 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
 
   Example 2: partial products of an array of matrices
 
-  >>> mats = jax.random.uniform(jax.random.PRNGKey(0), (4, 2, 2))
+  >>> mats = jax.random.uniform(jax.random.key(0), (4, 2, 2))
   >>> partial_prods = lax.associative_scan(jnp.matmul, mats)
   >>> partial_prods.shape
   (4, 2, 2)
diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py
index 5468fd663..d7353c396 100644
--- a/jax/_src/nn/initializers.py
+++ b/jax/_src/nn/initializers.py
@@ -62,7 +62,7 @@ def zeros(key: KeyArray,
   The ``key`` argument is ignored.
 
   >>> import jax, jax.numpy as jnp
-  >>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32)
+  >>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32)
   Array([[0., 0., 0.],
          [0., 0., 0.]], dtype=float32)
   """
@@ -77,7 +77,7 @@ def ones(key: KeyArray,
   The ``key`` argument is ignored.
 
   >>> import jax, jax.numpy as jnp
-  >>> jax.nn.initializers.ones(jax.random.PRNGKey(42), (3, 2), jnp.float32)
+  >>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32)
   Array([[1., 1.],
          [1., 1.],
          [1., 1.]], dtype=float32)
@@ -96,7 +96,7 @@ def constant(value: ArrayLike,
 
   >>> import jax, jax.numpy as jnp
   >>> initializer = jax.nn.initializers.constant(-7)
-  >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
+  >>> initializer(jax.random.key(42), (2, 3), jnp.float32)
   Array([[-7., -7., -7.],
          [-7., -7., -7.]], dtype=float32)
   """
@@ -122,7 +122,7 @@ def uniform(scale: RealNumeric = 1e-2,
 
   >>> import jax, jax.numpy as jnp
   >>> initializer = jax.nn.initializers.uniform(10.0)
-  >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)  # doctest: +SKIP
+  >>> initializer(jax.random.key(42), (2, 3), jnp.float32)  # doctest: +SKIP
   Array([[7.298188 , 8.691938 , 8.7230015],
          [2.0818567, 1.8662417, 5.5022564]], dtype=float32)
   """
@@ -148,7 +148,7 @@ def normal(stddev: RealNumeric = 1e-2,
 
   >>> import jax, jax.numpy as jnp
   >>> initializer = jax.nn.initializers.normal(5.0)
-  >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)  # doctest: +SKIP
+  >>> initializer(jax.random.key(42), (2, 3), jnp.float32)  # doctest: +SKIP
   Array([[ 3.0613258 ,  5.6129413 ,  5.6866574 ],
          [-4.063663  , -4.4520254 ,  0.63115686]], dtype=float32)
   """
@@ -376,7 +376,7 @@ def glorot_uniform(in_axis: int | Sequence[int] = -2,
 
   >>> import jax, jax.numpy as jnp
   >>> initializer = jax.nn.initializers.glorot_uniform()
-  >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)  # doctest: +SKIP
+  >>> initializer(jax.random.key(42), (2, 3), jnp.float32)  # doctest: +SKIP
   Array([[ 0.50350785,  0.8088631 ,  0.81566876],
          [-0.6393332 , -0.6865721 ,  0.11003882]], dtype=float32)
 
@@ -414,7 +414,7 @@ def glorot_normal(in_axis: int | Sequence[int] = -2,
 
   >>> import jax, jax.numpy as jnp
   >>> initializer = jax.nn.initializers.glorot_normal()
-  >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)  # doctest: +SKIP
+  >>> initializer(jax.random.key(42), (2, 3), jnp.float32)  # doctest: +SKIP
   Array([[ 0.41770416,  0.75262755,  0.7619329 ],
          [-0.5516644 , -0.6028657 ,  0.08661086]], dtype=float32)
 
@@ -452,7 +452,7 @@ def lecun_uniform(in_axis: int | Sequence[int] = -2,
 
   >>> import jax, jax.numpy as jnp
   >>> initializer = jax.nn.initializers.lecun_uniform()
-  >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)  # doctest: +SKIP
+  >>> initializer(jax.random.key(42), (2, 3), jnp.float32)  # doctest: +SKIP
   Array([[ 0.56293887,  0.90433645,  0.9119454 ],
          [-0.71479625, -0.7676109 ,  0.12302713]], dtype=float32)
 
@@ -488,7 +488,7 @@ def lecun_normal(in_axis: int | Sequence[int] = -2,
 
   >>> import jax, jax.numpy as jnp
   >>> initializer = jax.nn.initializers.lecun_normal()
-  >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)  # doctest: +SKIP
+  >>> initializer(jax.random.key(42), (2, 3), jnp.float32)  # doctest: +SKIP
   Array([[ 0.46700746,  0.8414632 ,  0.8518669 ],
          [-0.61677957, -0.67402434,  0.09683388]], dtype=float32)
 
@@ -524,7 +524,7 @@ def he_uniform(in_axis: int | Sequence[int] = -2,
 
   >>> import jax, jax.numpy as jnp
   >>> initializer = jax.nn.initializers.he_uniform()
-  >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)  # doctest: +SKIP
+  >>> initializer(jax.random.key(42), (2, 3), jnp.float32)  # doctest: +SKIP
   Array([[ 0.79611576,  1.2789248 ,  1.2896855 ],
          [-1.0108745 , -1.0855657 ,  0.17398663]], dtype=float32)
 
@@ -562,7 +562,7 @@ def he_normal(in_axis: int | Sequence[int] = -2,
 
   >>> import jax, jax.numpy as jnp
   >>> initializer = jax.nn.initializers.he_normal()
-  >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)  # doctest: +SKIP
+  >>> initializer(jax.random.key(42), (2, 3), jnp.float32)  # doctest: +SKIP
   Array([[ 0.6604483 ,  1.1900088 ,  1.2047218 ],
          [-0.87225807, -0.95321447,  0.1369438 ]], dtype=float32)
 
@@ -595,7 +595,7 @@ def orthogonal(scale: RealNumeric = 1.0,
 
   >>> import jax, jax.numpy as jnp
   >>> initializer = jax.nn.initializers.orthogonal()
-  >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)  # doctest: +SKIP
+  >>> initializer(jax.random.key(42), (2, 3), jnp.float32)  # doctest: +SKIP
   Array([[ 3.9026976e-01,  7.2495741e-01, -5.6756169e-01],
          [ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]],            dtype=float32)
   """
@@ -638,7 +638,7 @@ def delta_orthogonal(
 
   >>> import jax, jax.numpy as jnp
   >>> initializer = jax.nn.initializers.delta_orthogonal()
-  >>> initializer(jax.random.PRNGKey(42), (3, 3, 3), jnp.float32)  # doctest: +SKIP
+  >>> initializer(jax.random.key(42), (3, 3, 3), jnp.float32)  # doctest: +SKIP
   Array([[[ 0.        ,  0.        ,  0.        ],
           [ 0.        ,  0.        ,  0.        ],
           [ 0.        ,  0.        ,  0.        ]],
diff --git a/jax/_src/random.py b/jax/_src/random.py
index f9045ebf3..2696b488a 100644
--- a/jax/_src/random.py
+++ b/jax/_src/random.py
@@ -244,7 +244,7 @@ def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray:
   """Folds in data to a PRNG key to form a new PRNG key.
 
   Args:
-    key: a PRNG key (from ``PRNGKey``, ``split``, ``fold_in``).
+    key: a PRNG key (from ``key``, ``split``, ``fold_in``).
     data: a 32bit integer representing data to be folded in to the key.
 
   Returns:
@@ -274,7 +274,7 @@ def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray:
   """Splits a PRNG key into `num` new keys by adding a leading axis.
 
   Args:
-    key: a PRNG key (from ``PRNGKey``, ``split``, ``fold_in``).
+    key: a PRNG key (from ``key``, ``split``, ``fold_in``).
     num: optional, a positive integer (or tuple of integers) indicating
       the number (or shape) of keys to produce. Defaults to 2.
 
diff --git a/jax/random.py b/jax/random.py
index c06f48b35..f65ec5858 100644
--- a/jax/random.py
+++ b/jax/random.py
@@ -22,24 +22,25 @@ Basic usage
 
 >>> seed = 1701
 >>> num_steps = 100
->>> key = jax.random.PRNGKey(seed)
+>>> key = jax.random.key(seed)
 >>> for i in range(num_steps):
 ...   key, subkey = jax.random.split(key)
 ...   params = compiled_update(subkey, params, next(batches))  # doctest: +SKIP
 
-PRNG Keys
+PRNG keys
 ---------
 
 Unlike the *stateful* pseudorandom number generators (PRNGs) that users of NumPy and
 SciPy may be accustomed to, JAX random functions all require an explicit PRNG state to
 be passed as a first argument.
-The random state is described by two unsigned 32-bit integers that we call a **key**,
-usually generated by the :py:func:`jax.random.PRNGKey` function::
+The random state is described by a special array element type that we call a **key**,
+usually generated by the :py:func:`jax.random.key` function::
 
     >>> from jax import random
-    >>> key = random.PRNGKey(0)
+    >>> key = random.key(0)
     >>> key
-    Array([0, 0], dtype=uint32)
+    Array((), dtype=key<fry>) overlaying:
+    [0 0]
 
 This key can then be used in any of JAX's random number generation routines::
 
@@ -60,8 +61,8 @@ If you need a new random number, you can use :meth:`jax.random.split` to generat
 Advanced
 --------
 
-Design and Context
-==================
+Design and background
+=====================
 
 **TLDR**: JAX PRNG = `Threefry counter PRNG <http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_
 + a functional array-oriented `splitting model <https://dl.acm.org/citation.cfm?id=2503784>`_
@@ -79,16 +80,19 @@ To summarize, among other requirements, the JAX PRNG aims to:
 Advanced RNG configuration
 ==========================
 
-JAX provides several PRNG implementations (controlled by the
-`jax_default_prng_impl` flag).
+JAX provides several PRNG implementations. A specific one can be
+selected with the optional `impl` keyword argument to
+`jax.random.key`. When no `impl` option is passed to the `key`
+constructor, the implementation is determined by the global
+`jax_default_prng_impl` configuration flag.
 
--   **default**
+-   **default**, `"threefry2x32"`:
     `A counter-based PRNG built around the Threefry hash function <http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_.
 -   *experimental* A PRNG that thinly wraps the XLA Random Bit Generator (RBG) algorithm. See
     `TF doc <https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator>`_.
 
-    -   "rbg" uses ThreeFry for splitting, and XLA RBG for data generation.
-    -   "unsafe_rbg" exists only for demonstration purposes, using RBG both for
+    -   `"rbg"` uses ThreeFry for splitting, and XLA RBG for data generation.
+    -   `"unsafe_rbg"` exists only for demonstration purposes, using RBG both for
         splitting (using an untested made up algorithm) and generating.
 
     The random streams generated by these experimental implementations haven't
@@ -126,7 +130,7 @@ robust hash functions for `jax.random.split` and `jax.random.fold_in`. Therefore
 less safe in the sense that the quality of random streams it generates from
 different keys is less well understood.
 
-For more about jax_threefry_partitionable, see
+For more about `jax_threefry_partitionable`, see
 https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
 """