rocm_jax/jax/_src/random.py

2772 lines
102 KiB
Python

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
from functools import partial
import math
from operator import index
import typing
from typing import Union
import warnings
import numpy as np
import jax.numpy as jnp
from jax import lax
from jax.numpy.linalg import cholesky, svd, eigh
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import prng
from jax._src import xla_bridge
from jax._src.api import jit, vmap
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.lax import lax as lax_internal
from jax._src.numpy.lax_numpy import _convert_and_clip_integer
from jax._src.numpy.util import _arraylike, check_arraylike, promote_dtypes_inexact
from jax._src.typing import Array, ArrayLike, DTypeLike
from jax._src.util import canonicalize_axis
RealArray = ArrayLike
IntegerArray = ArrayLike
# TODO: Import or define these to match
# https://github.com/numpy/numpy/blob/main/numpy/typing/_dtype_like.py.
DTypeLikeInt = DTypeLike
DTypeLikeUInt = DTypeLike
DTypeLikeFloat = DTypeLike
Shape = Sequence[int]
PRNGImpl = prng.PRNGImpl
UINT_DTYPES = prng.UINT_DTYPES
### utilities
_lax_const = lax_internal._const
def _isnan(x: ArrayLike) -> Array:
return lax.ne(x, x)
def _check_prng_key(name: str, key: ArrayLike, *,
allow_batched: bool = False) -> tuple[Array, bool]:
if isinstance(key, Array) and dtypes.issubdtype(key.dtype, dtypes.prng_key):
wrapped_key = key
wrapped = False
elif _arraylike(key):
# Call random_wrap here to surface errors for invalid keys.
wrapped_key = prng.random_wrap(key, impl=default_prng_impl())
wrapped = True
if config.legacy_prng_key.value == 'error':
raise ValueError(
'Legacy uint32 key array passed as key to jax.random function. '
'Please create keys using jax.random.key(). If use of a raw key array '
'was intended, set jax_legacy_prng_key="allow".')
elif config.legacy_prng_key.value == 'warn':
warnings.warn(
'Legacy uint32 key array passed as key to jax.random function. '
'Please create keys using jax.random.key(). If use of a raw key array '
'was intended, set jax_legacy_prng_key="allow".', stacklevel=2)
elif config.enable_custom_prng.value:
# TODO(jakevdp): possibly remove this warning condition.
warnings.warn(
'Raw arrays as random keys to jax.random functions are deprecated. '
'Assuming valid threefry2x32 key for now.',
FutureWarning)
else:
raise TypeError(f'unexpected PRNG key type {type(key)}')
if (not allow_batched) and wrapped_key.ndim:
raise ValueError(f"{name} accepts a single key, but was given a key array of"
f" shape {np.shape(key)} != (). Use jax.vmap for batching.")
return wrapped_key, wrapped
def _return_prng_keys(was_wrapped, key):
# TODO(frostig): remove once we always enable_custom_prng
assert jnp.issubdtype(key.dtype, dtypes.prng_key)
if config.enable_custom_prng.value:
return key
else:
return prng.random_unwrap(key) if was_wrapped else key
def _random_bits(key: Array, bit_width: int, shape: Shape) -> Array:
assert jnp.issubdtype(key.dtype, dtypes.prng_key)
return prng.random_bits(key, bit_width=bit_width, shape=shape)
# TODO(frostig,vanderplas): remove from public API altogether, or at
# least change to return after asserting presence in `prng.prngs`
def default_prng_impl():
"""Get the default PRNG implementation.
The default implementation is determined by ``config.jax_default_prng_impl``,
which specifies it by name.
"""
impl_name = config.default_prng_impl.value
assert impl_name in prng.prngs, impl_name
return prng.prngs[impl_name]
### key operations
# Wrapper around prng.PRNGImpl meant to hide its attributes from the
# public API.
# TODO(frostig,vanderplas): consider hiding all the attributes of
# PRNGImpl and directly returning it.
class PRNGSpec:
"""Specifies a PRNG key implementation."""
__slots__ = ['_impl']
_impl: PRNGImpl
def __init__(self, impl):
self._impl = impl
def __repr__(self) -> str:
return f"PRNGSpec({self._impl.name!r})"
def __str__(self) -> str:
return str(self._impl)
def __hash__(self) -> int:
return hash(self._impl)
def __eq__(self, other) -> bool:
return isinstance(other, PRNGSpec) and self._impl == other._impl
# TODO(frostig,vanderplas): remove PRNGImpl from this union when it's
# no longer in the public API because `default_prng_impl` is gone
PRNGSpecDesc = Union[str, PRNGSpec, PRNGImpl]
def resolve_prng_impl(impl_spec: PRNGSpecDesc | None) -> PRNGImpl:
if impl_spec is None:
return default_prng_impl()
if type(impl_spec) is PRNGImpl:
# TODO(frostig,vanderplas): remove this case once we remove
# default_prng_impl (and thus PRNGImpl) from the public API and
# PRNGImpl from jex. We won't need to handle these then, and we
# can remove them from the input type annotation above as well.
return impl_spec
if type(impl_spec) is PRNGSpec:
return impl_spec._impl
if type(impl_spec) is str:
if impl_spec in prng.prngs:
return prng.prngs[impl_spec]
keys_fmt = ', '.join(f'"{s}"' for s in prng.prngs.keys())
raise ValueError(f'unrecognized PRNG implementation "{impl_spec}". '
f'Did you mean one of: {keys_fmt}?')
t = type(impl_spec)
raise TypeError(f'unrecognized type {t} for specifying PRNG implementation.')
def _key(ctor_name: str, seed: int | ArrayLike,
impl_spec: PRNGSpecDesc | None) -> Array:
impl = resolve_prng_impl(impl_spec)
if hasattr(seed, 'dtype') and jnp.issubdtype(seed.dtype, dtypes.prng_key):
raise TypeError(
f"{ctor_name} accepts a scalar seed, but was given a PRNG key.")
if np.ndim(seed):
raise TypeError(
f"{ctor_name} accepts a scalar seed, but was given an array of "
f"shape {np.shape(seed)} != (). Use jax.vmap for batching")
return prng.random_seed(seed, impl=impl)
def key(seed: int | ArrayLike, *,
impl: PRNGSpecDesc | None = None) -> Array:
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
The result is a scalar array containing a key, whose dtype indicates
the default PRNG implementation, as determined by the optional
``impl`` argument or, otherwise, by the ``jax_default_prng_impl``
config flag at the time when this function is called.
Args:
seed: a 64- or 32-bit integer used as the value of the key.
impl: optional string specifying the PRNG implementation (e.g.
``'threefry2x32'``)
Returns:
A scalar PRNG key array, consumable by random functions as well as ``split``
and ``fold_in``.
"""
return _key('key', seed, impl)
def PRNGKey(seed: int | ArrayLike, *,
impl: PRNGSpecDesc | None = None) -> Array:
"""Create a legacy PRNG key given an integer seed.
This function produces old-style legacy PRNG keys, which are arrays
of dtype ``uint32``. For more, see the note in the `PRNG keys
<https://jax.readthedocs.io/en/latest/jax.random.html#prng-keys>`_
section. When possible, :func:`jax.random.key` is recommended for
use instead.
The resulting key does not carry a PRNG implementation. The returned
key matches the implementation given by the optional ``impl``
argument or, otherwise, determined by the ``jax_default_prng_impl``
config flag. Callers must ensure that same implementation is set as
the default when passing this key as an argument to other functions
(such as ``jax.random.split`` and ``jax.random.normal``).
Args:
seed: a 64- or 32-bit integer used as the value of the key.
impl: optional string specifying the PRNG implementation (e.g.
``'threefry2x32'``)
Returns:
A PRNG key, consumable by random functions as well as ``split``
and ``fold_in``.
"""
return _return_prng_keys(True, _key('PRNGKey', seed, impl))
def fold_in(key: ArrayLike, data: IntegerArray) -> Array:
"""Folds in data to a PRNG key to form a new PRNG key.
Args:
key: a PRNG key (from ``key``, ``split``, ``fold_in``).
data: a 32-bit integer representing data to be folded into the key.
Returns:
A new PRNG key that is a deterministic function of the inputs and is
statistically safe for producing a stream of new pseudo-random values.
"""
key, wrapped = _check_prng_key("fold_in", key)
if np.ndim(data):
raise TypeError("fold_in accepts a scalar, but was given an array of"
f"shape {np.shape(data)} != (). Use jax.vmap for batching.")
key_out = prng.random_fold_in(key, jnp.uint32(data))
return _return_prng_keys(wrapped, key_out)
def _split(key: Array, num: int | tuple[int, ...] = 2) -> Array:
# Alternative to split() to use within random samplers.
# TODO(frostig): remove and use split(); we no longer need to wait
# to always enable_custom_prng
assert jnp.issubdtype(key.dtype, dtypes.prng_key)
if key.ndim:
raise TypeError("split accepts a single key, but was given a key array of "
f"shape {key.shape} != (). Use jax.vmap for batching.")
shape = tuple(num) if isinstance(num, Sequence) else (num,)
return prng.random_split(key, shape=shape)
def split(key: ArrayLike, num: int | tuple[int, ...] = 2) -> Array:
"""Splits a PRNG key into `num` new keys by adding a leading axis.
Args:
key: a PRNG key (from ``key``, ``split``, ``fold_in``).
num: optional, a positive integer (or tuple of integers) indicating
the number (or shape) of keys to produce. Defaults to 2.
Returns:
An array-like object of `num` new PRNG keys.
"""
typed_key, wrapped = _check_prng_key("split", key)
return _return_prng_keys(wrapped, _split(typed_key, num))
def _key_impl(keys: Array) -> PRNGImpl:
assert jnp.issubdtype(keys.dtype, dtypes.prng_key)
keys_dtype = typing.cast(prng.KeyTy, keys.dtype)
return keys_dtype._impl
def _key_spec(keys: Array) -> str | PRNGSpec:
impl = _key_impl(keys)
return impl.name if impl.name in prng.prngs else PRNGSpec(impl)
def key_impl(keys: ArrayLike) -> str | PRNGSpec:
typed_keys, _ = _check_prng_key("key_impl", keys, allow_batched=True)
return _key_spec(typed_keys)
def _key_data(keys: Array) -> Array:
assert jnp.issubdtype(keys.dtype, dtypes.prng_key)
return prng.random_unwrap(keys)
def key_data(keys: ArrayLike) -> Array:
"""Recover the bits of key data underlying a PRNG key array."""
keys, _ = _check_prng_key("key_data", keys, allow_batched=True)
return _key_data(keys)
def wrap_key_data(key_bits_array: Array, *,
impl: PRNGSpecDesc | None = None):
"""Wrap an array of key data bits into a PRNG key array.
Args:
key_bits_array: a ``uint32`` array with trailing shape corresponding to
the key shape of the PRNG implementation specified by ``impl``.
impl: optional, specifies a PRNG implementation, as in ``random.key``.
Returns:
A PRNG key array, whose dtype is a subdtype of ``jax.dtypes.prng_key``
corresponding to ``impl``, and whose shape equals the leading shape
of ``key_bits_array.shape`` up to the key bit dimensions.
"""
impl_obj = resolve_prng_impl(impl)
return prng.random_wrap(key_bits_array, impl=impl_obj)
### random samplers
def _check_shape(name: str, shape: Shape, *param_shapes) -> None:
if param_shapes:
shape_ = lax.broadcast_shapes(shape, *param_shapes) # type: ignore
if shape != shape_:
msg = ("{} parameter shapes must be broadcast-compatible with shape "
"argument, and the result of broadcasting the shapes must equal "
"the shape argument, but got result {} for shape argument {}.")
raise ValueError(msg.format(name, shape_, shape))
def bits(key: ArrayLike,
shape: Shape = (),
dtype: DTypeLikeUInt | None = None) -> Array:
"""Sample uniform bits in the form of unsigned integers.
Args:
key: a PRNG key used as the random key.
shape: optional, a tuple of nonnegative integers representing the result
shape. Default ``()``.
dtype: optional, an unsigned integer dtype for the returned values (default
``uint64`` if ``jax_enable_x64`` is true, otherwise ``uint32``).
Returns:
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key("bits", key)
if dtype is None:
dtype = dtypes.canonicalize_dtype(jnp.uint)
else:
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.unsignedinteger):
raise ValueError("dtype argument to `bits` must be an unsigned int dtype, "
f"got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
shape = core.canonicalize_shape(shape)
bit_width = dtype.itemsize * 8
return _random_bits(key, bit_width, shape)
def uniform(key: ArrayLike,
shape: Shape = (),
dtype: DTypeLikeFloat = float,
minval: RealArray = 0.,
maxval: RealArray = 1.) -> Array:
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
Args:
key: a PRNG key used as the random key.
shape: optional, a tuple of nonnegative integers representing the result
shape. Default ().
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
minval: optional, a minimum (inclusive) value broadcast-compatible with shape for the range (default 0).
maxval: optional, a maximum (exclusive) value broadcast-compatible with shape for the range (default 1).
Returns:
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key("uniform", key)
dtypes.check_user_dtype_supported(dtype)
shape = core.canonicalize_shape(shape)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `uniform` must be a float dtype, "
f"got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
return _uniform(key, shape, dtype, minval, maxval)
@partial(jit, static_argnums=(1, 2))
def _uniform(key, shape, dtype, minval, maxval) -> Array:
_check_shape("uniform", shape)
if not jnp.issubdtype(dtype, np.floating):
raise TypeError("uniform only accepts floating point dtypes.")
minval = lax.convert_element_type(minval, dtype)
maxval = lax.convert_element_type(maxval, dtype)
minval = lax.broadcast_to_rank(minval, len(shape))
maxval = lax.broadcast_to_rank(maxval, len(shape))
finfo = jnp.finfo(dtype)
nbits, nmant = finfo.bits, finfo.nmant
if nbits not in (8, 16, 32, 64):
raise TypeError(
f"uniform only accepts 8-, 16-, 32-, or 64-bit dtypesgot {dtype}."
)
rng_bits = nbits
if nmant < 8:
rng_bits = 8
bits = _random_bits(key, rng_bits, shape)
uint_dtype = UINT_DTYPES[nbits]
if rng_bits != nbits:
bits = lax.convert_element_type(bits, uint_dtype)
# The strategy here is to randomize only the mantissa bits with an exponent of
# 1 (after applying the bias), then shift and scale to the desired range. The
# bit-level transformation we use relies on Numpy and XLA having bit-for-bit
# equivalent float representations, which might not be true on all platforms.
float_bits = lax.bitwise_or(
lax.shift_right_logical(bits, np.array(rng_bits - nmant, uint_dtype)),
np.array(1.0, dtype).view(uint_dtype),
)
floats = lax.bitcast_convert_type(float_bits, dtype) - np.array(1., dtype)
return lax.max(
minval,
lax.reshape(floats * (maxval - minval) + minval, shape))
def randint(key: ArrayLike,
shape: Shape,
minval: IntegerArray,
maxval: IntegerArray,
dtype: DTypeLikeInt = int) -> Array:
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
Args:
key: a PRNG key used as the random key.
shape: a tuple of nonnegative integers representing the shape.
minval: int or array of ints broadcast-compatible with ``shape``, a minimum
(inclusive) value for the range.
maxval: int or array of ints broadcast-compatible with ``shape``, a maximum
(exclusive) value for the range.
dtype: optional, an int dtype for the returned values (default int64 if
jax_enable_x64 is true, otherwise int32).
Returns:
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key("randint", key)
dtypes.check_user_dtype_supported(dtype)
dtype = dtypes.canonicalize_dtype(dtype)
shape = core.canonicalize_shape(shape)
return _randint(key, shape, minval, maxval, dtype)
@partial(jit, static_argnums=(1, 4))
def _randint(key, shape, minval, maxval, dtype) -> Array:
_check_shape("randint", shape, np.shape(minval), np.shape(maxval))
if not jnp.issubdtype(dtype, np.integer):
raise TypeError(f"randint only accepts integer dtypes, got {dtype}")
check_arraylike("randint", minval, maxval)
minval = jnp.asarray(minval)
maxval = jnp.asarray(maxval)
if not jnp.issubdtype(minval.dtype, np.integer):
minval = minval.astype(int)
if not jnp.issubdtype(maxval.dtype, np.integer):
maxval = maxval.astype(int)
# Flag where maxval is greater than the maximum value of dtype
# in order to handle cases like randint(key, shape, 0, 256, 'uint8')
maxval_out_of_range = lax.gt(
maxval, _convert_and_clip_integer(jnp.array(jnp.iinfo(dtype).max, dtype), maxval.dtype))
minval = _convert_and_clip_integer(minval, dtype)
maxval = _convert_and_clip_integer(maxval, dtype)
minval = lax.broadcast_to_rank(minval, len(shape))
maxval = lax.broadcast_to_rank(maxval, len(shape))
nbits = jnp.iinfo(dtype).bits
if nbits not in (8, 16, 32, 64):
raise TypeError(f"randint only accepts 8-, 16-, 32-, or 64-bit dtypes, got {dtype}")
# This algorithm is biased whenever (maxval - minval) is not a power of 2.
# We generate double the number of random bits required by the dtype so as to
# reduce that bias.
k1, k2 = _split(key)
rbits = lambda key: _random_bits(key, nbits, shape)
higher_bits, lower_bits = rbits(k1), rbits(k2)
unsigned_dtype = UINT_DTYPES[nbits]
span = lax.convert_element_type(maxval - minval, unsigned_dtype)
# Ensure that span=1 when maxval <= minval, so minval is always returned;
# https://github.com/jax-ml/jax/issues/222
span = lax.select(maxval <= minval, lax.full_like(span, 1), span)
# When maxval is out of range, the span has to be one larger.
# If span is already the maximum representable value, this will wrap to zero,
# causing remainders below to have no effect, which is the correct semantics.
span = lax.select(
maxval_out_of_range & (maxval > minval),
lax.add(span, _lax_const(span, 1)),
span)
# To compute a remainder operation on an integer that might have twice as many
# bits as we can represent in the native unsigned dtype, we compute a
# multiplier equal to 2**nbits % span. To avoid overflow, we use the identity:
# (a * b) % N = [(a % N) * (b % N)] % N
multiplier = lax.rem(_lax_const(span, 2 ** (nbits // 2)), span)
multiplier = lax.rem(lax.mul(multiplier, multiplier), span)
random_offset = lax.add(lax.mul(lax.rem(higher_bits, span), multiplier),
lax.rem(lower_bits, span))
random_offset = lax.rem(random_offset, span)
return lax.add(minval, lax.convert_element_type(random_offset, dtype))
def permutation(key: ArrayLike,
x: int | ArrayLike,
axis: int = 0,
independent: bool = False) -> Array:
"""Returns a randomly permuted array or range.
Args:
key: a PRNG key used as the random key.
x: int or array. If x is an integer, randomly shuffle np.arange(x).
If x is an array, randomly shuffle its elements.
axis: int, optional. The axis which x is shuffled along. Default is 0.
independent: bool, optional. If set to True, each individual vector along
the given axis is shuffled independently. Default is False.
Returns:
A shuffled version of x or array range
"""
key, _ = _check_prng_key("permutation", key)
check_arraylike("permutation", x)
axis = canonicalize_axis(axis, np.ndim(x) or 1)
if not np.ndim(x):
if not np.issubdtype(lax.dtype(x), np.integer):
raise TypeError("x must be an integer or at least 1-dimensional")
r = core.concrete_or_error(int, x, 'argument x of jax.random.permutation()')
return _shuffle(key, jnp.arange(r), axis)
if independent or np.ndim(x) == 1:
return _shuffle(key, x, axis)
ind = _shuffle(key, jnp.arange(x.shape[axis]), 0) # type: ignore[union-attr]
return jnp.take(x, ind, axis, unique_indices=True)
@partial(jit, static_argnums=(2,))
def _shuffle(key, x, axis) -> Array:
# On parallel architectures, Fisher-Yates is more expensive than doing
# multiple sorts. This algorithm is based on one developed and analyzed by
# tjablin@. We sort according to randomly-generated 32bit keys, but those keys
# may have collisions. If we repeat the process, using fresh 32bit keys for
# each sort, then whenever all pairs of elements have been assigned distinct
# keys at some iteration (or equivalently when the strings formed by
# concatenating the successive keys for each element are all distinct) then we
# are guaranteed to have a perfect sample (assuming that either the sort is
# stable or that any bias is not value-dependent). Since checking uniqueness
# at runtime may be expensive, we use a heuristic static stop criterion
# developed by tjablin@. See tensorflow/compiler/tf2xla/random_ops.cc for more
# info, and for the original implementation of this algorithm. See also
# Section 2 of http://people.csail.mit.edu/costis/6896sp11/lec5s.pdf for
# another analysis (where the keys are generated one bit at a time).
exponent = 3 # see tjablin@'s analysis for explanation of this parameter
uint32max = jnp.iinfo(np.uint32).max
if not core.is_constant_dim(x.size):
raise NotImplementedError(
"shape polymorphism for `permutation` or `shuffle`"
f" for arrays of non-constant size: {x.size}")
num_rounds = int(np.ceil(exponent * np.log(max(1, x.size)) / np.log(uint32max)))
for _ in range(num_rounds):
key, subkey = _split(key)
sort_keys = _random_bits(subkey, 32, x.shape)
_, x = lax.sort_key_val(sort_keys, x, axis)
return x
def choice(key: ArrayLike,
a: int | ArrayLike,
shape: Shape = (),
replace: bool = True,
p: RealArray | None = None,
axis: int = 0) -> Array:
"""Generates a random sample from a given array.
.. warning::
If ``p`` has fewer non-zero elements than the requested number of samples,
as specified in ``shape``, and ``replace=False``, the output of this
function is ill-defined. Please make sure to use appropriate inputs.
Args:
key: a PRNG key used as the random key.
a : array or int. If an ndarray, a random sample is generated from
its elements. If an int, the random sample is generated as if a were
arange(a).
shape : tuple of ints, optional. Output shape. If the given shape is,
e.g., ``(m, n)``, then ``m * n`` samples are drawn. Default is (),
in which case a single value is returned.
replace : boolean. Whether the sample is with or without replacement.
Default is True.
p : 1-D array-like, The probabilities associated with each entry in a.
If not given the sample assumes a uniform distribution over all
entries in a.
axis: int, optional. The axis along which the selection is performed.
The default, 0, selects by row.
Returns:
An array of shape `shape` containing samples from `a`.
"""
key, _ = _check_prng_key("choice", key)
if not isinstance(shape, Sequence):
raise TypeError("shape argument of jax.random.choice must be a sequence, "
f"got {shape}")
check_arraylike("choice", a)
arr = jnp.asarray(a)
if arr.ndim == 0:
n_inputs = core.concrete_or_error(int, a, "The error occurred in jax.random.choice()")
else:
axis = canonicalize_axis(axis, arr.ndim)
n_inputs = arr.shape[axis]
n_draws = math.prod(shape)
if n_draws == 0:
return jnp.zeros(shape, dtype=arr.dtype)
if n_inputs <= 0:
raise ValueError("a must be greater than 0 unless no samples are taken")
if not replace and n_draws > n_inputs:
raise ValueError(
f"Cannot take a larger sample (size {n_draws}) than "
f"population (size {n_inputs}) when 'replace=False'")
if p is None:
if replace:
ind = randint(key, shape, 0, n_inputs)
result = ind if arr.ndim == 0 else jnp.take(arr, ind, axis)
else:
slices = (slice(None),) * axis + (slice(n_draws),)
result = permutation(key, n_inputs if arr.ndim == 0 else arr, axis)[slices]
else:
check_arraylike("choice", p)
p_arr, = promote_dtypes_inexact(p)
if p_arr.shape != (n_inputs,):
raise ValueError(
"p must be None or a 1D vector with the same size as a.shape[axis]. "
f"p has shape {p_arr.shape} and a.shape[axis] is {n_inputs}.")
if replace:
p_cuml = jnp.cumsum(p_arr)
r = p_cuml[-1] * (1 - uniform(key, shape, dtype=p_cuml.dtype))
ind = jnp.searchsorted(p_cuml, r).astype(int)
else:
# Gumbel top-k trick: https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/
g = gumbel(key, (n_inputs,), dtype=p_arr.dtype) + jnp.log(p_arr)
ind = lax.top_k(g, k=n_draws)[1].astype(int)
result = ind if arr.ndim == 0 else jnp.take(arr, ind, axis)
return result.reshape(shape if arr.ndim == 0 else
arr.shape[0:axis] + tuple(shape) + arr.shape[axis+1:])
def normal(key: ArrayLike,
shape: Shape = (),
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample standard normal random values with given shape and float dtype.
The values are returned according to the probability density function:
.. math::
f(x) = \frac{1}{\sqrt{2\pi}}e^{-x^2/2}
on the domain :math:`-\infty < x < \infty`
Args:
key: a PRNG key used as the random key.
shape: optional, a tuple of nonnegative integers representing the result
shape. Default ().
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key("normal", key)
shape = core.canonicalize_shape(shape)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.inexact):
raise ValueError(f"dtype argument to `normal` must be a float or complex dtype, "
f"got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
return _normal(key, shape, dtype)
@partial(jit, static_argnums=(1, 2))
def _normal(key, shape, dtype) -> Array:
if dtypes.issubdtype(dtype, np.complexfloating):
sqrt2 = np.array(np.sqrt(2), dtype)
key_re, key_im = _split(key)
real_dtype = np.array(0, dtype).real.dtype
_re = _normal_real(key_re, shape, real_dtype).astype(dtype)
_im = _normal_real(key_im, shape, real_dtype).astype(dtype)
return (_re + 1j * _im) / sqrt2
else:
return _normal_real(key, shape, dtype)
@partial(jit, static_argnums=(1, 2))
def _normal_real(key, shape, dtype) -> Array:
_check_shape("normal", shape)
lo = np.nextafter(np.array(-1., dtype), np.array(0., dtype), dtype=dtype)
hi = np.array(1., dtype)
u = uniform(key, shape, dtype, lo, hi)
return lax.mul(np.array(np.sqrt(2), dtype), lax.erf_inv(u))
def multivariate_normal(key: ArrayLike,
mean: RealArray,
cov: RealArray,
shape: Shape | None = None,
dtype: DTypeLikeFloat | None = None,
method: str = 'cholesky') -> Array:
r"""Sample multivariate normal random values with given mean and covariance.
The values are returned according to the probability density function:
.. math::
f(x;\mu, \Sigma) = (2\pi)^{-k/2} \det(\Sigma)^{-1}e^{-\frac{1}{2}(x - \mu)^T \Sigma^{-1} (x - \mu)}
where :math:`k` is the dimension, :math:`\mu` is the mean (given by ``mean``) and
:math:`\Sigma` is the covariance matrix (given by ``cov``).
Args:
key: a PRNG key used as the random key.
mean: a mean vector of shape ``(..., n)``.
cov: a positive definite covariance matrix of shape ``(..., n, n)``. The
batch shape ``...`` must be broadcast-compatible with that of ``mean``.
shape: optional, a tuple of nonnegative integers specifying the result
batch shape; that is, the prefix of the result shape excluding the last
axis. Must be broadcast-compatible with ``mean.shape[:-1]`` and
``cov.shape[:-2]``. The default (None) produces a result batch shape by
broadcasting together the batch shapes of ``mean`` and ``cov``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
method: optional, a method to compute the factor of ``cov``.
Must be one of 'svd', 'eigh', and 'cholesky'. Default 'cholesky'. For
singular covariance matrices, use 'svd' or 'eigh'.
Returns:
A random array with the specified dtype and shape given by
``shape + mean.shape[-1:]`` if ``shape`` is not None, or else
``broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]``.
"""
key, _ = _check_prng_key("multivariate_normal", key)
dtypes.check_user_dtype_supported(dtype)
mean, cov = promote_dtypes_inexact(mean, cov)
if method not in {'svd', 'eigh', 'cholesky'}:
raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}")
if dtype is None:
dtype = mean.dtype
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `multivariate_normal` must be a float "
f"dtype, got {dtype}")
if shape is not None:
shape = core.canonicalize_shape(shape)
return _multivariate_normal(key, mean, cov, shape, dtype, method)
@partial(jit, static_argnums=(3, 4, 5))
def _multivariate_normal(key, mean, cov, shape, dtype, method) -> Array:
if not np.ndim(mean) >= 1:
msg = "multivariate_normal requires mean.ndim >= 1, got mean.ndim == {}"
raise ValueError(msg.format(np.ndim(mean)))
if not np.ndim(cov) >= 2:
msg = "multivariate_normal requires cov.ndim >= 2, got cov.ndim == {}"
raise ValueError(msg.format(np.ndim(cov)))
n = mean.shape[-1]
if np.shape(cov)[-2:] != (n, n):
msg = ("multivariate_normal requires cov.shape == (..., n, n) for n={n}, "
"but got cov.shape == {shape}.")
raise ValueError(msg.format(n=n, shape=np.shape(cov)))
if shape is None:
shape = lax.broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
else:
_check_shape("normal", shape, mean.shape[:-1], cov.shape[:-2])
if method == 'svd':
(u, s, _) = svd(cov)
factor = u * jnp.sqrt(s[..., None, :])
elif method == 'eigh':
(w, v) = eigh(cov)
factor = v * jnp.sqrt(w[..., None, :])
else: # 'cholesky'
factor = cholesky(cov)
normal_samples = normal(key, shape + mean.shape[-1:], dtype)
with config.numpy_rank_promotion('allow'):
result = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
return result
def truncated_normal(key: ArrayLike,
lower: RealArray,
upper: RealArray,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample truncated standard normal random values with given shape and dtype.
The values are returned according to the probability density function:
.. math::
f(x) \propto e^{-x^2/2}
on the domain :math:`\rm{lower} < x < \rm{upper}`.
Args:
key: a PRNG key used as the random key.
lower: a float or array of floats representing the lower bound for
truncation. Must be broadcast-compatible with ``upper``.
upper: a float or array of floats representing the upper bound for
truncation. Must be broadcast-compatible with ``lower``.
shape: optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with ``lower`` and ``upper``. The
default (None) produces a result shape by broadcasting ``lower`` and
``upper``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified dtype and shape given by ``shape`` if
``shape`` is not None, or else by broadcasting ``lower`` and ``upper``.
Returns values in the open interval ``(lower, upper)``.
"""
if shape is not None:
shape = core.canonicalize_shape(shape)
key, _ = _check_prng_key("truncated_normal", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `truncated_normal` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
return _truncated_normal(key, lower, upper, shape, dtype)
@partial(jit, static_argnums=(3, 4))
def _truncated_normal(key, lower, upper, shape, dtype) -> Array:
if shape is None:
shape = lax.broadcast_shapes(np.shape(lower), np.shape(upper))
else:
_check_shape("truncated_normal", shape, np.shape(lower), np.shape(upper))
sqrt2 = np.array(np.sqrt(2), dtype)
lower = lax.convert_element_type(lower, dtype)
upper = lax.convert_element_type(upper, dtype)
a = lax.erf(lower / sqrt2)
b = lax.erf(upper / sqrt2)
if not jnp.issubdtype(dtype, np.floating):
raise TypeError("truncated_normal only accepts floating point dtypes.")
u = uniform(key, shape, dtype, minval=a, maxval=b)
out = sqrt2 * lax.erf_inv(u)
# Clamp the value to the open interval (lower, upper) to make sure that
# rounding (or if we chose `a` for `u`) doesn't push us outside of the range.
return jnp.clip(
out,
lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype)))
def bernoulli(key: ArrayLike,
p: RealArray = np.float32(0.5),
shape: Shape | None = None) -> Array:
r"""Sample Bernoulli random values with given shape and mean.
The values are distributed according to the probability mass function:
.. math::
f(k; p) = p^k(1 - p)^{1 - k}
where :math:`k \in \{0, 1\}` and :math:`0 \le p \le 1`.
Args:
key: a PRNG key used as the random key.
p: optional, a float or array of floats for the mean of the random
variables. Must be broadcast-compatible with ``shape``. Default 0.5.
shape: optional, a tuple of nonnegative integers representing the result
shape. Must be broadcast-compatible with ``p.shape``. The default (None)
produces a result shape equal to ``p.shape``.
Returns:
A random array with boolean dtype and shape given by ``shape`` if ``shape``
is not None, or else ``p.shape``.
"""
if shape is not None:
shape = core.canonicalize_shape(shape)
key, _ = _check_prng_key("bernoulli", key)
dtype = dtypes.canonicalize_dtype(lax.dtype(p))
if not jnp.issubdtype(dtype, np.floating):
msg = "bernoulli probability `p` must have a floating dtype, got {}."
raise TypeError(msg.format(dtype))
p = lax.convert_element_type(p, dtype)
return _bernoulli(key, p, shape)
@partial(jit, static_argnums=(2,))
def _bernoulli(key, p, shape) -> Array:
if shape is None:
# TODO: Use the named part of `p` as well
shape = np.shape(p)
else:
_check_shape("bernoulli", shape, np.shape(p))
return uniform(key, shape, lax.dtype(p)) < p
def beta(key: ArrayLike,
a: RealArray,
b: RealArray,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Beta random values with given shape and float dtype.
The values are distributed according to the probability density function:
.. math::
f(x;a,b) \propto x^{a - 1}(1 - x)^{b - 1}
on the domain :math:`0 \le x \le 1`.
Args:
key: a PRNG key used as the random key.
a: a float or array of floats broadcast-compatible with ``shape``
representing the first parameter "alpha".
b: a float or array of floats broadcast-compatible with ``shape``
representing the second parameter "beta".
shape: optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with ``a`` and ``b``. The default
(None) produces a result shape by broadcasting ``a`` and ``b``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified dtype and shape given by ``shape`` if
``shape`` is not None, or else by broadcasting ``a`` and ``b``.
"""
key, _ = _check_prng_key("beta", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `beta` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = core.canonicalize_shape(shape)
return _beta(key, a, b, shape, dtype)
def _beta(key, a, b, shape, dtype) -> Array:
if shape is None:
shape = lax.broadcast_shapes(np.shape(a), np.shape(b))
else:
_check_shape("beta", shape, np.shape(a), np.shape(b))
a = lax.convert_element_type(a, dtype)
b = lax.convert_element_type(b, dtype)
key_a, key_b = _split(key)
a = jnp.broadcast_to(a, shape)
b = jnp.broadcast_to(b, shape)
log_gamma_a = loggamma(key_a, a, shape, dtype)
log_gamma_b = loggamma(key_b, b, shape, dtype)
# Compute gamma_a / (gamma_a + gamma_b) without losing precision.
log_max = lax.max(log_gamma_a, log_gamma_b)
gamma_a_scaled = jnp.exp(log_gamma_a - log_max)
gamma_b_scaled = jnp.exp(log_gamma_b - log_max)
return gamma_a_scaled / (gamma_a_scaled + gamma_b_scaled)
def cauchy(key: ArrayLike,
shape: Shape = (),
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Cauchy random values with given shape and float dtype.
The values are distributed according to the probability density function:
.. math::
f(x) \propto \frac{1}{x^2 + 1}
on the domain :math:`-\infty < x < \infty`
Args:
key: a PRNG key used as the random key.
shape: optional, a tuple of nonnegative integers representing the result
shape. Default ().
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key("cauchy", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `cauchy` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
shape = core.canonicalize_shape(shape)
return _cauchy(key, shape, dtype)
@partial(jit, static_argnums=(1, 2))
def _cauchy(key, shape, dtype) -> Array:
_check_shape("cauchy", shape)
u = uniform(key, shape, dtype, minval=jnp.finfo(dtype).eps, maxval=1.)
pi = _lax_const(u, np.pi)
return lax.tan(lax.mul(pi, lax.sub(u, _lax_const(u, 0.5))))
def dirichlet(key: ArrayLike,
alpha: RealArray,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Dirichlet random values with given shape and float dtype.
The values are distributed according to the probability density function:
.. math::
f(\{x_i\}; \{\alpha_i\}) \propto \prod_{i=1}^k x_i^{\alpha_i - 1}
Where :math:`k` is the dimension, and :math:`\{x_i\}` satisfies
.. math::
\sum_{i=1}^k x_i = 1
and :math:`0 \le x_i \le 1` for all :math:`x_i`.
Args:
key: a PRNG key used as the random key.
alpha: an array of shape ``(..., n)`` used as the concentration
parameter of the random variables.
shape: optional, a tuple of nonnegative integers specifying the result
batch shape; that is, the prefix of the result shape excluding the last
element of value ``n``. Must be broadcast-compatible with
``alpha.shape[:-1]``. The default (None) produces a result shape equal to
``alpha.shape``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified dtype and shape given by
``shape + (alpha.shape[-1],)`` if ``shape`` is not None, or else
``alpha.shape``.
"""
key, _ = _check_prng_key("dirichlet", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `dirichlet` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = core.canonicalize_shape(shape)
return _dirichlet(key, alpha, shape, dtype)
@partial(jit, static_argnums=(2, 3))
def _dirichlet(key, alpha, shape, dtype) -> Array:
if not np.ndim(alpha) >= 1:
msg = "dirichlet requires alpha.ndim >= 1, got alpha.ndim == {}"
raise ValueError(msg.format(np.ndim(alpha)))
if shape is None:
shape = np.shape(alpha)[:-1]
else:
_check_shape("dirichlet", shape, np.shape(alpha)[:-1])
alpha = lax.convert_element_type(alpha, dtype)
# Compute gamma in log space, otherwise small alpha can lead to poor behavior.
log_gamma_samples = loggamma(key, alpha, shape + np.shape(alpha)[-1:], dtype)
return _softmax(log_gamma_samples, -1)
def _softmax(x, axis) -> Array:
"""Utility to compute the softmax of x along a given axis."""
if not dtypes.issubdtype(x.dtype, np.floating):
raise TypeError(f"_softmax only accepts floating dtypes, got {x.dtype}")
x_max = jnp.max(x, axis, keepdims=True)
unnormalized = jnp.exp(x - lax.stop_gradient(x_max))
return unnormalized / unnormalized.sum(axis, keepdims=True)
def exponential(key: ArrayLike,
shape: Shape = (),
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Exponential random values with given shape and float dtype.
The values are distributed according to the probability density function:
.. math::
f(x) = e^{-x}
on the domain :math:`0 \le x < \infty`.
Args:
key: a PRNG key used as the random key.
shape: optional, a tuple of nonnegative integers representing the result
shape. Default ().
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key("exponential", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `exponential` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
shape = core.canonicalize_shape(shape)
return _exponential(key, shape, dtype)
@partial(jit, static_argnums=(1, 2))
def _exponential(key, shape, dtype) -> Array:
_check_shape("exponential", shape)
u = uniform(key, shape, dtype)
# taking 1 - u to move the domain of log to (0, 1] instead of [0, 1)
return lax.neg(lax.log1p(lax.neg(u)))
def _gamma_one(key: Array, alpha, log_space) -> Array:
# Ref: A simple method for generating gamma variables, George Marsaglia and Wai Wan Tsang
# The algorithm can also be founded in:
# https://en.wikipedia.org/wiki/Gamma_distribution#Generating_gamma-distributed_random_variables
zero = _lax_const(alpha, 0)
one = _lax_const(alpha, 1)
minus_one = _lax_const(alpha, -1)
one_over_two = _lax_const(alpha, 0.5)
one_over_three = _lax_const(alpha, 1. / 3.)
squeeze_const = _lax_const(alpha, 0.0331)
dtype = lax.dtype(alpha)
# for alpha < 1, we boost alpha to alpha + 1 and get a sample according to
# Gamma(alpha) ~ Gamma(alpha+1) * Uniform()^(1 / alpha)
# When alpha is very small, this boost can be problematic because it may result
# in floating point underflow; for this reason we compute it in log space if
# specified by the `log_space` argument:
# log[Gamma(alpha)] ~ log[Gamma(alpha + 1)] + log[Uniform()] / alpha
# Note that log[Uniform()] ~ -Exponential(), but to avoid problems at x=0
# exponential is computed in terms of log[1 - Uniform()]; we must account for this
# so that log-space and non-log-space samples match.
boost_mask = lax.ge(alpha, one)
alpha_orig = alpha
alpha = lax.select(boost_mask, alpha, lax.add(alpha, one))
d = lax.sub(alpha, one_over_three)
c = lax.div(one_over_three, lax.sqrt(d))
def _cond_fn(kXVU):
_, X, V, U = kXVU
# TODO: use lax.cond when its batching rule is supported
# The reason is to avoid evaluating second condition which involves log+log
# if the first condition is satisfied
cond = lax.bitwise_and(lax.ge(U, lax.sub(one, lax.mul(squeeze_const, lax.mul(X, X)))),
lax.ge(lax.log(U), lax.add(lax.mul(X, one_over_two),
lax.mul(d, lax.add(lax.sub(one, V),
lax.log(V))))))
return cond
def _body_fn(kXVU):
def _next_kxv(kxv):
key = kxv[0]
key, subkey = _split(key)
x = normal(subkey, (), dtype=dtype)
v = lax.add(one, lax.mul(x, c))
return key, x, v
key = kXVU[0]
key, x_key, U_key = _split(key, 3)
_, x, v = lax.while_loop(lambda kxv: lax.le(kxv[2], zero), _next_kxv, (x_key, zero, minus_one))
X = lax.mul(x, x)
V = lax.mul(lax.mul(v, v), v)
U = uniform(U_key, (), dtype=dtype)
return key, X, V, U
# initial state is chosen such that _cond_fn will return True
key, subkey = _split(key)
_, _, V, _ = lax.while_loop(_cond_fn, _body_fn, (key, zero, one, _lax_const(alpha, 2)))
if log_space:
log_samples = lax.neg(exponential(subkey, (), dtype=dtype))
log_boost = lax.select(boost_mask | (log_samples == 0), zero, lax.mul(log_samples, lax.div(one, alpha_orig)))
return lax.add(lax.add(lax.log(d), lax.log(V)), log_boost)
else:
samples = 1 - uniform(subkey, (), dtype=dtype)
boost = lax.select(boost_mask, one, lax.pow(samples, lax.div(one, alpha_orig)))
return lax.mul(lax.mul(d, V), boost)
def _gamma_grad(sample, a, *, log_space):
samples = jnp.reshape(sample, -1)
alphas = jnp.reshape(a, -1)
if log_space:
# d[log(sample)] = d[sample] / sample
# This requires computing exp(log_sample), which may be zero due to float roundoff.
# In this case, correct it to smallest representable float.
samples = lax.exp(samples)
zero = lax_internal._const(sample, 0)
tiny = lax.full_like(samples, jnp.finfo(samples.dtype).tiny)
samples = lax.select(lax.eq(samples, zero), tiny, samples)
gamma_grad = lambda alpha, sample: lax.random_gamma_grad(alpha, sample) / sample
else:
gamma_grad = lax.random_gamma_grad
if xla_bridge.get_backend().platform == 'cpu':
grads = lax.map(lambda args: gamma_grad(*args), (alphas, samples))
else:
grads = vmap(gamma_grad)(alphas, samples)
return grads.reshape(np.shape(a))
def _gamma_impl(key, a, *, log_space, use_vmap=False):
# split key to match the shape of a
a_shape = jnp.shape(a)
split_count = math.prod(a_shape[key.ndim:])
keys = key.flatten()
keys = vmap(_split, in_axes=(0, None))(keys, split_count)
keys = keys.flatten()
alphas = a.flatten()
if use_vmap and _key_impl(key) is prng.threefry_prng_impl:
samples = vmap(partial(_gamma_one, log_space=log_space))(keys, alphas)
else:
samples = lax.map(
lambda args: _gamma_one(*args, log_space=log_space), (keys, alphas))
return jnp.reshape(samples, a_shape)
def _gamma_batching_rule(batched_args, batch_dims, *, log_space):
k, a = batched_args
bk, ba = batch_dims
size = next(
t.shape[i] for t, i in zip(batched_args, batch_dims) if i is not None)
k = batching.bdim_at_front(k, bk, size)
a = batching.bdim_at_front(a, ba, size)
return random_gamma_p.bind(k, a, log_space=log_space), 0
random_gamma_p = core.Primitive('random_gamma')
random_gamma_p.def_impl(_gamma_impl)
random_gamma_p.def_abstract_eval(lambda key, a, **_: a)
ad.defjvp2(
random_gamma_p, None,
lambda tangent, ans, key, a, **kwds: tangent * _gamma_grad(ans, a, **kwds))
mlir.register_lowering(random_gamma_p, mlir.lower_fun(
partial(_gamma_impl, use_vmap=True),
multiple_results=False))
mlir.register_lowering(random_gamma_p, mlir.lower_fun(
partial(_gamma_impl, use_vmap=True),
multiple_results=False), platform='cpu')
batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule
def gamma(key: ArrayLike,
a: RealArray,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Gamma random values with given shape and float dtype.
The values are distributed according to the probability density function:
.. math::
f(x;a) \propto x^{a - 1} e^{-x}
on the domain :math:`0 \le x < \infty`, with :math:`a > 0`.
This is the standard gamma density, with a unit scale/rate parameter.
Dividing the sample output by the rate is equivalent to sampling from
*gamma(a, rate)*, and multiplying the sample output by the scale is equivalent
to sampling from *gamma(a, scale)*.
Args:
key: a PRNG key used as the random key.
a: a float or array of floats broadcast-compatible with ``shape``
representing the parameter of the distribution.
shape: optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with ``a``. The default (None)
produces a result shape equal to ``a.shape``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``a.shape``.
See Also:
loggamma : sample gamma values in log-space, which can provide improved
accuracy for small values of ``a``.
"""
key, _ = _check_prng_key("gamma", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `gamma` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = core.canonicalize_shape(shape)
return _gamma(key, a, shape=shape, dtype=dtype)
def loggamma(key: ArrayLike,
a: RealArray,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
"""Sample log-gamma random values with given shape and float dtype.
This function is implemented such that the following will hold for a
dtype-appropriate tolerance::
np.testing.assert_allclose(jnp.exp(loggamma(*args)), gamma(*args), rtol=rtol)
The benefit of log-gamma is that for samples very close to zero (which occur frequently
when `a << 1`) sampling in log space provides better precision.
Args:
key: a PRNG key used as the random key.
a: a float or array of floats broadcast-compatible with ``shape``
representing the parameter of the distribution.
shape: optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with ``a``. The default (None)
produces a result shape equal to ``a.shape``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``a.shape``.
See Also:
gamma : standard gamma sampler.
"""
key, _ = _check_prng_key("loggamma", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `gamma` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = core.canonicalize_shape(shape)
return _gamma(key, a, shape=shape, dtype=dtype, log_space=True)
@partial(jit, static_argnames=('shape', 'dtype', 'log_space'))
def _gamma(key, a, shape, dtype, log_space=False) -> Array:
if shape is None:
shape = np.shape(a)
else:
_check_shape("gamma", shape, np.shape(a))
a = lax.convert_element_type(a, dtype)
if np.shape(a) != shape:
a = jnp.broadcast_to(a, shape)
return random_gamma_p.bind(key, a, log_space=log_space)
@partial(jit, static_argnums=(2, 3, 4))
def _poisson_knuth(key, lam, shape, dtype, max_iters) -> Array:
# Knuth's algorithm for generating Poisson random variates.
# Reference:
# https://en.wikipedia.org/wiki/Poisson_distribution#Generating_Poisson-distributed_random_variables
def body_fn(carry):
i, k, rng, log_prod = carry
rng, subkey = _split(rng)
k = lax.select(log_prod > -lam, k + 1, k)
u = uniform(subkey, shape, np.float32)
return i + 1, k, rng, log_prod + jnp.log(u)
def cond_fn(carry):
i, log_prod = carry[0], carry[3]
return (log_prod > -lam).any() & (i < max_iters)
k_init = lax.full_like(lam, 0, dtype, shape)
log_rate_init = lax.full_like(lam, 0, np.float32, shape)
k = lax.while_loop(cond_fn, body_fn, (0, k_init, key, log_rate_init))[1]
return (k - 1).astype(dtype)
@partial(jit, static_argnums=(2, 3, 4))
def _poisson_rejection(key, lam, shape, dtype, max_iters) -> Array:
# Transformed rejection due to Hormann.
# Reference:
# http://citeseer.ist.psu.edu/viewdoc/citations;jsessionid=1BEB35946CC807879F55D42512E5490C?doi=10.1.1.48.3054.
log_lam = lax.log(lam)
b = 0.931 + 2.53 * lax.sqrt(lam)
a = -0.059 + 0.02483 * b
inv_alpha = 1.1239 + 1.1328 / (b - 3.4)
v_r = 0.9277 - 3.6224 / (b - 2)
def body_fn(carry):
i, k_out, accepted, key = carry
key, subkey_0, subkey_1 = _split(key, 3)
u = uniform(subkey_0, shape, lam.dtype) - 0.5
v = uniform(subkey_1, shape, lam.dtype)
u_shifted = 0.5 - abs(u)
k = lax.floor((2 * a / u_shifted + b) * u + lam + 0.43)
s = lax.log(v * inv_alpha / (a / (u_shifted * u_shifted) + b))
t = -lam + k * log_lam - lax.lgamma(k + 1)
accept1 = (u_shifted >= 0.07) & (v <= v_r)
reject = (k < 0) | ((u_shifted < 0.013) & (v > u_shifted))
accept2 = s <= t
accept = accept1 | (~reject & accept2)
k_out = lax.select(accept, k, k_out)
accepted |= accept
return i + 1, k_out, accepted, key
def cond_fn(carry):
i, k_out, accepted, key = carry
return (~accepted).any() & (i < max_iters)
k_init = lax.full_like(lam, -1, lam.dtype, shape)
accepted = lax.full_like(lam, False, jnp.bool_, shape)
k = lax.while_loop(cond_fn, body_fn, (0, k_init, accepted, key))[1]
return k.astype(dtype)
@partial(jit, static_argnums=(2, 3))
def _poisson(key, lam, shape, dtype) -> Array:
# The implementation matches TensorFlow and NumPy:
# https://github.com/tensorflow/tensorflow/blob/v2.2.0-rc3/tensorflow/core/kernels/random_poisson_op.cc
# https://github.com/numpy/numpy/blob/v1.18.3/numpy/random/src/distributions/distributions.c#L574
# For lambda < 10, we use the Knuth algorithm; otherwise, we use transformed
# rejection sampling.
use_knuth = _isnan(lam) | (lam < 10)
lam_knuth = lax.select(use_knuth, lam, lax.full_like(lam, 0.0))
# The acceptance probability for rejection sampling maxes out at 89% as
# λ -> ∞, so pick some arbitrary large value.
lam_rejection = lax.select(use_knuth, lax.full_like(lam, 1e5), lam)
max_iters = dtype.type(jnp.iinfo(dtype).max) # insanely conservative
result = lax.select(
use_knuth,
_poisson_knuth(key, lam_knuth, shape, dtype, max_iters),
_poisson_rejection(key, lam_rejection, shape, dtype, max_iters),
)
return lax.select(lam == 0, jnp.zeros_like(result), result)
def poisson(key: ArrayLike,
lam: RealArray,
shape: Shape | None = None,
dtype: DTypeLikeInt = int) -> Array:
r"""Sample Poisson random values with given shape and integer dtype.
The values are distributed according to the probability mass function:
.. math::
f(k; \lambda) = \frac{\lambda^k e^{-\lambda}}{k!}
Where `k` is a non-negative integer and :math:`\lambda > 0`.
Args:
key: a PRNG key used as the random key.
lam: rate parameter (mean of the distribution), must be >= 0. Must be broadcast-compatible with ``shape``
shape: optional, a tuple of nonnegative integers representing the result
shape. Default (None) produces a result shape equal to ``lam.shape``.
dtype: optional, a integer dtype for the returned values (default int64 if
jax_enable_x64 is true, otherwise int32).
Returns:
A random array with the specified dtype and with shape given by ``shape`` if
``shape is not None, or else by ``lam.shape``.
"""
key, _ = _check_prng_key("poisson", key)
dtypes.check_user_dtype_supported(dtype)
# TODO(frostig): generalize underlying poisson implementation and
# remove this check
keys_dtype = typing.cast(prng.KeyTy, key.dtype)
key_impl = keys_dtype._impl
if key_impl is not prng.threefry_prng_impl:
raise NotImplementedError(
'`poisson` is only implemented for the threefry2x32 RNG, '
f'not {key_impl}')
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = core.canonicalize_shape(shape)
else:
shape = np.shape(lam)
lam = jnp.broadcast_to(lam, shape)
lam = lax.convert_element_type(lam, np.float32)
return _poisson(key, lam, shape, dtype)
def gumbel(key: ArrayLike,
shape: Shape = (),
dtype: DTypeLikeFloat = float,
mode: str | None =None) -> Array:
"""Sample Gumbel random values with given shape and float dtype.
The values are distributed according to the probability density function:
.. math::
f(x) = e^{-(x + e^{-x})}
Args:
key: a PRNG key used as the random key.
shape: optional, a tuple of nonnegative integers representing the result
shape. Default ().
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
mode: optional, "high" or "low" for how many bits to use when sampling.
Returns:
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key("gumbel", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `gumbel` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
shape = core.canonicalize_shape(shape)
if mode is None:
mode = "high" if config.use_high_dynamic_range_gumbel.value else "low"
if mode not in ("high", "low"):
raise ValueError("Must provide valid mode for gumbel got: %s" % mode)
return _gumbel(key, shape, dtype, mode)
@partial(jit, static_argnums=(1, 2, 3))
def _gumbel(key, shape, dtype, mode) -> Array:
_check_shape("gumbel", shape)
if mode == "high":
high, low = _uniform(key, (2,) + shape, dtype, minval=0., maxval=1.)
# TODO(parkers): The condition is to protect against rounding up but
# we should be able to add safely with the right addition operation.
x = jnp.where(high >= 0.5, high,
high + 2 ** -(jnp.finfo(dtype).nmant) * low + jnp.finfo(dtype).tiny)
return -jnp.log(-jnp.log1p(-x))
else:
return -jnp.log(-jnp.log(
_uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.)))
def categorical(
key: ArrayLike,
logits: RealArray,
axis: int = -1,
shape: Shape | None = None,
replace: bool = True,
) -> Array:
"""Sample random values from categorical distributions.
Sampling with replacement uses the Gumbel max trick. Sampling without replacement uses
the Gumbel top-k trick. See [1] for reference.
Args:
key: a PRNG key used as the random key.
logits: Unnormalized log probabilities of the categorical distribution(s) to sample from,
so that `softmax(logits, axis)` gives the corresponding probabilities.
axis: Axis along which logits belong to the same categorical distribution.
shape: Optional, a tuple of nonnegative integers representing the result shape.
Must be broadcast-compatible with ``np.delete(logits.shape, axis)``.
The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``.
replace: If True, perform sampling without replacement. Default (False) is to
perform sampling with replacement.
Returns:
A random array with int dtype and shape given by ``shape`` if ``shape``
is not None, or else ``np.delete(logits.shape, axis)``.
References:
.. [1] Wouter Kool, Herke van Hoof, Max Welling. "Stochastic Beams and Where to Find
Them: The Gumbel-Top-k Trick for Sampling Sequences Without Replacement".
Proceedings of the 36th International Conference on Machine Learning, PMLR
97:3499-3508, 2019. https://proceedings.mlr.press/v97/kool19a.html.
"""
key, _ = _check_prng_key("categorical", key)
check_arraylike("categorical", logits)
logits_arr = jnp.asarray(logits)
batch_shape = tuple(np.delete(logits_arr.shape, axis))
if shape is None:
shape = batch_shape
else:
shape = core.canonicalize_shape(shape)
_check_shape("categorical", shape, batch_shape)
shape_prefix = shape[:len(shape)-len(batch_shape)]
if replace:
if axis >= 0:
axis -= len(logits_arr.shape)
logits_shape = list(shape[len(shape) - len(batch_shape):])
logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis])
return jnp.argmax(
gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype) +
lax.expand_dims(logits_arr, tuple(range(len(shape_prefix)))),
axis=axis)
else:
logits_arr += gumbel(key, logits_arr.shape, logits_arr.dtype)
k = math.prod(shape_prefix)
if k > logits_arr.shape[axis]:
raise ValueError(
f"Number of samples without replacement ({k}) cannot exceed number of "
f"categories ({logits_arr.shape[axis]})."
)
_, indices = lax.top_k(jnp.moveaxis(logits_arr, axis, -1), k)
assert indices.shape == batch_shape + (k,)
assert shape == shape_prefix + batch_shape
dimensions = (indices.ndim - 1, *range(indices.ndim - 1))
indices = lax.reshape(indices, shape, dimensions)
assert indices.shape == shape
return indices
def laplace(key: ArrayLike,
shape: Shape = (),
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Laplace random values with given shape and float dtype.
The values are distributed according to the probability density function:
.. math::
f(x) = \frac{1}{2}e^{-|x|}
Args:
key: a PRNG key used as the random key.
shape: optional, a tuple of nonnegative integers representing the result
shape. Default ().
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key("laplace", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `laplace` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
shape = core.canonicalize_shape(shape)
return _laplace(key, shape, dtype)
@partial(jit, static_argnums=(1, 2))
def _laplace(key, shape, dtype) -> Array:
_check_shape("laplace", shape)
u = uniform(
key, shape, dtype, minval=-1. + jnp.finfo(dtype).epsneg, maxval=1.)
return lax.mul(lax.sign(u), lax.log1p(lax.neg(lax.abs(u))))
def logistic(key: ArrayLike,
shape: Shape = (),
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample logistic random values with given shape and float dtype.
The values are distributed according to the probability density function:
.. math::
f(x) = \frac{e^{-x}}{(1 + e^{-x})^2}
Args:
key: a PRNG key used as the random key.
shape: optional, a tuple of nonnegative integers representing the result
shape. Default ().
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key("logistic", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `logistic` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
shape = core.canonicalize_shape(shape)
return _logistic(key, shape, dtype)
@partial(jit, static_argnums=(1, 2))
def _logistic(key, shape, dtype):
_check_shape("logistic", shape)
x = uniform(key, shape, dtype, minval=jnp.finfo(dtype).eps, maxval=1.)
return lax.log(lax.div(x, lax.sub(_lax_const(x, 1), x)))
def pareto(key: ArrayLike,
b: RealArray,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Pareto random values with given shape and float dtype.
The values are distributed according to the probability density function:
.. math::
f(x; b) = b / x^{b + 1}
on the domain :math:`1 \le x < \infty` with :math:`b > 0`
Args:
key: a PRNG key used as the random key.
b: a float or array of floats broadcast-compatible with ``shape``
representing the parameter of the distribution.
shape: optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with ``b``. The default (None)
produces a result shape equal to ``b.shape``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``b.shape``.
"""
key, _ = _check_prng_key("pareto", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `pareto` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = core.canonicalize_shape(shape)
return _pareto(key, b, shape, dtype)
@partial(jit, static_argnums=(2, 3))
def _pareto(key, b, shape, dtype) -> Array:
if shape is None:
shape = np.shape(b)
else:
_check_shape("pareto", shape)
b = lax.convert_element_type(b, dtype)
e = exponential(key, shape, dtype)
return lax.exp(e / b)
def t(key: ArrayLike,
df: RealArray,
shape: Shape = (),
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Student's t random values with given shape and float dtype.
The values are distributed according to the probability density function:
.. math::
f(t; \nu) \propto \left(1 + \frac{t^2}{\nu}\right)^{-(\nu + 1)/2}
Where :math:`\nu > 0` is the degrees of freedom, given by the parameter ``df``.
Args:
key: a PRNG key used as the random key.
df: a float or array of floats broadcast-compatible with ``shape``
representing the degrees of freedom parameter of the distribution.
shape: optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with ``df``. The default (None)
produces a result shape equal to ``df.shape``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``df.shape``.
"""
key, _ = _check_prng_key("t", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `t` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
shape = core.canonicalize_shape(shape)
return _t(key, df, shape, dtype)
@partial(jit, static_argnums=(2, 3))
def _t(key, df, shape, dtype) -> Array:
if shape is None:
shape = np.shape(df)
else:
_check_shape("t", shape, np.shape(df))
df = lax.convert_element_type(df, dtype)
key_n, key_g = _split(key)
n = normal(key_n, shape, dtype)
two = _lax_const(n, 2)
half_df = lax.div(df, two)
g = gamma(key_g, half_df, shape, dtype)
return n * jnp.sqrt(half_df / g)
def chisquare(key: ArrayLike,
df: RealArray,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Chisquare random values with given shape and float dtype.
The values are distributed according to the probability density function:
.. math::
f(x; \nu) \propto x^{\nu/2 - 1}e^{-x/2}
on the domain :math:`0 < x < \infty`, where :math:`\nu > 0` represents the
degrees of freedom, given by the parameter ``df``.
Args:
key: a PRNG key used as the random key.
df: a float or array of floats broadcast-compatible with ``shape``
representing the parameter of the distribution.
shape: optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with ``df``. The default (None)
produces a result shape equal to ``df.shape``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``df.shape``.
"""
key, _ = _check_prng_key("chisquare", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError("dtype argument to `chisquare` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = core.canonicalize_shape(shape)
return _chisquare(key, df, shape, dtype)
@partial(jit, static_argnums=(2, 3))
def _chisquare(key, df, shape, dtype) -> Array:
if shape is None:
shape = np.shape(df)
else:
_check_shape("chisquare", shape, np.shape(df))
df = lax.convert_element_type(df, dtype)
two = _lax_const(df, 2)
half_df = lax.div(df, two)
log_g = loggamma(key, a=half_df, shape=shape, dtype=dtype)
chi2 = lax.mul(jnp.exp(log_g), two)
return chi2
def f(key: ArrayLike,
dfnum: RealArray,
dfden: RealArray,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample F-distribution random values with given shape and float dtype.
The values are distributed according to the probability density function:
.. math::
f(x; \nu_1, \nu_2) \propto x^{\nu_1/2 - 1}\left(1 + \frac{\nu_1}{\nu_2}x\right)^{
-(\nu_1 + \nu_2) / 2}
on the domain :math:`0 < x < \infty`. Here :math:`\nu_1` is the degrees of
freedom of the numerator (``dfnum``), and :math:`\nu_2` is the degrees of
freedom of the denominator (``dfden``).
Args:
key: a PRNG key used as the random key.
dfnum: a float or array of floats broadcast-compatible with ``shape``
representing the numerator's ``df`` of the distribution.
dfden: a float or array of floats broadcast-compatible with ``shape``
representing the denominator's ``df`` of the distribution.
shape: optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with ``dfnum`` and ``dfden``.
The default (None) produces a result shape equal to ``dfnum.shape``,
and ``dfden.shape``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``df.shape``.
"""
key, _ = _check_prng_key("f", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError("dtype argument to `f` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = core.canonicalize_shape(shape)
return _f(key, dfnum, dfden, shape, dtype)
@partial(jit, static_argnums=(3, 4))
def _f(key, dfnum, dfden, shape, dtype) -> Array:
if shape is None:
shape = lax.broadcast_shapes(np.shape(dfden), np.shape(dfnum))
else:
_check_shape("f", shape, np.shape(dfden), np.shape(dfnum))
dfden = lax.convert_element_type(dfden, dtype)
dfnum = lax.convert_element_type(dfnum, dtype)
key_dfd, key_dfn = _split(key)
chi2_dfn = chisquare(key_dfn, dfnum, shape, dtype)
chi2_dfd = chisquare(key_dfd, dfden, shape, dtype)
# broadcast dfden and dfnum to do div operation
dfden = jnp.broadcast_to(dfden, shape)
dfnum = jnp.broadcast_to(dfnum, shape)
num = lax.div(chi2_dfn, dfnum)
den = lax.div(chi2_dfd ,dfden)
f = lax.div(num, den)
return f
def rademacher(key: ArrayLike,
shape: Shape = (),
dtype: DTypeLikeInt = int) -> Array:
r"""Sample from a Rademacher distribution.
The values are distributed according to the probability mass function:
.. math::
f(k) = \frac{1}{2}(\delta(k - 1) + \delta(k + 1))
on the domain :math:`k \in \{-1, 1\}`, where :math:`\delta(x)` is the dirac delta function.
Args:
key: a PRNG key.
shape: The shape of the returned samples. Default ().
dtype: The type used for samples.
Returns:
A jnp.array of samples, of shape `shape`. Each element in the output has
a 50% change of being 1 or -1.
"""
key, _ = _check_prng_key("rademacher", key)
dtypes.check_user_dtype_supported(dtype)
dtype = dtypes.canonicalize_dtype(dtype)
shape = core.canonicalize_shape(shape)
return _rademacher(key, shape, dtype)
@partial(jit, static_argnums=(1, 2))
def _rademacher(key, shape, dtype) -> Array:
bernoulli_samples = bernoulli(key=key, p=0.5, shape=shape).astype(dtype)
return (2 * bernoulli_samples - 1).astype(dtype)
def maxwell(key: ArrayLike,
shape: Shape = (),
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample from a one sided Maxwell distribution.
The values are distributed according to the probability density function:
.. math::
f(x) \propto x^2 e^{-x^2 / 2}
on the domain :math:`0 \le x < \infty`.
Args:
key: a PRNG key.
shape: The shape of the returned samples.
dtype: The type used for samples.
Returns:
A jnp.array of samples, of shape `shape`.
"""
# Generate samples using:
# sqrt(X^2 + Y^2 + Z^2), X,Y,Z ~N(0,1)
key, _ = _check_prng_key("maxwell", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `maxwell` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
shape = core.canonicalize_shape(shape)
return _maxwell(key, shape, dtype)
@partial(jit, static_argnums=(1, 2))
def _maxwell(key, shape, dtype) -> Array:
shape = shape + (3,)
norm_rvs = normal(key=key, shape=shape, dtype=dtype)
return jnp.linalg.norm(norm_rvs, axis=-1)
def double_sided_maxwell(key: ArrayLike,
loc: RealArray,
scale: RealArray,
shape: Shape = (),
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample from a double sided Maxwell distribution.
The values are distributed according to the probability density function:
.. math::
f(x;\mu,\sigma) \propto z^2 e^{-z^2 / 2}
where :math:`z = (x - \mu) / \sigma`, with the center :math:`\mu` specified by
``loc`` and the scale :math:`\sigma` specified by ``scale``.
Args:
key: a PRNG key.
loc: The location parameter of the distribution.
scale: The scale parameter of the distribution.
shape: The shape added to the parameters loc and scale broadcastable shape.
dtype: The type used for samples.
Returns:
A jnp.array of samples.
"""
key, _ = _check_prng_key("double_sided_maxwell", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `double_sided_maxwell` must be a float"
f" dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
shape = core.canonicalize_shape(shape)
return _double_sided_maxwell(key, loc, scale, shape, dtype)
@partial(jit, static_argnums=(3, 4))
def _double_sided_maxwell(key, loc, scale, shape, dtype) -> Array:
params_shapes = lax.broadcast_shapes(np.shape(loc), np.shape(scale))
if not shape:
shape = params_shapes
shape = shape + params_shapes
maxwell_key, rademacher_key = _split(key)
maxwell_rvs = maxwell(maxwell_key, shape=shape, dtype=dtype)
# Generate random signs for the symmetric variates.
random_sign = rademacher(rademacher_key, shape=shape, dtype=dtype)
assert random_sign.shape == maxwell_rvs.shape
return random_sign * maxwell_rvs * scale + loc
def weibull_min(key: ArrayLike,
scale: RealArray,
concentration: RealArray,
shape: Shape = (),
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample from a Weibull distribution.
The values are distributed according to the probability density function:
.. math::
f(x;\sigma,c) \propto x^{c - 1} \exp(-(x / \sigma)^c)
on the domain :math:`0 < x < \infty`, where :math:`c > 0` is the concentration
parameter, and :math:`\sigma > 0` is the scale parameter.
Args:
key: a PRNG key.
scale: The scale parameter of the distribution.
concentration: The concentration parameter of the distribution.
shape: The shape added to the parameters loc and scale broadcastable shape.
dtype: The type used for samples.
Returns:
A jnp.array of samples.
"""
key, _ = _check_prng_key("weibull_min", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `weibull_min` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
shape = core.canonicalize_shape(shape)
return _weibull_min(key, scale, concentration, shape, dtype)
@partial(jit, static_argnums=(3, 4))
def _weibull_min(key, scale, concentration, shape, dtype) -> Array:
random_uniform = uniform(
key=key, shape=shape, minval=0, maxval=1, dtype=dtype)
# Inverse weibull CDF.
return jnp.power(-jnp.log1p(-random_uniform), 1.0/concentration) * scale
def orthogonal(
key: ArrayLike,
n: int,
shape: Shape = (),
dtype: DTypeLikeFloat = float,
m: int | None = None,
) -> Array:
r"""Sample uniformly from the orthogonal group O(n).
If the dtype is complex, sample uniformly from the unitary group U(n).
For unequal rows and columns, this samples a semi-orthogonal matrix instead.
That is, if :math:`A` is the resulting matrix and :math:`A^*` is its conjugate
transpose, then:
- If :math:`n \leq m`, the rows are mutually orthonormal: :math:`A A^* = I_n`.
- If :math:`m \leq n`, the columns are mutually orthonormal: :math:`A^* A = I_m`.
Args:
key: a PRNG key used as the random key.
n: an integer indicating the number of rows.
shape: optional, the batch dimensions of the result. Default ().
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
m: an integer indicating the number of columns. Defaults to `n`.
Returns:
A random array of shape `(*shape, n, n)` and specified dtype.
References:
.. [1] Mezzadri, Francesco. (2007). "How to generate random matrices from
the classical compact groups". Notices of the American Mathematical
Society, 54(5), 592-604. https://arxiv.org/abs/math-ph/0609050.
"""
if m is None:
_m = n
else:
_m = m
shape = core.canonicalize_shape(shape)
key, _ = _check_prng_key("orthogonal", key)
dtypes.check_user_dtype_supported(dtype)
_check_shape("orthogonal", shape)
n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
_m = core.concrete_or_error(index, _m, "The error occurred in jax.random.orthogonal()")
z = normal(key, (*shape, max(n, _m), min(n, _m)), dtype)
q, r = jnp.linalg.qr(z)
d = jnp.linalg.diagonal(r)
x = q * jnp.expand_dims(jnp.sign(d), -2)
if n < _m:
return x.mT
else:
return x
def generalized_normal(
key: ArrayLike,
p: float,
shape: Shape = (),
dtype: DTypeLikeFloat = float
) -> Array:
r"""Sample from the generalized normal distribution.
The values are returned according to the probability density function:
.. math::
f(x;p) \propto e^{-|x|^p}
on the domain :math:`-\infty < x < \infty`, where :math:`p > 0` is the
shape parameter.
Args:
key: a PRNG key used as the random key.
p: a float representing the shape parameter.
shape: optional, the batch dimensions of the result. Default ().
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified shape and dtype.
"""
shape = core.canonicalize_shape(shape)
key, _ = _check_prng_key("generalized_normal", key)
dtypes.check_user_dtype_supported(dtype)
_check_shape("generalized_normal", shape)
keys = split(key)
g = gamma(keys[0], 1/p, shape, dtype)
r = rademacher(keys[1], shape, dtype)
return r * g ** (1 / p)
def ball(
key: ArrayLike,
d: int,
p: float = 2,
shape: Shape = (),
dtype: DTypeLikeFloat = float
):
"""Sample uniformly from the unit Lp ball.
Reference: https://arxiv.org/abs/math/0503650.
Args:
key: a PRNG key used as the random key.
d: a nonnegative int representing the dimensionality of the ball.
p: a float representing the p parameter of the Lp norm.
shape: optional, the batch dimensions of the result. Default ().
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array of shape `(*shape, d)` and specified dtype.
"""
shape = core.canonicalize_shape(shape)
key, _ = _check_prng_key("ball", key)
dtypes.check_user_dtype_supported(dtype)
_check_shape("ball", shape)
d = core.concrete_or_error(index, d, "The error occurred in jax.random.ball()")
k1, k2 = split(key)
g = generalized_normal(k1, p, (*shape, d), dtype)
e = exponential(k2, shape, dtype)
return g / (((jnp.abs(g) ** p).sum(-1) + e) ** (1 / p))[..., None]
def rayleigh(key: ArrayLike,
scale: RealArray,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Rayleigh random values with given shape and float dtype.
The values are returned according to the probability density function:
.. math::
f(x;\sigma) \propto xe^{-x^2/(2\sigma^2)}
on the domain :math:`-\infty < x < \infty`, and where :math:`\sigma > 0` is the scale
parameter of the distribution.
Args:
key: a PRNG key used as the random key.
scale: a float or array of floats broadcast-compatible with ``shape``
representing the parameter of the distribution.
shape: optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with ``scale``. The default (None)
produces a result shape equal to ``scale.shape``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``scale.shape``.
"""
key, _ = _check_prng_key("rayleigh", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError("dtype argument to `rayleigh` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = core.canonicalize_shape(shape)
return _rayleigh(key, scale, shape, dtype)
@partial(jit, static_argnums=(2, 3))
def _rayleigh(key, scale, shape, dtype) -> Array:
if shape is None:
shape = np.shape(scale)
else:
_check_shape("rayleigh", shape, np.shape(scale))
u = uniform(key, shape, dtype)
scale = scale.astype(dtype)
scale = jnp.broadcast_to(scale, shape)
log_u = lax.log(u)
n_two = _lax_const(scale, -2)
sqrt_u = lax.sqrt(lax.mul(log_u, n_two))
ray = lax.mul(scale, sqrt_u)
return ray
def wald(key: ArrayLike,
mean: RealArray,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Wald random values with given shape and float dtype.
The values are returned according to the probability density function:
.. math::
f(x;\mu) = \frac{1}{\sqrt{2\pi x^3}} \exp\left(-\frac{(x - \mu)^2}{2\mu^2 x}\right)
on the domain :math:`-\infty < x < \infty`, and where :math:`\mu > 0` is the location
parameter of the distribution.
Args:
key: a PRNG key used as the random key.
mean: a float or array of floats broadcast-compatible with ``shape``
representing the mean parameter of the distribution.
shape: optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with ``mean``. The default
(None) produces a result shape equal to ``np.shape(mean)``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``mean.shape``.
"""
key, _ = _check_prng_key("wald", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError("dtype argument to `wald` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = core.canonicalize_shape(shape)
return _wald(key, mean, shape, dtype)
@partial(jit, static_argnums=(2, 3))
def _wald(key, mean, shape, dtype) -> Array:
if shape is None:
shape = np.shape(mean)
else:
_check_shape("wald", shape, np.shape(mean))
k1, k2 = _split(key, 2)
mean = mean.astype(dtype)
mean = jnp.broadcast_to(mean, shape)
v = normal(k1, shape, dtype)
z = uniform(k2, shape, dtype)
y = lax.integer_pow(v, 2)
y_sq = lax.integer_pow(y, 2)
mean_sq = lax.integer_pow(mean, 2)
sqrt_term = lax.sqrt(4 * mean * y + mean_sq * y_sq)
x = mean + mean_sq * y / 2 - mean / 2 * sqrt_term
w = lax.select(lax.le(z, mean / (mean + x)), x, mean_sq / x)
return w
def geometric(key: ArrayLike,
p: RealArray,
shape: Shape | None = None,
dtype: DTypeLikeInt = int) -> Array:
r"""Sample Geometric random values with given shape and float dtype.
The values are returned according to the probability mass function:
.. math::
f(k;p) = p(1-p)^{k-1}
on the domain :math:`0 < p < 1`.
Args:
key: a PRNG key used as the random key.
p: a float or array of floats broadcast-compatible with ``shape``
representing the probability of success of an individual trial.
shape: optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with ``p``. The default
(None) produces a result shape equal to ``np.shape(p)``.
dtype: optional, a int dtype for the returned values (default int64 if
jax_enable_x64 is true, otherwise int32).
Returns:
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``p.shape``.
"""
key, _ = _check_prng_key("geometric", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.integer):
raise ValueError("dtype argument to `geometric` must be an int "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = core.canonicalize_shape(shape)
return _geometric(key, p, shape, dtype)
@partial(jit, static_argnums=(2, 3))
def _geometric(key, p, shape, dtype) -> Array:
if shape is None:
shape = np.shape(p)
else:
_check_shape("geometric", shape, np.shape(p))
check_arraylike("geometric", p)
p, = promote_dtypes_inexact(p)
u = uniform(key, shape, p.dtype)
log_u = lax.log(u)
log_one_minus_p = lax.log1p(-p)
log_one_minus_p = jnp.broadcast_to(log_one_minus_p, shape)
g = lax.floor(lax.div(log_u, log_one_minus_p)) + 1
return g.astype(dtype)
def triangular(key: ArrayLike,
left: RealArray,
mode: RealArray,
right: RealArray,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Triangular random values with given shape and float dtype.
The values are returned according to the probability density function:
.. math::
f(x; a, b, c) = \frac{2}{c-a} \left\{ \begin{array}{ll} \frac{x-a}{b-a} & a \leq x \leq b \\ \frac{c-x}{c-b} & b \leq x \leq c \end{array} \right.
on the domain :math:`a \leq x \leq c`.
Args:
key: a PRNG key used as the random key.
left: a float or array of floats broadcast-compatible with ``shape``
representing the lower limit parameter of the distribution.
mode: a float or array of floats broadcast-compatible with ``shape``
representing the peak value parameter of the distribution, value must
fulfill the condition ``left <= mode <= right``.
right: a float or array of floats broadcast-compatible with ``shape``
representing the upper limit parameter of the distribution, must be
larger than ``left``.
shape: optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with ``left``,``mode`` and ``right``.
The default (None) produces a result shape equal to ``left.shape``, ``mode.shape``
and ``right.shape``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``left.shape``, ``mode.shape`` and ``right.shape``.
"""
key, _ = _check_prng_key("triangular", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError("dtype argument to `triangular` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = core.canonicalize_shape(shape)
return _triangular(key, left, mode, right, shape, dtype)
@partial(jit, static_argnums=(4, 5), inline=True)
def _triangular(key, left, mode, right, shape, dtype) -> Array:
# https://en.wikipedia.org/wiki/Triangular_distribution#Generating_triangular-distributed_random_variates
if shape is None:
shape = lax.broadcast_shapes(np.shape(left), np.shape(mode), np.shape(right))
else:
_check_shape("triangular", shape, np.shape(left), np.shape(mode), np.shape(right))
left = jnp.broadcast_to(left, shape)
mode = jnp.broadcast_to(mode, shape)
right = jnp.broadcast_to(right, shape)
fc = (mode - left) / (right - left)
u = uniform(key, shape, dtype)
out1 = left + lax.sqrt(u * (right - left) * (mode - left))
out2 = right - lax.sqrt((1 - u) * (right - left) * (right - mode))
tri = lax.select(u < fc, out1, out2)
return tri
def lognormal(key: ArrayLike,
sigma: RealArray = np.float32(1),
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
r""" Sample lognormal random values with given shape and float dtype.
The values are distributed according to the probability density function:
.. math::
f(x) = \frac{1}{x\sqrt{2\pi\sigma^2}}\exp\left(-\frac{(\log x)^2}{2\sigma^2}\right)
on the domain :math:`x > 0`.
Args:
key: a PRNG key used as the random key.
sigma: a float or array of floats broadcast-compatible with ``shape`` representing
the standard deviation of the underlying normal distribution. Default 1.
shape: optional, a tuple of nonnegative integers specifying the result
shape. The default (None) produces a result shape equal to ``()``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified dtype and with shape given by ``shape``.
"""
key, _ = _check_prng_key("lognormal", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.inexact):
raise ValueError(f"dtype argument to `lognormal` must be a float or complex dtype, "
f"got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = core.canonicalize_shape(shape)
return _lognormal(key, sigma, shape, dtype)
@partial(jit, static_argnums=(2, 3), inline=True)
def _lognormal(key, sigma, shape, dtype) -> Array:
if shape is None:
shape = np.shape(sigma)
else:
_check_shape("triangular", shape, np.shape(sigma))
sigma = jnp.broadcast_to(sigma, shape)
scaled_norm = normal(key, shape, dtype) * sigma
return lax.exp(scaled_norm)
def _stirling_approx_tail(k):
stirling_tail_vals = jnp.array(
[
0.0810614667953272,
0.0413406959554092,
0.0276779256849983,
0.02079067210376509,
0.0166446911898211,
0.0138761288230707,
0.0118967099458917,
0.0104112652619720,
0.00925546218271273,
0.00833056343336287,
],
dtype=k.dtype,
)
use_tail_values = k <= 9
k = lax.clamp(_lax_const(k, 0.0), k, _lax_const(k, 9.0))
kp1sq = (k + 1) * (k + 1)
approx = (1.0 / 12 - (1.0 / 360 - 1.0 / 1260 / kp1sq) / kp1sq) / (k + 1)
k = jnp.floor(k)
return lax.select(use_tail_values, stirling_tail_vals[jnp.int32(k)], approx)
@partial(jit, static_argnums=(3, 4, 5), inline=True)
def _binomial_inversion(key, count, prob, shape, dtype, max_iters):
if config.enable_checks.value:
assert jnp.issubdtype(prob.dtype, jnp.floating)
log1minusprob = jnp.log1p(-prob)
def body_fn(carry):
i, num_geom, geom_sum, key = carry
subkey, key = split(key)
num_geom_out = lax.select(geom_sum <= count, num_geom + 1, num_geom)
u = uniform(subkey, shape, prob.dtype)
geom = jnp.ceil(jnp.log(u) / log1minusprob)
geom_sum = geom_sum + geom
return i + 1, num_geom_out, geom_sum, key
def cond_fn(carry):
i, geom_sum = carry[0], carry[2]
return (geom_sum <= count).any() & (i < max_iters)
num_geom_init = lax.full_like(prob, 0, prob.dtype, shape)
geom_sum_init = lax.full_like(prob, 0, prob.dtype, shape)
carry = (0, num_geom_init, geom_sum_init, key)
k = lax.while_loop(cond_fn, body_fn, carry)[1]
return (k - 1).astype(dtype)
@partial(jit, static_argnums=(3, 4, 5), inline=True)
def _btrs(key, count, prob, shape, dtype, max_iters):
# transforman-rejection algorithm
# https://www.tandfonline.com/doi/abs/10.1080/00949659308811496
stddev = jnp.sqrt(count * prob * (1 - prob))
b = 1.15 + 2.53 * stddev
a = -0.0873 + 0.0248 * b + 0.01 * prob
c = count * prob + 0.5
v_r = 0.92 - 4.2 / b
r = prob / (1 - prob)
alpha = (2.83 + 5.1 / b) * stddev
m = jnp.floor((count + 1) * prob)
def body_fn(carry):
i, k_out, accepted, key = carry
key, subkey_0, subkey_1 = split(key, 3)
u = uniform(subkey_0, shape, prob.dtype)
v = uniform(subkey_1, shape, prob.dtype)
u = u - 0.5
us = 0.5 - jnp.abs(u)
accept1 = (us >= 0.07) & (v <= v_r)
k = jnp.floor((2 * a / us + b) * u + c)
reject = (k < 0) | (k > count)
v = jnp.log(v * alpha / (a / (us * us) + b))
ub = (
(m + 0.5) * jnp.log((m + 1) / (r * (count - m + 1)))
+ (count + 1) * jnp.log((count - m + 1) / (count - k + 1))
+ (k + 0.5) * jnp.log(r * (count - k + 1) / (k + 1))
+ _stirling_approx_tail(m)
+ _stirling_approx_tail(count - m)
- _stirling_approx_tail(k)
- _stirling_approx_tail(count - k)
)
accept2 = v <= ub
accept = accept1 | (~reject & accept2)
k_out = lax.select(accept, k, k_out)
accepted |= accept
return i + 1, k_out, accepted, key
def cond_fn(carry):
i, accepted = carry[0], carry[2]
return (~accepted).any() & (i < max_iters)
k_init = lax.full_like(prob, -1, prob.dtype, shape)
carry = (0, k_init, jnp.full(shape, False, jnp.bool_), key)
return lax.while_loop(cond_fn, body_fn, carry)[1].astype(dtype)
@partial(jit, static_argnums=(3, 4), inline=True)
def _binomial(key, count, prob, shape, dtype) -> Array:
# The implementation matches TensorFlow and TensorFlow Probability:
# https://github.com/tensorflow/tensorflow/blob/v2.2.0-rc3/tensorflow/core/kernels/random_binomial_op.cc
# and tensorflow_probability.substrates.jax.distributions.Binomial
# For n * p < 10, we use the binomial inverse algorithm; otherwise btrs.
if shape is None:
shape = jnp.broadcast_shapes(jnp.shape(count), jnp.shape(prob))
else:
_check_shape("binomial", shape, np.shape(count), np.shape(prob))
(prob,) = promote_dtypes_inexact(prob)
count = lax.convert_element_type(count, prob.dtype)
count = jnp.broadcast_to(count, shape)
prob = jnp.broadcast_to(prob, shape)
p_lt_half = prob < 0.5
q = lax.select(p_lt_half, prob, 1.0 - prob)
count_nan_or_neg = _isnan(count) | (count < 0.0)
count_inf = jnp.isinf(count)
q_is_nan = _isnan(q)
q_l_0 = q < 0.0
q = lax.select(q_is_nan | q_l_0, lax.full_like(q, 0.01), q)
use_inversion = count_nan_or_neg | (count * q <= 10.0)
# consistent with np.random.binomial behavior for float count input
count = jnp.floor(count)
count_inv = lax.select(use_inversion, count, lax.full_like(count, 0.0))
count_btrs = lax.select(use_inversion, lax.full_like(count, 1e4), count)
q_btrs = lax.select(use_inversion, lax.full_like(q, 0.5), q)
max_iters = dtype.type(jnp.finfo(dtype).max)
samples = lax.select(
use_inversion,
_binomial_inversion(key, count_inv, q, shape, dtype, max_iters),
_btrs(key, count_btrs, q_btrs, shape, dtype, max_iters),
)
# ensure nan q always leads to nan output and nan or neg count leads to nan
# as discussed in https://github.com/jax-ml/jax/pull/16134#pullrequestreview-1446642709
invalid = (q_l_0 | q_is_nan | count_nan_or_neg)
samples = lax.select(
invalid,
jnp.full_like(samples, jnp.nan, dtype),
samples,
)
# +inf count leads to inf
samples = lax.select(
count_inf & (~invalid),
jnp.full_like(samples, jnp.inf, dtype),
samples,
)
samples = lax.select(
p_lt_half | count_nan_or_neg | q_is_nan | count_inf,
samples,
count.astype(dtype) - samples,
)
return samples
def binomial(
key: Array,
n: RealArray,
p: RealArray,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float,
) -> Array:
r"""Sample Binomial random values with given shape and float dtype.
The values are returned according to the probability mass function:
.. math::
f(k;n,p) = \binom{n}{k}p^k(1-p)^{n-k}
on the domain :math:`0 < p < 1`, and where :math:`n` is a nonnegative integer
representing the number of trials and :math:`p` is a float representing the
probability of success of an individual trial.
Args:
key: a PRNG key used as the random key.
n: a float or array of floats broadcast-compatible with ``shape``
representing the number of trials.
p: a float or array of floats broadcast-compatible with ``shape``
representing the probability of success of an individual trial.
shape: optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with ``n`` and ``p``.
The default (None) produces a result shape equal to ``np.broadcast(n, p).shape``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified dtype and with shape given by
``np.broadcast(n, p).shape``.
"""
key, _ = _check_prng_key("binomial", key)
check_arraylike("binomial", n, p)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(
f"dtype argument to `binomial` must be a float dtype, got {dtype}"
)
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = core.canonicalize_shape(shape)
return _binomial(key, n, p, shape, dtype)
# Functions related to key reuse checking
random_clone_p = core.Primitive("random_clone")
dispatch.simple_impl(random_clone_p)
random_clone_p.def_abstract_eval(lambda x: x)
batching.defvectorized(random_clone_p)
mlir.register_lowering(random_clone_p, lambda _, k: [k])
def multinomial(
key: Array,
n: RealArray,
p: RealArray,
*,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float,
unroll: int | bool = 1,
):
r"""Sample from a multinomial distribution.
The probability mass function is
.. math::
f(x;n,p) = \frac{n!}{x_1! \ldots x_k!} p_1^{x_1} \ldots p_k^{x_k}
Args:
key: PRNG key.
n: number of trials. Should have shape broadcastable to ``p.shape[:-1]``.
p: probability of each outcome, with outcomes along the last axis.
shape: optional, a tuple of nonnegative integers specifying the result batch
shape, that is, the prefix of the result shape excluding the last axis.
Must be broadcast-compatible with ``p.shape[:-1]``. The default (None)
produces a result shape equal to ``p.shape``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
unroll: optional, unroll parameter passed to :func:`jax.lax.scan` inside the
implementation of this function.
Returns:
An array of counts for each outcome with the specified dtype and with shape
``p.shape`` if ``shape`` is None, otherwise ``shape + (p.shape[-1],)``.
"""
key, _ = _check_prng_key("multinomial", key)
check_arraylike("multinomial", n, p)
n, p = promote_dtypes_inexact(n, p)
if shape is None:
shape = p.shape
n = jnp.broadcast_to(n, shape[:-1])
p = jnp.broadcast_to(p, shape)
def f(remainder, ratio_key):
ratio, key = ratio_key
count = binomial(key, remainder, ratio.clip(0, 1), dtype=remainder.dtype)
return remainder - count, count
p = jnp.moveaxis(p, -1, 0)
remaining_probs = lax.cumsum(p, 0, reverse=True)
ratios = p / jnp.where(remaining_probs == 0, 1, remaining_probs)
keys = split(key, ratios.shape[0])
remainder, counts = lax.scan(f, n, (ratios, keys), unroll=unroll)
# final remainder should be zero
return jnp.moveaxis(counts, 0, -1).astype(dtype)
def clone(key):
"""Clone a key for reuse
Outside the context of key reuse checking (see :mod:`jax.experimental.key_reuse`)
this function operates as an identity.
Examples:
>>> import jax
>>> key = jax.random.key(0)
>>> data = jax.random.uniform(key)
>>> cloned_key = jax.random.clone(key)
>>> same_data = jax.random.uniform(cloned_key)
>>> assert data == same_data
"""
return random_clone_p.bind(key)