mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
update package/API reference docs to new-style typed PRNG keys
This commit is contained in:
parent
f0afc1b43d
commit
98f790f5d5
@ -14,8 +14,8 @@ consistent with definitions used in Keras and Sonnet.
|
|||||||
|
|
||||||
An initializer is a function that takes three arguments:
|
An initializer is a function that takes three arguments:
|
||||||
``(key, shape, dtype)`` and returns an array with dimensions ``shape`` and
|
``(key, shape, dtype)`` and returns an array with dimensions ``shape`` and
|
||||||
data type ``dtype``. Argument ``key`` is a :class:`jax.random.PRNGKey` random
|
data type ``dtype``. Argument ``key`` is a PRNG key (e.g. from
|
||||||
key used when generating random numbers to initialize the array.
|
:func:`jax.random.key`), used to generate random numbers to initialize the array.
|
||||||
|
|
||||||
.. autosummary::
|
.. autosummary::
|
||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
|
@ -280,7 +280,7 @@ def jit(
|
|||||||
... def selu(x, alpha=1.67, lmbda=1.05):
|
... def selu(x, alpha=1.67, lmbda=1.05):
|
||||||
... return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
|
... return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
|
||||||
>>>
|
>>>
|
||||||
>>> key = jax.random.PRNGKey(0)
|
>>> key = jax.random.key(0)
|
||||||
>>> x = jax.random.normal(key, (10,))
|
>>> x = jax.random.normal(key, (10,))
|
||||||
>>> print(selu(x)) # doctest: +SKIP
|
>>> print(selu(x)) # doctest: +SKIP
|
||||||
[-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748
|
[-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748
|
||||||
|
@ -1277,7 +1277,7 @@ def ensure_compile_time_eval():
|
|||||||
@jax.jit
|
@jax.jit
|
||||||
def jax_fn(x):
|
def jax_fn(x):
|
||||||
with jax.ensure_compile_time_eval():
|
with jax.ensure_compile_time_eval():
|
||||||
y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100)
|
y = random.randint(random.key(0), (1000,1000), 0, 100)
|
||||||
y2 = y @ y
|
y2 = y @ y
|
||||||
x2 = jnp.sum(y2) * x
|
x2 = jnp.sum(y2) * x
|
||||||
return x2
|
return x2
|
||||||
@ -1285,7 +1285,7 @@ def ensure_compile_time_eval():
|
|||||||
A similar behavior can often be achieved simply by 'hoisting' the constant
|
A similar behavior can often be achieved simply by 'hoisting' the constant
|
||||||
expression out of the corresponding staging API::
|
expression out of the corresponding staging API::
|
||||||
|
|
||||||
y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100)
|
y = random.randint(random.key(0), (1000,1000), 0, 100)
|
||||||
|
|
||||||
@jax.jit
|
@jax.jit
|
||||||
def jax_fn(x):
|
def jax_fn(x):
|
||||||
|
@ -2101,7 +2101,7 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
|
|||||||
|
|
||||||
Example 2: partial products of an array of matrices
|
Example 2: partial products of an array of matrices
|
||||||
|
|
||||||
>>> mats = jax.random.uniform(jax.random.PRNGKey(0), (4, 2, 2))
|
>>> mats = jax.random.uniform(jax.random.key(0), (4, 2, 2))
|
||||||
>>> partial_prods = lax.associative_scan(jnp.matmul, mats)
|
>>> partial_prods = lax.associative_scan(jnp.matmul, mats)
|
||||||
>>> partial_prods.shape
|
>>> partial_prods.shape
|
||||||
(4, 2, 2)
|
(4, 2, 2)
|
||||||
|
@ -62,7 +62,7 @@ def zeros(key: KeyArray,
|
|||||||
The ``key`` argument is ignored.
|
The ``key`` argument is ignored.
|
||||||
|
|
||||||
>>> import jax, jax.numpy as jnp
|
>>> import jax, jax.numpy as jnp
|
||||||
>>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32)
|
>>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32)
|
||||||
Array([[0., 0., 0.],
|
Array([[0., 0., 0.],
|
||||||
[0., 0., 0.]], dtype=float32)
|
[0., 0., 0.]], dtype=float32)
|
||||||
"""
|
"""
|
||||||
@ -77,7 +77,7 @@ def ones(key: KeyArray,
|
|||||||
The ``key`` argument is ignored.
|
The ``key`` argument is ignored.
|
||||||
|
|
||||||
>>> import jax, jax.numpy as jnp
|
>>> import jax, jax.numpy as jnp
|
||||||
>>> jax.nn.initializers.ones(jax.random.PRNGKey(42), (3, 2), jnp.float32)
|
>>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32)
|
||||||
Array([[1., 1.],
|
Array([[1., 1.],
|
||||||
[1., 1.],
|
[1., 1.],
|
||||||
[1., 1.]], dtype=float32)
|
[1., 1.]], dtype=float32)
|
||||||
@ -96,7 +96,7 @@ def constant(value: ArrayLike,
|
|||||||
|
|
||||||
>>> import jax, jax.numpy as jnp
|
>>> import jax, jax.numpy as jnp
|
||||||
>>> initializer = jax.nn.initializers.constant(-7)
|
>>> initializer = jax.nn.initializers.constant(-7)
|
||||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
|
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)
|
||||||
Array([[-7., -7., -7.],
|
Array([[-7., -7., -7.],
|
||||||
[-7., -7., -7.]], dtype=float32)
|
[-7., -7., -7.]], dtype=float32)
|
||||||
"""
|
"""
|
||||||
@ -122,7 +122,7 @@ def uniform(scale: RealNumeric = 1e-2,
|
|||||||
|
|
||||||
>>> import jax, jax.numpy as jnp
|
>>> import jax, jax.numpy as jnp
|
||||||
>>> initializer = jax.nn.initializers.uniform(10.0)
|
>>> initializer = jax.nn.initializers.uniform(10.0)
|
||||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||||
Array([[7.298188 , 8.691938 , 8.7230015],
|
Array([[7.298188 , 8.691938 , 8.7230015],
|
||||||
[2.0818567, 1.8662417, 5.5022564]], dtype=float32)
|
[2.0818567, 1.8662417, 5.5022564]], dtype=float32)
|
||||||
"""
|
"""
|
||||||
@ -148,7 +148,7 @@ def normal(stddev: RealNumeric = 1e-2,
|
|||||||
|
|
||||||
>>> import jax, jax.numpy as jnp
|
>>> import jax, jax.numpy as jnp
|
||||||
>>> initializer = jax.nn.initializers.normal(5.0)
|
>>> initializer = jax.nn.initializers.normal(5.0)
|
||||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||||
Array([[ 3.0613258 , 5.6129413 , 5.6866574 ],
|
Array([[ 3.0613258 , 5.6129413 , 5.6866574 ],
|
||||||
[-4.063663 , -4.4520254 , 0.63115686]], dtype=float32)
|
[-4.063663 , -4.4520254 , 0.63115686]], dtype=float32)
|
||||||
"""
|
"""
|
||||||
@ -376,7 +376,7 @@ def glorot_uniform(in_axis: int | Sequence[int] = -2,
|
|||||||
|
|
||||||
>>> import jax, jax.numpy as jnp
|
>>> import jax, jax.numpy as jnp
|
||||||
>>> initializer = jax.nn.initializers.glorot_uniform()
|
>>> initializer = jax.nn.initializers.glorot_uniform()
|
||||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||||
Array([[ 0.50350785, 0.8088631 , 0.81566876],
|
Array([[ 0.50350785, 0.8088631 , 0.81566876],
|
||||||
[-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32)
|
[-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32)
|
||||||
|
|
||||||
@ -414,7 +414,7 @@ def glorot_normal(in_axis: int | Sequence[int] = -2,
|
|||||||
|
|
||||||
>>> import jax, jax.numpy as jnp
|
>>> import jax, jax.numpy as jnp
|
||||||
>>> initializer = jax.nn.initializers.glorot_normal()
|
>>> initializer = jax.nn.initializers.glorot_normal()
|
||||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||||
Array([[ 0.41770416, 0.75262755, 0.7619329 ],
|
Array([[ 0.41770416, 0.75262755, 0.7619329 ],
|
||||||
[-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32)
|
[-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32)
|
||||||
|
|
||||||
@ -452,7 +452,7 @@ def lecun_uniform(in_axis: int | Sequence[int] = -2,
|
|||||||
|
|
||||||
>>> import jax, jax.numpy as jnp
|
>>> import jax, jax.numpy as jnp
|
||||||
>>> initializer = jax.nn.initializers.lecun_uniform()
|
>>> initializer = jax.nn.initializers.lecun_uniform()
|
||||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||||
Array([[ 0.56293887, 0.90433645, 0.9119454 ],
|
Array([[ 0.56293887, 0.90433645, 0.9119454 ],
|
||||||
[-0.71479625, -0.7676109 , 0.12302713]], dtype=float32)
|
[-0.71479625, -0.7676109 , 0.12302713]], dtype=float32)
|
||||||
|
|
||||||
@ -488,7 +488,7 @@ def lecun_normal(in_axis: int | Sequence[int] = -2,
|
|||||||
|
|
||||||
>>> import jax, jax.numpy as jnp
|
>>> import jax, jax.numpy as jnp
|
||||||
>>> initializer = jax.nn.initializers.lecun_normal()
|
>>> initializer = jax.nn.initializers.lecun_normal()
|
||||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||||
Array([[ 0.46700746, 0.8414632 , 0.8518669 ],
|
Array([[ 0.46700746, 0.8414632 , 0.8518669 ],
|
||||||
[-0.61677957, -0.67402434, 0.09683388]], dtype=float32)
|
[-0.61677957, -0.67402434, 0.09683388]], dtype=float32)
|
||||||
|
|
||||||
@ -524,7 +524,7 @@ def he_uniform(in_axis: int | Sequence[int] = -2,
|
|||||||
|
|
||||||
>>> import jax, jax.numpy as jnp
|
>>> import jax, jax.numpy as jnp
|
||||||
>>> initializer = jax.nn.initializers.he_uniform()
|
>>> initializer = jax.nn.initializers.he_uniform()
|
||||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||||
Array([[ 0.79611576, 1.2789248 , 1.2896855 ],
|
Array([[ 0.79611576, 1.2789248 , 1.2896855 ],
|
||||||
[-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32)
|
[-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32)
|
||||||
|
|
||||||
@ -562,7 +562,7 @@ def he_normal(in_axis: int | Sequence[int] = -2,
|
|||||||
|
|
||||||
>>> import jax, jax.numpy as jnp
|
>>> import jax, jax.numpy as jnp
|
||||||
>>> initializer = jax.nn.initializers.he_normal()
|
>>> initializer = jax.nn.initializers.he_normal()
|
||||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||||
Array([[ 0.6604483 , 1.1900088 , 1.2047218 ],
|
Array([[ 0.6604483 , 1.1900088 , 1.2047218 ],
|
||||||
[-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32)
|
[-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32)
|
||||||
|
|
||||||
@ -595,7 +595,7 @@ def orthogonal(scale: RealNumeric = 1.0,
|
|||||||
|
|
||||||
>>> import jax, jax.numpy as jnp
|
>>> import jax, jax.numpy as jnp
|
||||||
>>> initializer = jax.nn.initializers.orthogonal()
|
>>> initializer = jax.nn.initializers.orthogonal()
|
||||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||||
Array([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01],
|
Array([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01],
|
||||||
[ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32)
|
[ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32)
|
||||||
"""
|
"""
|
||||||
@ -638,7 +638,7 @@ def delta_orthogonal(
|
|||||||
|
|
||||||
>>> import jax, jax.numpy as jnp
|
>>> import jax, jax.numpy as jnp
|
||||||
>>> initializer = jax.nn.initializers.delta_orthogonal()
|
>>> initializer = jax.nn.initializers.delta_orthogonal()
|
||||||
>>> initializer(jax.random.PRNGKey(42), (3, 3, 3), jnp.float32) # doctest: +SKIP
|
>>> initializer(jax.random.key(42), (3, 3, 3), jnp.float32) # doctest: +SKIP
|
||||||
Array([[[ 0. , 0. , 0. ],
|
Array([[[ 0. , 0. , 0. ],
|
||||||
[ 0. , 0. , 0. ],
|
[ 0. , 0. , 0. ],
|
||||||
[ 0. , 0. , 0. ]],
|
[ 0. , 0. , 0. ]],
|
||||||
|
@ -244,7 +244,7 @@ def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray:
|
|||||||
"""Folds in data to a PRNG key to form a new PRNG key.
|
"""Folds in data to a PRNG key to form a new PRNG key.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key: a PRNG key (from ``PRNGKey``, ``split``, ``fold_in``).
|
key: a PRNG key (from ``key``, ``split``, ``fold_in``).
|
||||||
data: a 32bit integer representing data to be folded in to the key.
|
data: a 32bit integer representing data to be folded in to the key.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -274,7 +274,7 @@ def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray:
|
|||||||
"""Splits a PRNG key into `num` new keys by adding a leading axis.
|
"""Splits a PRNG key into `num` new keys by adding a leading axis.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key: a PRNG key (from ``PRNGKey``, ``split``, ``fold_in``).
|
key: a PRNG key (from ``key``, ``split``, ``fold_in``).
|
||||||
num: optional, a positive integer (or tuple of integers) indicating
|
num: optional, a positive integer (or tuple of integers) indicating
|
||||||
the number (or shape) of keys to produce. Defaults to 2.
|
the number (or shape) of keys to produce. Defaults to 2.
|
||||||
|
|
||||||
|
@ -22,24 +22,25 @@ Basic usage
|
|||||||
|
|
||||||
>>> seed = 1701
|
>>> seed = 1701
|
||||||
>>> num_steps = 100
|
>>> num_steps = 100
|
||||||
>>> key = jax.random.PRNGKey(seed)
|
>>> key = jax.random.key(seed)
|
||||||
>>> for i in range(num_steps):
|
>>> for i in range(num_steps):
|
||||||
... key, subkey = jax.random.split(key)
|
... key, subkey = jax.random.split(key)
|
||||||
... params = compiled_update(subkey, params, next(batches)) # doctest: +SKIP
|
... params = compiled_update(subkey, params, next(batches)) # doctest: +SKIP
|
||||||
|
|
||||||
PRNG Keys
|
PRNG keys
|
||||||
---------
|
---------
|
||||||
|
|
||||||
Unlike the *stateful* pseudorandom number generators (PRNGs) that users of NumPy and
|
Unlike the *stateful* pseudorandom number generators (PRNGs) that users of NumPy and
|
||||||
SciPy may be accustomed to, JAX random functions all require an explicit PRNG state to
|
SciPy may be accustomed to, JAX random functions all require an explicit PRNG state to
|
||||||
be passed as a first argument.
|
be passed as a first argument.
|
||||||
The random state is described by two unsigned 32-bit integers that we call a **key**,
|
The random state is described by a special array element type that we call a **key**,
|
||||||
usually generated by the :py:func:`jax.random.PRNGKey` function::
|
usually generated by the :py:func:`jax.random.key` function::
|
||||||
|
|
||||||
>>> from jax import random
|
>>> from jax import random
|
||||||
>>> key = random.PRNGKey(0)
|
>>> key = random.key(0)
|
||||||
>>> key
|
>>> key
|
||||||
Array([0, 0], dtype=uint32)
|
Array((), dtype=key<fry>) overlaying:
|
||||||
|
[0 0]
|
||||||
|
|
||||||
This key can then be used in any of JAX's random number generation routines::
|
This key can then be used in any of JAX's random number generation routines::
|
||||||
|
|
||||||
@ -60,8 +61,8 @@ If you need a new random number, you can use :meth:`jax.random.split` to generat
|
|||||||
Advanced
|
Advanced
|
||||||
--------
|
--------
|
||||||
|
|
||||||
Design and Context
|
Design and background
|
||||||
==================
|
=====================
|
||||||
|
|
||||||
**TLDR**: JAX PRNG = `Threefry counter PRNG <http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_
|
**TLDR**: JAX PRNG = `Threefry counter PRNG <http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_
|
||||||
+ a functional array-oriented `splitting model <https://dl.acm.org/citation.cfm?id=2503784>`_
|
+ a functional array-oriented `splitting model <https://dl.acm.org/citation.cfm?id=2503784>`_
|
||||||
@ -79,16 +80,19 @@ To summarize, among other requirements, the JAX PRNG aims to:
|
|||||||
Advanced RNG configuration
|
Advanced RNG configuration
|
||||||
==========================
|
==========================
|
||||||
|
|
||||||
JAX provides several PRNG implementations (controlled by the
|
JAX provides several PRNG implementations. A specific one can be
|
||||||
`jax_default_prng_impl` flag).
|
selected with the optional `impl` keyword argument to
|
||||||
|
`jax.random.key`. When no `impl` option is passed to the `key`
|
||||||
|
constructor, the implementation is determined by the global
|
||||||
|
`jax_default_prng_impl` configuration flag.
|
||||||
|
|
||||||
- **default**
|
- **default**, `"threefry2x32"`:
|
||||||
`A counter-based PRNG built around the Threefry hash function <http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_.
|
`A counter-based PRNG built around the Threefry hash function <http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_.
|
||||||
- *experimental* A PRNG that thinly wraps the XLA Random Bit Generator (RBG) algorithm. See
|
- *experimental* A PRNG that thinly wraps the XLA Random Bit Generator (RBG) algorithm. See
|
||||||
`TF doc <https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator>`_.
|
`TF doc <https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator>`_.
|
||||||
|
|
||||||
- "rbg" uses ThreeFry for splitting, and XLA RBG for data generation.
|
- `"rbg"` uses ThreeFry for splitting, and XLA RBG for data generation.
|
||||||
- "unsafe_rbg" exists only for demonstration purposes, using RBG both for
|
- `"unsafe_rbg"` exists only for demonstration purposes, using RBG both for
|
||||||
splitting (using an untested made up algorithm) and generating.
|
splitting (using an untested made up algorithm) and generating.
|
||||||
|
|
||||||
The random streams generated by these experimental implementations haven't
|
The random streams generated by these experimental implementations haven't
|
||||||
@ -126,7 +130,7 @@ robust hash functions for `jax.random.split` and `jax.random.fold_in`. Therefore
|
|||||||
less safe in the sense that the quality of random streams it generates from
|
less safe in the sense that the quality of random streams it generates from
|
||||||
different keys is less well understood.
|
different keys is less well understood.
|
||||||
|
|
||||||
For more about jax_threefry_partitionable, see
|
For more about `jax_threefry_partitionable`, see
|
||||||
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
|
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user