update package/API reference docs to new-style typed PRNG keys

This commit is contained in:
Roy Frostig 2023-08-14 12:59:09 -07:00
parent f0afc1b43d
commit 98f790f5d5
7 changed files with 39 additions and 35 deletions

View File

@ -14,8 +14,8 @@ consistent with definitions used in Keras and Sonnet.
An initializer is a function that takes three arguments:
``(key, shape, dtype)`` and returns an array with dimensions ``shape`` and
data type ``dtype``. Argument ``key`` is a :class:`jax.random.PRNGKey` random
key used when generating random numbers to initialize the array.
data type ``dtype``. Argument ``key`` is a PRNG key (e.g. from
:func:`jax.random.key`), used to generate random numbers to initialize the array.
.. autosummary::
:toctree: _autosummary

View File

@ -280,7 +280,7 @@ def jit(
... def selu(x, alpha=1.67, lmbda=1.05):
... return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
>>>
>>> key = jax.random.PRNGKey(0)
>>> key = jax.random.key(0)
>>> x = jax.random.normal(key, (10,))
>>> print(selu(x)) # doctest: +SKIP
[-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748

View File

@ -1277,7 +1277,7 @@ def ensure_compile_time_eval():
@jax.jit
def jax_fn(x):
with jax.ensure_compile_time_eval():
y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100)
y = random.randint(random.key(0), (1000,1000), 0, 100)
y2 = y @ y
x2 = jnp.sum(y2) * x
return x2
@ -1285,7 +1285,7 @@ def ensure_compile_time_eval():
A similar behavior can often be achieved simply by 'hoisting' the constant
expression out of the corresponding staging API::
y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100)
y = random.randint(random.key(0), (1000,1000), 0, 100)
@jax.jit
def jax_fn(x):

View File

@ -2101,7 +2101,7 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
Example 2: partial products of an array of matrices
>>> mats = jax.random.uniform(jax.random.PRNGKey(0), (4, 2, 2))
>>> mats = jax.random.uniform(jax.random.key(0), (4, 2, 2))
>>> partial_prods = lax.associative_scan(jnp.matmul, mats)
>>> partial_prods.shape
(4, 2, 2)

View File

@ -62,7 +62,7 @@ def zeros(key: KeyArray,
The ``key`` argument is ignored.
>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32)
>>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32)
Array([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)
"""
@ -77,7 +77,7 @@ def ones(key: KeyArray,
The ``key`` argument is ignored.
>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.ones(jax.random.PRNGKey(42), (3, 2), jnp.float32)
>>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32)
Array([[1., 1.],
[1., 1.],
[1., 1.]], dtype=float32)
@ -96,7 +96,7 @@ def constant(value: ArrayLike,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.constant(-7)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)
Array([[-7., -7., -7.],
[-7., -7., -7.]], dtype=float32)
"""
@ -122,7 +122,7 @@ def uniform(scale: RealNumeric = 1e-2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.uniform(10.0)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[7.298188 , 8.691938 , 8.7230015],
[2.0818567, 1.8662417, 5.5022564]], dtype=float32)
"""
@ -148,7 +148,7 @@ def normal(stddev: RealNumeric = 1e-2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.normal(5.0)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[ 3.0613258 , 5.6129413 , 5.6866574 ],
[-4.063663 , -4.4520254 , 0.63115686]], dtype=float32)
"""
@ -376,7 +376,7 @@ def glorot_uniform(in_axis: int | Sequence[int] = -2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_uniform()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[ 0.50350785, 0.8088631 , 0.81566876],
[-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32)
@ -414,7 +414,7 @@ def glorot_normal(in_axis: int | Sequence[int] = -2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_normal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[ 0.41770416, 0.75262755, 0.7619329 ],
[-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32)
@ -452,7 +452,7 @@ def lecun_uniform(in_axis: int | Sequence[int] = -2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.lecun_uniform()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[ 0.56293887, 0.90433645, 0.9119454 ],
[-0.71479625, -0.7676109 , 0.12302713]], dtype=float32)
@ -488,7 +488,7 @@ def lecun_normal(in_axis: int | Sequence[int] = -2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.lecun_normal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[ 0.46700746, 0.8414632 , 0.8518669 ],
[-0.61677957, -0.67402434, 0.09683388]], dtype=float32)
@ -524,7 +524,7 @@ def he_uniform(in_axis: int | Sequence[int] = -2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.he_uniform()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[ 0.79611576, 1.2789248 , 1.2896855 ],
[-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32)
@ -562,7 +562,7 @@ def he_normal(in_axis: int | Sequence[int] = -2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.he_normal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[ 0.6604483 , 1.1900088 , 1.2047218 ],
[-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32)
@ -595,7 +595,7 @@ def orthogonal(scale: RealNumeric = 1.0,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.orthogonal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01],
[ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32)
"""
@ -638,7 +638,7 @@ def delta_orthogonal(
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.delta_orthogonal()
>>> initializer(jax.random.PRNGKey(42), (3, 3, 3), jnp.float32) # doctest: +SKIP
>>> initializer(jax.random.key(42), (3, 3, 3), jnp.float32) # doctest: +SKIP
Array([[[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ]],

View File

@ -244,7 +244,7 @@ def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray:
"""Folds in data to a PRNG key to form a new PRNG key.
Args:
key: a PRNG key (from ``PRNGKey``, ``split``, ``fold_in``).
key: a PRNG key (from ``key``, ``split``, ``fold_in``).
data: a 32bit integer representing data to be folded in to the key.
Returns:
@ -274,7 +274,7 @@ def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray:
"""Splits a PRNG key into `num` new keys by adding a leading axis.
Args:
key: a PRNG key (from ``PRNGKey``, ``split``, ``fold_in``).
key: a PRNG key (from ``key``, ``split``, ``fold_in``).
num: optional, a positive integer (or tuple of integers) indicating
the number (or shape) of keys to produce. Defaults to 2.

View File

@ -22,24 +22,25 @@ Basic usage
>>> seed = 1701
>>> num_steps = 100
>>> key = jax.random.PRNGKey(seed)
>>> key = jax.random.key(seed)
>>> for i in range(num_steps):
... key, subkey = jax.random.split(key)
... params = compiled_update(subkey, params, next(batches)) # doctest: +SKIP
PRNG Keys
PRNG keys
---------
Unlike the *stateful* pseudorandom number generators (PRNGs) that users of NumPy and
SciPy may be accustomed to, JAX random functions all require an explicit PRNG state to
be passed as a first argument.
The random state is described by two unsigned 32-bit integers that we call a **key**,
usually generated by the :py:func:`jax.random.PRNGKey` function::
The random state is described by a special array element type that we call a **key**,
usually generated by the :py:func:`jax.random.key` function::
>>> from jax import random
>>> key = random.PRNGKey(0)
>>> key = random.key(0)
>>> key
Array([0, 0], dtype=uint32)
Array((), dtype=key<fry>) overlaying:
[0 0]
This key can then be used in any of JAX's random number generation routines::
@ -60,8 +61,8 @@ If you need a new random number, you can use :meth:`jax.random.split` to generat
Advanced
--------
Design and Context
==================
Design and background
=====================
**TLDR**: JAX PRNG = `Threefry counter PRNG <http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_
+ a functional array-oriented `splitting model <https://dl.acm.org/citation.cfm?id=2503784>`_
@ -79,16 +80,19 @@ To summarize, among other requirements, the JAX PRNG aims to:
Advanced RNG configuration
==========================
JAX provides several PRNG implementations (controlled by the
`jax_default_prng_impl` flag).
JAX provides several PRNG implementations. A specific one can be
selected with the optional `impl` keyword argument to
`jax.random.key`. When no `impl` option is passed to the `key`
constructor, the implementation is determined by the global
`jax_default_prng_impl` configuration flag.
- **default**
- **default**, `"threefry2x32"`:
`A counter-based PRNG built around the Threefry hash function <http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_.
- *experimental* A PRNG that thinly wraps the XLA Random Bit Generator (RBG) algorithm. See
`TF doc <https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator>`_.
- "rbg" uses ThreeFry for splitting, and XLA RBG for data generation.
- "unsafe_rbg" exists only for demonstration purposes, using RBG both for
- `"rbg"` uses ThreeFry for splitting, and XLA RBG for data generation.
- `"unsafe_rbg"` exists only for demonstration purposes, using RBG both for
splitting (using an untested made up algorithm) and generating.
The random streams generated by these experimental implementations haven't
@ -126,7 +130,7 @@ robust hash functions for `jax.random.split` and `jax.random.fold_in`. Therefore
less safe in the sense that the quality of random streams it generates from
different keys is less well understood.
For more about jax_threefry_partitionable, see
For more about `jax_threefry_partitionable`, see
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
"""