diff --git a/docs/jax.nn.initializers.rst b/docs/jax.nn.initializers.rst index d96ba43f0..246e0cdbe 100644 --- a/docs/jax.nn.initializers.rst +++ b/docs/jax.nn.initializers.rst @@ -14,8 +14,8 @@ consistent with definitions used in Keras and Sonnet. An initializer is a function that takes three arguments: ``(key, shape, dtype)`` and returns an array with dimensions ``shape`` and -data type ``dtype``. Argument ``key`` is a :class:`jax.random.PRNGKey` random -key used when generating random numbers to initialize the array. +data type ``dtype``. Argument ``key`` is a PRNG key (e.g. from +:func:`jax.random.key`), used to generate random numbers to initialize the array. .. autosummary:: :toctree: _autosummary diff --git a/jax/_src/api.py b/jax/_src/api.py index 005e0ceca..2fecc4fd7 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -280,7 +280,7 @@ def jit( ... def selu(x, alpha=1.67, lmbda=1.05): ... return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha) >>> - >>> key = jax.random.PRNGKey(0) + >>> key = jax.random.key(0) >>> x = jax.random.normal(key, (10,)) >>> print(selu(x)) # doctest: +SKIP [-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748 diff --git a/jax/_src/core.py b/jax/_src/core.py index c351e6980..b207b39e5 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1277,7 +1277,7 @@ def ensure_compile_time_eval(): @jax.jit def jax_fn(x): with jax.ensure_compile_time_eval(): - y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100) + y = random.randint(random.key(0), (1000,1000), 0, 100) y2 = y @ y x2 = jnp.sum(y2) * x return x2 @@ -1285,7 +1285,7 @@ def ensure_compile_time_eval(): A similar behavior can often be achieved simply by 'hoisting' the constant expression out of the corresponding staging API:: - y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100) + y = random.randint(random.key(0), (1000,1000), 0, 100) @jax.jit def jax_fn(x): diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index e939f2170..0c6ae2fe0 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -2101,7 +2101,7 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0): Example 2: partial products of an array of matrices - >>> mats = jax.random.uniform(jax.random.PRNGKey(0), (4, 2, 2)) + >>> mats = jax.random.uniform(jax.random.key(0), (4, 2, 2)) >>> partial_prods = lax.associative_scan(jnp.matmul, mats) >>> partial_prods.shape (4, 2, 2) diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index 5468fd663..d7353c396 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -62,7 +62,7 @@ def zeros(key: KeyArray, The ``key`` argument is ignored. >>> import jax, jax.numpy as jnp - >>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32) + >>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32) """ @@ -77,7 +77,7 @@ def ones(key: KeyArray, The ``key`` argument is ignored. >>> import jax, jax.numpy as jnp - >>> jax.nn.initializers.ones(jax.random.PRNGKey(42), (3, 2), jnp.float32) + >>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32) Array([[1., 1.], [1., 1.], [1., 1.]], dtype=float32) @@ -96,7 +96,7 @@ def constant(value: ArrayLike, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.constant(-7) - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[-7., -7., -7.], [-7., -7., -7.]], dtype=float32) """ @@ -122,7 +122,7 @@ def uniform(scale: RealNumeric = 1e-2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.uniform(10.0) - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[7.298188 , 8.691938 , 8.7230015], [2.0818567, 1.8662417, 5.5022564]], dtype=float32) """ @@ -148,7 +148,7 @@ def normal(stddev: RealNumeric = 1e-2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.normal(5.0) - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 3.0613258 , 5.6129413 , 5.6866574 ], [-4.063663 , -4.4520254 , 0.63115686]], dtype=float32) """ @@ -376,7 +376,7 @@ def glorot_uniform(in_axis: int | Sequence[int] = -2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.glorot_uniform() - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 0.50350785, 0.8088631 , 0.81566876], [-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32) @@ -414,7 +414,7 @@ def glorot_normal(in_axis: int | Sequence[int] = -2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.glorot_normal() - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 0.41770416, 0.75262755, 0.7619329 ], [-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32) @@ -452,7 +452,7 @@ def lecun_uniform(in_axis: int | Sequence[int] = -2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.lecun_uniform() - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 0.56293887, 0.90433645, 0.9119454 ], [-0.71479625, -0.7676109 , 0.12302713]], dtype=float32) @@ -488,7 +488,7 @@ def lecun_normal(in_axis: int | Sequence[int] = -2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.lecun_normal() - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 0.46700746, 0.8414632 , 0.8518669 ], [-0.61677957, -0.67402434, 0.09683388]], dtype=float32) @@ -524,7 +524,7 @@ def he_uniform(in_axis: int | Sequence[int] = -2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.he_uniform() - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 0.79611576, 1.2789248 , 1.2896855 ], [-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32) @@ -562,7 +562,7 @@ def he_normal(in_axis: int | Sequence[int] = -2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.he_normal() - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 0.6604483 , 1.1900088 , 1.2047218 ], [-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32) @@ -595,7 +595,7 @@ def orthogonal(scale: RealNumeric = 1.0, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.orthogonal() - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01], [ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32) """ @@ -638,7 +638,7 @@ def delta_orthogonal( >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.delta_orthogonal() - >>> initializer(jax.random.PRNGKey(42), (3, 3, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (3, 3, 3), jnp.float32) # doctest: +SKIP Array([[[ 0. , 0. , 0. ], [ 0. , 0. , 0. ], [ 0. , 0. , 0. ]], diff --git a/jax/_src/random.py b/jax/_src/random.py index f9045ebf3..2696b488a 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -244,7 +244,7 @@ def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray: """Folds in data to a PRNG key to form a new PRNG key. Args: - key: a PRNG key (from ``PRNGKey``, ``split``, ``fold_in``). + key: a PRNG key (from ``key``, ``split``, ``fold_in``). data: a 32bit integer representing data to be folded in to the key. Returns: @@ -274,7 +274,7 @@ def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray: """Splits a PRNG key into `num` new keys by adding a leading axis. Args: - key: a PRNG key (from ``PRNGKey``, ``split``, ``fold_in``). + key: a PRNG key (from ``key``, ``split``, ``fold_in``). num: optional, a positive integer (or tuple of integers) indicating the number (or shape) of keys to produce. Defaults to 2. diff --git a/jax/random.py b/jax/random.py index c06f48b35..f65ec5858 100644 --- a/jax/random.py +++ b/jax/random.py @@ -22,24 +22,25 @@ Basic usage >>> seed = 1701 >>> num_steps = 100 ->>> key = jax.random.PRNGKey(seed) +>>> key = jax.random.key(seed) >>> for i in range(num_steps): ... key, subkey = jax.random.split(key) ... params = compiled_update(subkey, params, next(batches)) # doctest: +SKIP -PRNG Keys +PRNG keys --------- Unlike the *stateful* pseudorandom number generators (PRNGs) that users of NumPy and SciPy may be accustomed to, JAX random functions all require an explicit PRNG state to be passed as a first argument. -The random state is described by two unsigned 32-bit integers that we call a **key**, -usually generated by the :py:func:`jax.random.PRNGKey` function:: +The random state is described by a special array element type that we call a **key**, +usually generated by the :py:func:`jax.random.key` function:: >>> from jax import random - >>> key = random.PRNGKey(0) + >>> key = random.key(0) >>> key - Array([0, 0], dtype=uint32) + Array((), dtype=key) overlaying: + [0 0] This key can then be used in any of JAX's random number generation routines:: @@ -60,8 +61,8 @@ If you need a new random number, you can use :meth:`jax.random.split` to generat Advanced -------- -Design and Context -================== +Design and background +===================== **TLDR**: JAX PRNG = `Threefry counter PRNG `_ + a functional array-oriented `splitting model `_ @@ -79,16 +80,19 @@ To summarize, among other requirements, the JAX PRNG aims to: Advanced RNG configuration ========================== -JAX provides several PRNG implementations (controlled by the -`jax_default_prng_impl` flag). +JAX provides several PRNG implementations. A specific one can be +selected with the optional `impl` keyword argument to +`jax.random.key`. When no `impl` option is passed to the `key` +constructor, the implementation is determined by the global +`jax_default_prng_impl` configuration flag. -- **default** +- **default**, `"threefry2x32"`: `A counter-based PRNG built around the Threefry hash function `_. - *experimental* A PRNG that thinly wraps the XLA Random Bit Generator (RBG) algorithm. See `TF doc `_. - - "rbg" uses ThreeFry for splitting, and XLA RBG for data generation. - - "unsafe_rbg" exists only for demonstration purposes, using RBG both for + - `"rbg"` uses ThreeFry for splitting, and XLA RBG for data generation. + - `"unsafe_rbg"` exists only for demonstration purposes, using RBG both for splitting (using an untested made up algorithm) and generating. The random streams generated by these experimental implementations haven't @@ -126,7 +130,7 @@ robust hash functions for `jax.random.split` and `jax.random.fold_in`. Therefore less safe in the sense that the quality of random streams it generates from different keys is less well understood. -For more about jax_threefry_partitionable, see +For more about `jax_threefry_partitionable`, see https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers """