update package/API reference docs to new-style typed PRNG keys

This commit is contained in:
Roy Frostig 2023-08-14 12:59:09 -07:00
parent f0afc1b43d
commit 98f790f5d5
7 changed files with 39 additions and 35 deletions

View File

@ -14,8 +14,8 @@ consistent with definitions used in Keras and Sonnet.
An initializer is a function that takes three arguments: An initializer is a function that takes three arguments:
``(key, shape, dtype)`` and returns an array with dimensions ``shape`` and ``(key, shape, dtype)`` and returns an array with dimensions ``shape`` and
data type ``dtype``. Argument ``key`` is a :class:`jax.random.PRNGKey` random data type ``dtype``. Argument ``key`` is a PRNG key (e.g. from
key used when generating random numbers to initialize the array. :func:`jax.random.key`), used to generate random numbers to initialize the array.
.. autosummary:: .. autosummary::
:toctree: _autosummary :toctree: _autosummary

View File

@ -280,7 +280,7 @@ def jit(
... def selu(x, alpha=1.67, lmbda=1.05): ... def selu(x, alpha=1.67, lmbda=1.05):
... return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha) ... return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
>>> >>>
>>> key = jax.random.PRNGKey(0) >>> key = jax.random.key(0)
>>> x = jax.random.normal(key, (10,)) >>> x = jax.random.normal(key, (10,))
>>> print(selu(x)) # doctest: +SKIP >>> print(selu(x)) # doctest: +SKIP
[-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748 [-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748

View File

@ -1277,7 +1277,7 @@ def ensure_compile_time_eval():
@jax.jit @jax.jit
def jax_fn(x): def jax_fn(x):
with jax.ensure_compile_time_eval(): with jax.ensure_compile_time_eval():
y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100) y = random.randint(random.key(0), (1000,1000), 0, 100)
y2 = y @ y y2 = y @ y
x2 = jnp.sum(y2) * x x2 = jnp.sum(y2) * x
return x2 return x2
@ -1285,7 +1285,7 @@ def ensure_compile_time_eval():
A similar behavior can often be achieved simply by 'hoisting' the constant A similar behavior can often be achieved simply by 'hoisting' the constant
expression out of the corresponding staging API:: expression out of the corresponding staging API::
y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100) y = random.randint(random.key(0), (1000,1000), 0, 100)
@jax.jit @jax.jit
def jax_fn(x): def jax_fn(x):

View File

@ -2101,7 +2101,7 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
Example 2: partial products of an array of matrices Example 2: partial products of an array of matrices
>>> mats = jax.random.uniform(jax.random.PRNGKey(0), (4, 2, 2)) >>> mats = jax.random.uniform(jax.random.key(0), (4, 2, 2))
>>> partial_prods = lax.associative_scan(jnp.matmul, mats) >>> partial_prods = lax.associative_scan(jnp.matmul, mats)
>>> partial_prods.shape >>> partial_prods.shape
(4, 2, 2) (4, 2, 2)

View File

@ -62,7 +62,7 @@ def zeros(key: KeyArray,
The ``key`` argument is ignored. The ``key`` argument is ignored.
>>> import jax, jax.numpy as jnp >>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32) >>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32)
Array([[0., 0., 0.], Array([[0., 0., 0.],
[0., 0., 0.]], dtype=float32) [0., 0., 0.]], dtype=float32)
""" """
@ -77,7 +77,7 @@ def ones(key: KeyArray,
The ``key`` argument is ignored. The ``key`` argument is ignored.
>>> import jax, jax.numpy as jnp >>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.ones(jax.random.PRNGKey(42), (3, 2), jnp.float32) >>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32)
Array([[1., 1.], Array([[1., 1.],
[1., 1.], [1., 1.],
[1., 1.]], dtype=float32) [1., 1.]], dtype=float32)
@ -96,7 +96,7 @@ def constant(value: ArrayLike,
>>> import jax, jax.numpy as jnp >>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.constant(-7) >>> initializer = jax.nn.initializers.constant(-7)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) >>> initializer(jax.random.key(42), (2, 3), jnp.float32)
Array([[-7., -7., -7.], Array([[-7., -7., -7.],
[-7., -7., -7.]], dtype=float32) [-7., -7., -7.]], dtype=float32)
""" """
@ -122,7 +122,7 @@ def uniform(scale: RealNumeric = 1e-2,
>>> import jax, jax.numpy as jnp >>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.uniform(10.0) >>> initializer = jax.nn.initializers.uniform(10.0)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[7.298188 , 8.691938 , 8.7230015], Array([[7.298188 , 8.691938 , 8.7230015],
[2.0818567, 1.8662417, 5.5022564]], dtype=float32) [2.0818567, 1.8662417, 5.5022564]], dtype=float32)
""" """
@ -148,7 +148,7 @@ def normal(stddev: RealNumeric = 1e-2,
>>> import jax, jax.numpy as jnp >>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.normal(5.0) >>> initializer = jax.nn.initializers.normal(5.0)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[ 3.0613258 , 5.6129413 , 5.6866574 ], Array([[ 3.0613258 , 5.6129413 , 5.6866574 ],
[-4.063663 , -4.4520254 , 0.63115686]], dtype=float32) [-4.063663 , -4.4520254 , 0.63115686]], dtype=float32)
""" """
@ -376,7 +376,7 @@ def glorot_uniform(in_axis: int | Sequence[int] = -2,
>>> import jax, jax.numpy as jnp >>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_uniform() >>> initializer = jax.nn.initializers.glorot_uniform()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[ 0.50350785, 0.8088631 , 0.81566876], Array([[ 0.50350785, 0.8088631 , 0.81566876],
[-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32) [-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32)
@ -414,7 +414,7 @@ def glorot_normal(in_axis: int | Sequence[int] = -2,
>>> import jax, jax.numpy as jnp >>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_normal() >>> initializer = jax.nn.initializers.glorot_normal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[ 0.41770416, 0.75262755, 0.7619329 ], Array([[ 0.41770416, 0.75262755, 0.7619329 ],
[-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32) [-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32)
@ -452,7 +452,7 @@ def lecun_uniform(in_axis: int | Sequence[int] = -2,
>>> import jax, jax.numpy as jnp >>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.lecun_uniform() >>> initializer = jax.nn.initializers.lecun_uniform()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[ 0.56293887, 0.90433645, 0.9119454 ], Array([[ 0.56293887, 0.90433645, 0.9119454 ],
[-0.71479625, -0.7676109 , 0.12302713]], dtype=float32) [-0.71479625, -0.7676109 , 0.12302713]], dtype=float32)
@ -488,7 +488,7 @@ def lecun_normal(in_axis: int | Sequence[int] = -2,
>>> import jax, jax.numpy as jnp >>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.lecun_normal() >>> initializer = jax.nn.initializers.lecun_normal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[ 0.46700746, 0.8414632 , 0.8518669 ], Array([[ 0.46700746, 0.8414632 , 0.8518669 ],
[-0.61677957, -0.67402434, 0.09683388]], dtype=float32) [-0.61677957, -0.67402434, 0.09683388]], dtype=float32)
@ -524,7 +524,7 @@ def he_uniform(in_axis: int | Sequence[int] = -2,
>>> import jax, jax.numpy as jnp >>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.he_uniform() >>> initializer = jax.nn.initializers.he_uniform()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[ 0.79611576, 1.2789248 , 1.2896855 ], Array([[ 0.79611576, 1.2789248 , 1.2896855 ],
[-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32) [-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32)
@ -562,7 +562,7 @@ def he_normal(in_axis: int | Sequence[int] = -2,
>>> import jax, jax.numpy as jnp >>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.he_normal() >>> initializer = jax.nn.initializers.he_normal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[ 0.6604483 , 1.1900088 , 1.2047218 ], Array([[ 0.6604483 , 1.1900088 , 1.2047218 ],
[-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32) [-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32)
@ -595,7 +595,7 @@ def orthogonal(scale: RealNumeric = 1.0,
>>> import jax, jax.numpy as jnp >>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.orthogonal() >>> initializer = jax.nn.initializers.orthogonal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01], Array([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01],
[ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32) [ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32)
""" """
@ -638,7 +638,7 @@ def delta_orthogonal(
>>> import jax, jax.numpy as jnp >>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.delta_orthogonal() >>> initializer = jax.nn.initializers.delta_orthogonal()
>>> initializer(jax.random.PRNGKey(42), (3, 3, 3), jnp.float32) # doctest: +SKIP >>> initializer(jax.random.key(42), (3, 3, 3), jnp.float32) # doctest: +SKIP
Array([[[ 0. , 0. , 0. ], Array([[[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ], [ 0. , 0. , 0. ],
[ 0. , 0. , 0. ]], [ 0. , 0. , 0. ]],

View File

@ -244,7 +244,7 @@ def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray:
"""Folds in data to a PRNG key to form a new PRNG key. """Folds in data to a PRNG key to form a new PRNG key.
Args: Args:
key: a PRNG key (from ``PRNGKey``, ``split``, ``fold_in``). key: a PRNG key (from ``key``, ``split``, ``fold_in``).
data: a 32bit integer representing data to be folded in to the key. data: a 32bit integer representing data to be folded in to the key.
Returns: Returns:
@ -274,7 +274,7 @@ def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray:
"""Splits a PRNG key into `num` new keys by adding a leading axis. """Splits a PRNG key into `num` new keys by adding a leading axis.
Args: Args:
key: a PRNG key (from ``PRNGKey``, ``split``, ``fold_in``). key: a PRNG key (from ``key``, ``split``, ``fold_in``).
num: optional, a positive integer (or tuple of integers) indicating num: optional, a positive integer (or tuple of integers) indicating
the number (or shape) of keys to produce. Defaults to 2. the number (or shape) of keys to produce. Defaults to 2.

View File

@ -22,24 +22,25 @@ Basic usage
>>> seed = 1701 >>> seed = 1701
>>> num_steps = 100 >>> num_steps = 100
>>> key = jax.random.PRNGKey(seed) >>> key = jax.random.key(seed)
>>> for i in range(num_steps): >>> for i in range(num_steps):
... key, subkey = jax.random.split(key) ... key, subkey = jax.random.split(key)
... params = compiled_update(subkey, params, next(batches)) # doctest: +SKIP ... params = compiled_update(subkey, params, next(batches)) # doctest: +SKIP
PRNG Keys PRNG keys
--------- ---------
Unlike the *stateful* pseudorandom number generators (PRNGs) that users of NumPy and Unlike the *stateful* pseudorandom number generators (PRNGs) that users of NumPy and
SciPy may be accustomed to, JAX random functions all require an explicit PRNG state to SciPy may be accustomed to, JAX random functions all require an explicit PRNG state to
be passed as a first argument. be passed as a first argument.
The random state is described by two unsigned 32-bit integers that we call a **key**, The random state is described by a special array element type that we call a **key**,
usually generated by the :py:func:`jax.random.PRNGKey` function:: usually generated by the :py:func:`jax.random.key` function::
>>> from jax import random >>> from jax import random
>>> key = random.PRNGKey(0) >>> key = random.key(0)
>>> key >>> key
Array([0, 0], dtype=uint32) Array((), dtype=key<fry>) overlaying:
[0 0]
This key can then be used in any of JAX's random number generation routines:: This key can then be used in any of JAX's random number generation routines::
@ -60,8 +61,8 @@ If you need a new random number, you can use :meth:`jax.random.split` to generat
Advanced Advanced
-------- --------
Design and Context Design and background
================== =====================
**TLDR**: JAX PRNG = `Threefry counter PRNG <http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_ **TLDR**: JAX PRNG = `Threefry counter PRNG <http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_
+ a functional array-oriented `splitting model <https://dl.acm.org/citation.cfm?id=2503784>`_ + a functional array-oriented `splitting model <https://dl.acm.org/citation.cfm?id=2503784>`_
@ -79,16 +80,19 @@ To summarize, among other requirements, the JAX PRNG aims to:
Advanced RNG configuration Advanced RNG configuration
========================== ==========================
JAX provides several PRNG implementations (controlled by the JAX provides several PRNG implementations. A specific one can be
`jax_default_prng_impl` flag). selected with the optional `impl` keyword argument to
`jax.random.key`. When no `impl` option is passed to the `key`
constructor, the implementation is determined by the global
`jax_default_prng_impl` configuration flag.
- **default** - **default**, `"threefry2x32"`:
`A counter-based PRNG built around the Threefry hash function <http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_. `A counter-based PRNG built around the Threefry hash function <http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_.
- *experimental* A PRNG that thinly wraps the XLA Random Bit Generator (RBG) algorithm. See - *experimental* A PRNG that thinly wraps the XLA Random Bit Generator (RBG) algorithm. See
`TF doc <https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator>`_. `TF doc <https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator>`_.
- "rbg" uses ThreeFry for splitting, and XLA RBG for data generation. - `"rbg"` uses ThreeFry for splitting, and XLA RBG for data generation.
- "unsafe_rbg" exists only for demonstration purposes, using RBG both for - `"unsafe_rbg"` exists only for demonstration purposes, using RBG both for
splitting (using an untested made up algorithm) and generating. splitting (using an untested made up algorithm) and generating.
The random streams generated by these experimental implementations haven't The random streams generated by these experimental implementations haven't
@ -126,7 +130,7 @@ robust hash functions for `jax.random.split` and `jax.random.fold_in`. Therefore
less safe in the sense that the quality of random streams it generates from less safe in the sense that the quality of random streams it generates from
different keys is less well understood. different keys is less well understood.
For more about jax_threefry_partitionable, see For more about `jax_threefry_partitionable`, see
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
""" """