mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
update package/API reference docs to new-style typed PRNG keys
This commit is contained in:
parent
f0afc1b43d
commit
98f790f5d5
@ -14,8 +14,8 @@ consistent with definitions used in Keras and Sonnet.
|
||||
|
||||
An initializer is a function that takes three arguments:
|
||||
``(key, shape, dtype)`` and returns an array with dimensions ``shape`` and
|
||||
data type ``dtype``. Argument ``key`` is a :class:`jax.random.PRNGKey` random
|
||||
key used when generating random numbers to initialize the array.
|
||||
data type ``dtype``. Argument ``key`` is a PRNG key (e.g. from
|
||||
:func:`jax.random.key`), used to generate random numbers to initialize the array.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
@ -280,7 +280,7 @@ def jit(
|
||||
... def selu(x, alpha=1.67, lmbda=1.05):
|
||||
... return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
|
||||
>>>
|
||||
>>> key = jax.random.PRNGKey(0)
|
||||
>>> key = jax.random.key(0)
|
||||
>>> x = jax.random.normal(key, (10,))
|
||||
>>> print(selu(x)) # doctest: +SKIP
|
||||
[-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748
|
||||
|
@ -1277,7 +1277,7 @@ def ensure_compile_time_eval():
|
||||
@jax.jit
|
||||
def jax_fn(x):
|
||||
with jax.ensure_compile_time_eval():
|
||||
y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100)
|
||||
y = random.randint(random.key(0), (1000,1000), 0, 100)
|
||||
y2 = y @ y
|
||||
x2 = jnp.sum(y2) * x
|
||||
return x2
|
||||
@ -1285,7 +1285,7 @@ def ensure_compile_time_eval():
|
||||
A similar behavior can often be achieved simply by 'hoisting' the constant
|
||||
expression out of the corresponding staging API::
|
||||
|
||||
y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100)
|
||||
y = random.randint(random.key(0), (1000,1000), 0, 100)
|
||||
|
||||
@jax.jit
|
||||
def jax_fn(x):
|
||||
|
@ -2101,7 +2101,7 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
|
||||
|
||||
Example 2: partial products of an array of matrices
|
||||
|
||||
>>> mats = jax.random.uniform(jax.random.PRNGKey(0), (4, 2, 2))
|
||||
>>> mats = jax.random.uniform(jax.random.key(0), (4, 2, 2))
|
||||
>>> partial_prods = lax.associative_scan(jnp.matmul, mats)
|
||||
>>> partial_prods.shape
|
||||
(4, 2, 2)
|
||||
|
@ -62,7 +62,7 @@ def zeros(key: KeyArray,
|
||||
The ``key`` argument is ignored.
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32)
|
||||
>>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32)
|
||||
Array([[0., 0., 0.],
|
||||
[0., 0., 0.]], dtype=float32)
|
||||
"""
|
||||
@ -77,7 +77,7 @@ def ones(key: KeyArray,
|
||||
The ``key`` argument is ignored.
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> jax.nn.initializers.ones(jax.random.PRNGKey(42), (3, 2), jnp.float32)
|
||||
>>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32)
|
||||
Array([[1., 1.],
|
||||
[1., 1.],
|
||||
[1., 1.]], dtype=float32)
|
||||
@ -96,7 +96,7 @@ def constant(value: ArrayLike,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.constant(-7)
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
|
||||
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)
|
||||
Array([[-7., -7., -7.],
|
||||
[-7., -7., -7.]], dtype=float32)
|
||||
"""
|
||||
@ -122,7 +122,7 @@ def uniform(scale: RealNumeric = 1e-2,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.uniform(10.0)
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
Array([[7.298188 , 8.691938 , 8.7230015],
|
||||
[2.0818567, 1.8662417, 5.5022564]], dtype=float32)
|
||||
"""
|
||||
@ -148,7 +148,7 @@ def normal(stddev: RealNumeric = 1e-2,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.normal(5.0)
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
Array([[ 3.0613258 , 5.6129413 , 5.6866574 ],
|
||||
[-4.063663 , -4.4520254 , 0.63115686]], dtype=float32)
|
||||
"""
|
||||
@ -376,7 +376,7 @@ def glorot_uniform(in_axis: int | Sequence[int] = -2,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.glorot_uniform()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
Array([[ 0.50350785, 0.8088631 , 0.81566876],
|
||||
[-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32)
|
||||
|
||||
@ -414,7 +414,7 @@ def glorot_normal(in_axis: int | Sequence[int] = -2,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.glorot_normal()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
Array([[ 0.41770416, 0.75262755, 0.7619329 ],
|
||||
[-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32)
|
||||
|
||||
@ -452,7 +452,7 @@ def lecun_uniform(in_axis: int | Sequence[int] = -2,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.lecun_uniform()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
Array([[ 0.56293887, 0.90433645, 0.9119454 ],
|
||||
[-0.71479625, -0.7676109 , 0.12302713]], dtype=float32)
|
||||
|
||||
@ -488,7 +488,7 @@ def lecun_normal(in_axis: int | Sequence[int] = -2,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.lecun_normal()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
Array([[ 0.46700746, 0.8414632 , 0.8518669 ],
|
||||
[-0.61677957, -0.67402434, 0.09683388]], dtype=float32)
|
||||
|
||||
@ -524,7 +524,7 @@ def he_uniform(in_axis: int | Sequence[int] = -2,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.he_uniform()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
Array([[ 0.79611576, 1.2789248 , 1.2896855 ],
|
||||
[-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32)
|
||||
|
||||
@ -562,7 +562,7 @@ def he_normal(in_axis: int | Sequence[int] = -2,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.he_normal()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
Array([[ 0.6604483 , 1.1900088 , 1.2047218 ],
|
||||
[-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32)
|
||||
|
||||
@ -595,7 +595,7 @@ def orthogonal(scale: RealNumeric = 1.0,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.orthogonal()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
Array([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01],
|
||||
[ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32)
|
||||
"""
|
||||
@ -638,7 +638,7 @@ def delta_orthogonal(
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.delta_orthogonal()
|
||||
>>> initializer(jax.random.PRNGKey(42), (3, 3, 3), jnp.float32) # doctest: +SKIP
|
||||
>>> initializer(jax.random.key(42), (3, 3, 3), jnp.float32) # doctest: +SKIP
|
||||
Array([[[ 0. , 0. , 0. ],
|
||||
[ 0. , 0. , 0. ],
|
||||
[ 0. , 0. , 0. ]],
|
||||
|
@ -244,7 +244,7 @@ def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray:
|
||||
"""Folds in data to a PRNG key to form a new PRNG key.
|
||||
|
||||
Args:
|
||||
key: a PRNG key (from ``PRNGKey``, ``split``, ``fold_in``).
|
||||
key: a PRNG key (from ``key``, ``split``, ``fold_in``).
|
||||
data: a 32bit integer representing data to be folded in to the key.
|
||||
|
||||
Returns:
|
||||
@ -274,7 +274,7 @@ def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray:
|
||||
"""Splits a PRNG key into `num` new keys by adding a leading axis.
|
||||
|
||||
Args:
|
||||
key: a PRNG key (from ``PRNGKey``, ``split``, ``fold_in``).
|
||||
key: a PRNG key (from ``key``, ``split``, ``fold_in``).
|
||||
num: optional, a positive integer (or tuple of integers) indicating
|
||||
the number (or shape) of keys to produce. Defaults to 2.
|
||||
|
||||
|
@ -22,24 +22,25 @@ Basic usage
|
||||
|
||||
>>> seed = 1701
|
||||
>>> num_steps = 100
|
||||
>>> key = jax.random.PRNGKey(seed)
|
||||
>>> key = jax.random.key(seed)
|
||||
>>> for i in range(num_steps):
|
||||
... key, subkey = jax.random.split(key)
|
||||
... params = compiled_update(subkey, params, next(batches)) # doctest: +SKIP
|
||||
|
||||
PRNG Keys
|
||||
PRNG keys
|
||||
---------
|
||||
|
||||
Unlike the *stateful* pseudorandom number generators (PRNGs) that users of NumPy and
|
||||
SciPy may be accustomed to, JAX random functions all require an explicit PRNG state to
|
||||
be passed as a first argument.
|
||||
The random state is described by two unsigned 32-bit integers that we call a **key**,
|
||||
usually generated by the :py:func:`jax.random.PRNGKey` function::
|
||||
The random state is described by a special array element type that we call a **key**,
|
||||
usually generated by the :py:func:`jax.random.key` function::
|
||||
|
||||
>>> from jax import random
|
||||
>>> key = random.PRNGKey(0)
|
||||
>>> key = random.key(0)
|
||||
>>> key
|
||||
Array([0, 0], dtype=uint32)
|
||||
Array((), dtype=key<fry>) overlaying:
|
||||
[0 0]
|
||||
|
||||
This key can then be used in any of JAX's random number generation routines::
|
||||
|
||||
@ -60,8 +61,8 @@ If you need a new random number, you can use :meth:`jax.random.split` to generat
|
||||
Advanced
|
||||
--------
|
||||
|
||||
Design and Context
|
||||
==================
|
||||
Design and background
|
||||
=====================
|
||||
|
||||
**TLDR**: JAX PRNG = `Threefry counter PRNG <http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_
|
||||
+ a functional array-oriented `splitting model <https://dl.acm.org/citation.cfm?id=2503784>`_
|
||||
@ -79,16 +80,19 @@ To summarize, among other requirements, the JAX PRNG aims to:
|
||||
Advanced RNG configuration
|
||||
==========================
|
||||
|
||||
JAX provides several PRNG implementations (controlled by the
|
||||
`jax_default_prng_impl` flag).
|
||||
JAX provides several PRNG implementations. A specific one can be
|
||||
selected with the optional `impl` keyword argument to
|
||||
`jax.random.key`. When no `impl` option is passed to the `key`
|
||||
constructor, the implementation is determined by the global
|
||||
`jax_default_prng_impl` configuration flag.
|
||||
|
||||
- **default**
|
||||
- **default**, `"threefry2x32"`:
|
||||
`A counter-based PRNG built around the Threefry hash function <http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_.
|
||||
- *experimental* A PRNG that thinly wraps the XLA Random Bit Generator (RBG) algorithm. See
|
||||
`TF doc <https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator>`_.
|
||||
|
||||
- "rbg" uses ThreeFry for splitting, and XLA RBG for data generation.
|
||||
- "unsafe_rbg" exists only for demonstration purposes, using RBG both for
|
||||
- `"rbg"` uses ThreeFry for splitting, and XLA RBG for data generation.
|
||||
- `"unsafe_rbg"` exists only for demonstration purposes, using RBG both for
|
||||
splitting (using an untested made up algorithm) and generating.
|
||||
|
||||
The random streams generated by these experimental implementations haven't
|
||||
@ -126,7 +130,7 @@ robust hash functions for `jax.random.split` and `jax.random.fold_in`. Therefore
|
||||
less safe in the sense that the quality of random streams it generates from
|
||||
different keys is less well understood.
|
||||
|
||||
For more about jax_threefry_partitionable, see
|
||||
For more about `jax_threefry_partitionable`, see
|
||||
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
|
||||
"""
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user