2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2018 The JAX Authors.
|
2020-11-19 09:22:31 -08:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
from functools import partial
|
2023-02-28 12:40:30 -08:00
|
|
|
import math
|
2022-04-29 14:20:50 -04:00
|
|
|
from operator import index
|
2023-02-28 12:40:30 -08:00
|
|
|
from typing import Optional, Sequence, Union
|
2020-11-19 09:22:31 -08:00
|
|
|
import warnings
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
2022-12-05 12:43:05 -08:00
|
|
|
import jax
|
2022-12-16 20:59:41 -08:00
|
|
|
import jax.numpy as jnp
|
2020-11-19 09:22:31 -08:00
|
|
|
from jax import lax
|
2022-12-16 20:59:41 -08:00
|
|
|
from jax.config import config
|
|
|
|
from jax.interpreters import mlir
|
|
|
|
from jax.numpy.linalg import cholesky, svd, eigh
|
2023-02-06 22:51:50 -08:00
|
|
|
|
2022-12-16 20:59:41 -08:00
|
|
|
from jax._src import core
|
2021-04-07 19:35:17 -07:00
|
|
|
from jax._src import dtypes
|
2021-06-08 11:16:33 -07:00
|
|
|
from jax._src import prng
|
2023-02-28 07:01:14 -08:00
|
|
|
from jax._src import xla_bridge
|
2021-04-13 09:42:54 -07:00
|
|
|
from jax._src.api import jit, vmap
|
2022-12-16 20:59:41 -08:00
|
|
|
from jax._src.core import NamedShape
|
2023-02-06 22:51:50 -08:00
|
|
|
from jax._src.interpreters import ad
|
2023-02-09 15:11:20 -08:00
|
|
|
from jax._src.interpreters import batching
|
2022-03-07 12:25:01 -08:00
|
|
|
from jax._src.lax import lax as lax_internal
|
2023-03-08 10:29:04 -08:00
|
|
|
from jax._src.numpy.lax_numpy import _convert_and_clip_integer
|
2023-03-13 12:18:36 -07:00
|
|
|
from jax._src.numpy.util import _arraylike, check_arraylike, promote_dtypes_inexact
|
2022-09-30 19:37:48 -07:00
|
|
|
from jax._src.typing import Array, ArrayLike, DTypeLike
|
2023-02-28 12:40:30 -08:00
|
|
|
from jax._src.util import canonicalize_axis
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
|
2022-09-30 19:37:48 -07:00
|
|
|
RealArray = ArrayLike
|
|
|
|
IntegerArray = ArrayLike
|
2021-03-17 16:37:09 -04:00
|
|
|
# TODO: Import or define these to match
|
|
|
|
# https://github.com/numpy/numpy/blob/main/numpy/typing/_dtype_like.py.
|
2022-09-30 19:37:48 -07:00
|
|
|
DTypeLikeInt = DTypeLike
|
|
|
|
DTypeLikeFloat = DTypeLike
|
|
|
|
Shape = Sequence[int]
|
2021-03-17 16:37:09 -04:00
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
# TODO(frostig): simplify once we always enable_custom_prng
|
|
|
|
KeyArray = Union[Array, prng.PRNGKeyArray]
|
|
|
|
|
2021-06-08 11:16:33 -07:00
|
|
|
UINT_DTYPES = prng.UINT_DTYPES
|
|
|
|
|
|
|
|
|
|
|
|
### utilities
|
|
|
|
|
2022-03-07 12:25:01 -08:00
|
|
|
_lax_const = lax_internal._const
|
|
|
|
|
2022-09-30 19:37:48 -07:00
|
|
|
def _isnan(x: ArrayLike) -> Array:
|
2022-02-28 12:10:47 -08:00
|
|
|
return lax.ne(x, x)
|
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def _check_prng_key(key):
|
|
|
|
# TODO(frostig): remove once we always enable_custom_prng
|
2022-08-22 13:56:50 -07:00
|
|
|
if isinstance(key, prng.PRNGKeyArray):
|
2021-08-15 08:09:30 -07:00
|
|
|
return key, False
|
introduce key-element-type arrays and overhaul the Python PRNG key array type
Before this change, the Python PRNG key array was a pytree type
wrapping a `uint32` array. This was a stopgap that misbehaved under
`vmap`, `scan`, and even `jax.tree_map`. For a while, we thought we
might rely on something like the typeclass mechanisms in development
(e.g. `vmappable`) to move away from a pytree.
We're now taking a different approach: introducing key element types
into our IR and other internal machinery. During staging, we map
user-facing PRNG key arrays to abstract arrays such element type.
This leans heavily on our recently-introduced extended element type
capabilities.
As a consequence, `vmap`, `scan`, etc. now work.
A sample of changes made to introduce key-element-type arrays:
* Introduce a new element type (`prng.KeyTy`), with the requisite IR
type mapping and device result handlers, as well as lowering rules
for dtype-polymorphic primitive operations.
* Introduce primitives for basic RNG operations: `random_seed`,
`random_bits`, `random_split`, `random_fold_in`. These primitives
essentially delegate to the underlying PRNG implementation (directly
so in their impl rules, and by translating their staged-out form in
lowering rules).
* Also introduce `random_wrap` and `random_unwrap` for "unsafe"
conversion from/to the base `uint32` array. We need this backwards
compatibility, and it's useful for tests.
* Introduce some `vmap`-based helpers to adapt PRNG impls (which
define basic `random_bits`, `split`, etc. on scalars) to the above
batch-polymorphic primitives. Most of the primitives are vectorized,
but `random_fold_in` is a broadcasting binary op.
* Update the `gamma` primitive rules to account for key-element-type
abstract arrays (nice simplification here).
* Give PRNG implementation short string names ("tags") for IR
pretty-printing.
* Update `lax.stop_gradient` to handle opaque dtypes.
* Fix up loop MLIR lowering, which assumed that shaped arrays of all
dtypes have the same physical shape.
* Add new tests (exercising staging, jaxprs, lowerings, ...)
A sample of changes made to rework Python-level PRNG key arrays:
* Mimic `isinstance(x, KeyArray)` checks on abstract key arrays and
tracers that carry them.
* Patch (only a subset of) standard device array attributes onto PRNG
key arrays.
* Implement various conversion handlers (sharding, constant-creation,
`device_put`).
* Accept PRNG key arrays as input to `lax_numpy.transpose`.
* Update tests and rename some internals.
A sample of extra changes along the way:
* Disallow AD on key-typed arrays in the main API.
* Hoist `random_bits`'s named-shape-handling logic, which used to only
take place in the threefry PRNG's `random_bits` implementation, up
to the new `random_bits` traceable, so that we apply it consistently
across PRNG implementations.
This change leaves some unwanted `lax` and `jax.numpy` operations
superficially available on key arrays during tracing/staging
(e.g. under `jit`), though not outside of it. We ultimately want to
disallow these and raise useful errors, and I'm leaving that for
follow-up work. For now, applying such operations under `jit` may
result in downstream errors in the middle-end instead.
Everything here is still guarded by `config.jax_enable_custom_prng`,
whose default setting hasn't changed (it is off).
2022-08-10 06:06:19 -07:00
|
|
|
elif _arraylike(key):
|
|
|
|
if config.jax_enable_custom_prng:
|
|
|
|
warnings.warn(
|
|
|
|
'Raw arrays as random keys to jax.random functions are deprecated. '
|
|
|
|
'Assuming valid threefry2x32 key for now.',
|
|
|
|
FutureWarning)
|
2022-08-22 13:56:50 -07:00
|
|
|
return prng.random_wrap(key, impl=default_prng_impl()), True
|
2021-08-15 08:09:30 -07:00
|
|
|
else:
|
|
|
|
raise TypeError(f'unexpected PRNG key type {type(key)}')
|
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def _return_prng_keys(was_wrapped, key):
|
|
|
|
# TODO(frostig): remove once we always enable_custom_prng
|
2022-08-22 13:56:50 -07:00
|
|
|
assert isinstance(key, prng.PRNGKeyArray)
|
2021-08-15 08:09:30 -07:00
|
|
|
if config.jax_enable_custom_prng:
|
|
|
|
return key
|
|
|
|
else:
|
2022-08-22 13:56:50 -07:00
|
|
|
return prng.random_unwrap(key) if was_wrapped else key
|
|
|
|
|
2021-06-08 11:16:33 -07:00
|
|
|
|
2022-09-30 19:37:48 -07:00
|
|
|
def _random_bits(key: prng.PRNGKeyArray, bit_width, shape) -> Array:
|
2022-08-22 13:56:50 -07:00
|
|
|
assert isinstance(key, prng.PRNGKeyArray)
|
|
|
|
return prng.random_bits(key, bit_width=bit_width, shape=shape)
|
2021-06-08 11:16:33 -07:00
|
|
|
|
|
|
|
|
2021-10-07 19:15:43 -07:00
|
|
|
PRNG_IMPLS = {
|
|
|
|
'threefry2x32': prng.threefry_prng_impl,
|
|
|
|
'rbg': prng.rbg_prng_impl,
|
|
|
|
'unsafe_rbg': prng.unsafe_rbg_prng_impl,
|
|
|
|
}
|
|
|
|
|
2022-01-12 19:13:14 -08:00
|
|
|
def default_prng_impl():
|
|
|
|
"""Get the default PRNG implementation.
|
|
|
|
|
|
|
|
The default implementation is determined by ``config.jax_default_prng_impl``,
|
|
|
|
which specifies it by name. This function returns the corresponding
|
|
|
|
``jax.prng.PRNGImpl`` instance.
|
|
|
|
"""
|
2021-10-07 19:15:43 -07:00
|
|
|
impl_name = config.jax_default_prng_impl
|
|
|
|
assert impl_name in PRNG_IMPLS, impl_name
|
|
|
|
return PRNG_IMPLS[impl_name]
|
|
|
|
|
|
|
|
|
2021-06-08 11:16:33 -07:00
|
|
|
### key operations
|
|
|
|
|
|
|
|
|
2023-03-14 08:32:21 -07:00
|
|
|
def PRNGKey(seed: Union[int, Array]) -> KeyArray:
|
2021-06-08 11:16:33 -07:00
|
|
|
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
|
|
|
|
|
2021-10-07 19:15:43 -07:00
|
|
|
The resulting key carries the default PRNG implementation, as
|
|
|
|
determined by the ``jax_default_prng_impl`` config flag.
|
|
|
|
|
2021-06-08 11:16:33 -07:00
|
|
|
Args:
|
|
|
|
seed: a 64- or 32-bit integer used as the value of the key.
|
|
|
|
|
|
|
|
Returns:
|
2021-08-15 08:09:30 -07:00
|
|
|
A PRNG key, consumable by random functions as well as ``split``
|
|
|
|
and ``fold_in``.
|
2021-10-07 19:15:43 -07:00
|
|
|
|
2021-06-08 11:16:33 -07:00
|
|
|
"""
|
2022-01-12 19:13:14 -08:00
|
|
|
impl = default_prng_impl()
|
2022-09-01 12:44:53 -07:00
|
|
|
if np.ndim(seed):
|
|
|
|
raise TypeError("PRNGKey accepts a scalar seed, but was given an array of"
|
|
|
|
f"shape {np.shape(seed)} != (). Use jax.vmap for batching")
|
2021-10-07 19:15:43 -07:00
|
|
|
key = prng.seed_with_impl(impl, seed)
|
|
|
|
return _return_prng_keys(True, key)
|
|
|
|
|
|
|
|
# TODO(frostig): remove once we always enable_custom_prng
|
|
|
|
def _check_default_impl_with_no_custom_prng(impl, name):
|
2022-01-12 19:13:14 -08:00
|
|
|
default_impl = default_prng_impl()
|
2021-10-07 19:15:43 -07:00
|
|
|
default_name = config.jax_default_prng_impl
|
|
|
|
if not config.jax_enable_custom_prng and default_impl is not impl:
|
|
|
|
raise RuntimeError('jax_enable_custom_prng must be enabled in order '
|
|
|
|
f'to seed an RNG with an implementation "f{name}" '
|
|
|
|
f'differing from the default "f{default_name}".')
|
|
|
|
|
|
|
|
def threefry2x32_key(seed: int) -> KeyArray:
|
|
|
|
"""Creates a threefry2x32 PRNG key from an integer seed."""
|
|
|
|
impl = prng.threefry_prng_impl
|
|
|
|
_check_default_impl_with_no_custom_prng(impl, 'threefry2x32')
|
|
|
|
key = prng.seed_with_impl(impl, seed)
|
|
|
|
return _return_prng_keys(True, key)
|
|
|
|
|
|
|
|
def rbg_key(seed: int) -> KeyArray:
|
|
|
|
"""Creates an RBG PRNG key from an integer seed."""
|
|
|
|
impl = prng.rbg_prng_impl
|
|
|
|
_check_default_impl_with_no_custom_prng(impl, 'rbg')
|
|
|
|
key = prng.seed_with_impl(impl, seed)
|
|
|
|
return _return_prng_keys(True, key)
|
|
|
|
|
|
|
|
def unsafe_rbg_key(seed: int) -> KeyArray:
|
|
|
|
"""Creates an unsafe RBG PRNG key from an integer seed."""
|
|
|
|
impl = prng.unsafe_rbg_prng_impl
|
|
|
|
_check_default_impl_with_no_custom_prng(impl, 'unsafe_rbg')
|
|
|
|
key = prng.seed_with_impl(impl, seed)
|
|
|
|
return _return_prng_keys(True, key)
|
2021-06-08 11:16:33 -07:00
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def _fold_in(key: KeyArray, data: int) -> KeyArray:
|
|
|
|
# Alternative to fold_in() to use within random samplers.
|
|
|
|
# TODO(frostig): remove and use fold_in() once we always enable_custom_prng
|
2022-08-22 13:56:50 -07:00
|
|
|
assert isinstance(key, prng.PRNGKeyArray)
|
2022-09-01 12:44:53 -07:00
|
|
|
if key.ndim:
|
|
|
|
raise TypeError("fold_in accepts a single key, but was given a key array of"
|
|
|
|
f"shape {key.shape} != (). Use jax.vmap for batching.")
|
|
|
|
if np.ndim(data):
|
|
|
|
raise TypeError("fold_in accepts a scalar, but was given an array of"
|
|
|
|
f"shape {np.shape(data)} != (). Use jax.vmap for batching.")
|
2022-08-22 13:56:50 -07:00
|
|
|
return prng.random_fold_in(key, jnp.uint32(data))
|
2021-06-08 11:16:33 -07:00
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def fold_in(key: KeyArray, data: int) -> KeyArray:
|
2021-06-08 11:16:33 -07:00
|
|
|
"""Folds in data to a PRNG key to form a new PRNG key.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNG key (from ``PRNGKey``, ``split``, ``fold_in``).
|
|
|
|
data: a 32bit integer representing data to be folded in to the key.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A new PRNG key that is a deterministic function of the inputs and is
|
|
|
|
statistically safe for producing a stream of new pseudo-random values.
|
|
|
|
"""
|
2021-08-15 08:09:30 -07:00
|
|
|
key, wrapped = _check_prng_key(key)
|
|
|
|
return _return_prng_keys(wrapped, _fold_in(key, data))
|
2021-06-08 11:16:33 -07:00
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def _split(key: KeyArray, num: int = 2) -> KeyArray:
|
|
|
|
# Alternative to split() to use within random samplers.
|
2022-08-22 13:56:50 -07:00
|
|
|
# TODO(frostig): remove and use split(); we no longer need to wait
|
|
|
|
# to always enable_custom_prng
|
|
|
|
assert isinstance(key, prng.PRNGKeyArray)
|
2022-09-01 12:44:53 -07:00
|
|
|
if key.ndim:
|
|
|
|
raise TypeError("split accepts a single key, but was given a key array of"
|
|
|
|
f"shape {key.shape} != (). Use jax.vmap for batching.")
|
2022-08-22 13:56:50 -07:00
|
|
|
return prng.random_split(key, count=num)
|
2021-06-08 11:16:33 -07:00
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def split(key: KeyArray, num: int = 2) -> KeyArray:
|
2021-06-08 11:16:33 -07:00
|
|
|
"""Splits a PRNG key into `num` new keys by adding a leading axis.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNG key (from ``PRNGKey``, ``split``, ``fold_in``).
|
|
|
|
num: optional, a positive integer indicating the number of keys to produce
|
|
|
|
(default 2).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array-like object of `num` new PRNG keys.
|
|
|
|
"""
|
2021-08-15 08:09:30 -07:00
|
|
|
key, wrapped = _check_prng_key(key)
|
|
|
|
return _return_prng_keys(wrapped, _split(key, num))
|
2021-06-08 11:16:33 -07:00
|
|
|
|
2022-09-30 19:37:48 -07:00
|
|
|
def _key_data(keys: KeyArray) -> Array:
|
2022-08-30 18:05:01 -07:00
|
|
|
assert isinstance(keys, prng.PRNGKeyArray)
|
|
|
|
return prng.random_unwrap(keys)
|
|
|
|
|
2022-09-30 19:37:48 -07:00
|
|
|
def key_data(keys: KeyArray) -> Array:
|
2022-08-30 18:05:01 -07:00
|
|
|
keys, _ = _check_prng_key(keys)
|
|
|
|
return _key_data(keys)
|
|
|
|
|
2021-03-17 16:37:09 -04:00
|
|
|
|
2020-11-19 09:22:31 -08:00
|
|
|
### random samplers
|
|
|
|
|
|
|
|
|
2022-09-30 19:37:48 -07:00
|
|
|
def _check_shape(name: str, shape: Union[Shape, NamedShape], *param_shapes) -> None:
|
2021-02-04 12:38:12 +00:00
|
|
|
shape = core.as_named_shape(shape)
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
if param_shapes:
|
2021-02-04 12:38:12 +00:00
|
|
|
shape_ = lax.broadcast_shapes(shape.positional, *param_shapes)
|
|
|
|
if shape.positional != shape_:
|
2020-11-19 09:22:31 -08:00
|
|
|
msg = ("{} parameter shapes must be broadcast-compatible with shape "
|
|
|
|
"argument, and the result of broadcasting the shapes must equal "
|
|
|
|
"the shape argument, but got result {} for shape argument {}.")
|
|
|
|
raise ValueError(msg.format(name, shape_, shape))
|
|
|
|
|
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def uniform(key: KeyArray,
|
2022-09-30 19:37:48 -07:00
|
|
|
shape: Union[Shape, NamedShape] = (),
|
2021-03-17 16:37:09 -04:00
|
|
|
dtype: DTypeLikeFloat = dtypes.float_,
|
|
|
|
minval: RealArray = 0.,
|
2022-09-30 19:37:48 -07:00
|
|
|
maxval: RealArray = 1.) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
|
|
|
|
|
|
|
|
Args:
|
2021-06-08 11:16:33 -07:00
|
|
|
key: a PRNG key used as the random key.
|
2020-11-19 09:22:31 -08:00
|
|
|
shape: optional, a tuple of nonnegative integers representing the result
|
|
|
|
shape. Default ().
|
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
|
|
|
minval: optional, a minimum (inclusive) value broadcast-compatible with shape for the range (default 0).
|
|
|
|
maxval: optional, a maximum (exclusive) value broadcast-compatible with shape for the range (default 1).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
2021-08-15 08:09:30 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2020-11-19 09:22:31 -08:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `uniform` must be a float dtype, "
|
|
|
|
f"got {dtype}")
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2021-02-04 12:38:12 +00:00
|
|
|
shape = core.as_named_shape(shape)
|
2020-11-19 09:22:31 -08:00
|
|
|
return _uniform(key, shape, dtype, minval, maxval) # type: ignore
|
|
|
|
|
2021-08-20 13:43:38 -07:00
|
|
|
@partial(jit, static_argnums=(1, 2), inline=True)
|
2022-09-30 19:37:48 -07:00
|
|
|
def _uniform(key, shape, dtype, minval, maxval) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
_check_shape("uniform", shape)
|
|
|
|
if not jnp.issubdtype(dtype, np.floating):
|
|
|
|
raise TypeError("uniform only accepts floating point dtypes.")
|
|
|
|
|
|
|
|
minval = lax.convert_element_type(minval, dtype)
|
|
|
|
maxval = lax.convert_element_type(maxval, dtype)
|
2021-02-04 12:38:12 +00:00
|
|
|
minval = lax.broadcast_to_rank(minval, shape.positional_rank)
|
|
|
|
maxval = lax.broadcast_to_rank(maxval, shape.positional_rank)
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
finfo = jnp.finfo(dtype)
|
|
|
|
nbits, nmant = finfo.bits, finfo.nmant
|
|
|
|
|
|
|
|
if nbits not in (16, 32, 64):
|
|
|
|
raise TypeError("uniform only accepts 32- or 64-bit dtypes.")
|
|
|
|
|
2023-03-17 22:39:04 -07:00
|
|
|
rng_bits = nbits
|
2023-03-28 12:43:32 -07:00
|
|
|
if nmant < 8:
|
2023-03-17 22:39:04 -07:00
|
|
|
rng_bits = 8
|
|
|
|
bits = _random_bits(key, rng_bits, shape)
|
|
|
|
uint_dtype = UINT_DTYPES[nbits]
|
|
|
|
if rng_bits != nbits:
|
|
|
|
bits = lax.convert_element_type(bits, uint_dtype)
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
# The strategy here is to randomize only the mantissa bits with an exponent of
|
|
|
|
# 1 (after applying the bias), then shift and scale to the desired range. The
|
|
|
|
# bit-level transformation we use relies on Numpy and XLA having bit-for-bit
|
|
|
|
# equivalent float representations, which might not be true on all platforms.
|
|
|
|
float_bits = lax.bitwise_or(
|
2023-03-17 22:39:04 -07:00
|
|
|
lax.shift_right_logical(bits, np.array(rng_bits - nmant, uint_dtype)),
|
|
|
|
np.array(1.0, dtype).view(uint_dtype),
|
|
|
|
)
|
2020-11-19 09:22:31 -08:00
|
|
|
floats = lax.bitcast_convert_type(float_bits, dtype) - np.array(1., dtype)
|
|
|
|
return lax.max(
|
|
|
|
minval,
|
2021-02-04 12:38:12 +00:00
|
|
|
lax.reshape(floats * (maxval - minval) + minval, shape.positional))
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def randint(key: KeyArray,
|
2022-09-30 19:37:48 -07:00
|
|
|
shape: Shape,
|
2021-03-17 16:37:09 -04:00
|
|
|
minval: IntegerArray,
|
|
|
|
maxval: IntegerArray,
|
2022-09-30 19:37:48 -07:00
|
|
|
dtype: DTypeLikeInt = dtypes.int_) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
|
|
|
|
|
|
|
|
Args:
|
2021-06-08 11:16:33 -07:00
|
|
|
key: a PRNG key used as the random key.
|
2020-11-19 09:22:31 -08:00
|
|
|
shape: a tuple of nonnegative integers representing the shape.
|
|
|
|
minval: int or array of ints broadcast-compatible with ``shape``, a minimum
|
|
|
|
(inclusive) value for the range.
|
|
|
|
maxval: int or array of ints broadcast-compatible with ``shape``, a maximum
|
|
|
|
(exclusive) value for the range.
|
|
|
|
dtype: optional, an int dtype for the returned values (default int64 if
|
|
|
|
jax_enable_x64 is true, otherwise int32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
2021-08-15 08:09:30 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2020-11-19 09:22:31 -08:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-11-18 21:17:02 -05:00
|
|
|
shape = core.canonicalize_shape(shape)
|
2020-11-19 09:22:31 -08:00
|
|
|
return _randint(key, shape, minval, maxval, dtype)
|
|
|
|
|
2021-08-20 13:43:38 -07:00
|
|
|
@partial(jit, static_argnums=(1, 4), inline=True)
|
2022-09-30 19:37:48 -07:00
|
|
|
def _randint(key, shape, minval, maxval, dtype) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
_check_shape("randint", shape, np.shape(minval), np.shape(maxval))
|
|
|
|
if not jnp.issubdtype(dtype, np.integer):
|
2021-03-31 15:49:03 -07:00
|
|
|
raise TypeError(f"randint only accepts integer dtypes, got {dtype}")
|
|
|
|
|
2023-03-13 12:18:36 -07:00
|
|
|
check_arraylike("randint", minval, maxval)
|
2021-06-08 13:37:21 -07:00
|
|
|
minval = jnp.asarray(minval)
|
|
|
|
maxval = jnp.asarray(maxval)
|
2021-03-31 15:49:03 -07:00
|
|
|
if not jnp.issubdtype(minval.dtype, np.integer):
|
|
|
|
minval = minval.astype(int)
|
|
|
|
if not jnp.issubdtype(maxval.dtype, np.integer):
|
|
|
|
maxval = maxval.astype(int)
|
|
|
|
|
|
|
|
# Flag where maxval is greater than the maximum value of dtype
|
|
|
|
# in order to handle cases like randint(key, shape, 0, 256, 'uint8')
|
|
|
|
maxval_out_of_range = lax.gt(
|
|
|
|
maxval, _convert_and_clip_integer(jnp.array(jnp.iinfo(dtype).max, dtype), maxval.dtype))
|
|
|
|
|
|
|
|
minval = _convert_and_clip_integer(minval, dtype)
|
|
|
|
maxval = _convert_and_clip_integer(maxval, dtype)
|
2020-11-19 09:22:31 -08:00
|
|
|
minval = lax.broadcast_to_rank(minval, len(shape))
|
|
|
|
maxval = lax.broadcast_to_rank(maxval, len(shape))
|
|
|
|
nbits = jnp.iinfo(dtype).bits
|
|
|
|
|
|
|
|
if nbits not in (8, 16, 32, 64):
|
2021-03-31 15:49:03 -07:00
|
|
|
raise TypeError(f"randint only accepts 8-, 16-, 32-, or 64-bit dtypes, got {dtype}")
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
# This algorithm is biased whenever (maxval - minval) is not a power of 2.
|
|
|
|
# We generate double the number of random bits required by the dtype so as to
|
|
|
|
# reduce that bias.
|
2021-08-15 08:09:30 -07:00
|
|
|
k1, k2 = _split(key)
|
2020-11-19 09:22:31 -08:00
|
|
|
rbits = lambda key: _random_bits(key, nbits, shape)
|
|
|
|
higher_bits, lower_bits = rbits(k1), rbits(k2)
|
|
|
|
|
2021-06-03 21:55:39 -07:00
|
|
|
unsigned_dtype = UINT_DTYPES[nbits]
|
2020-11-19 09:22:31 -08:00
|
|
|
span = lax.convert_element_type(maxval - minval, unsigned_dtype)
|
|
|
|
|
2021-03-31 15:49:03 -07:00
|
|
|
# Ensure that span=1 when maxval <= minval, so minval is always returned;
|
|
|
|
# https://github.com/google/jax/issues/222
|
|
|
|
span = lax.select(maxval <= minval, lax.full_like(span, 1), span)
|
|
|
|
|
|
|
|
# When maxval is out of range, the span has to be one larger.
|
|
|
|
# If span is already the maximum representable value, this will wrap to zero,
|
|
|
|
# causing remainders below to have no effect, which is the correct semantics.
|
|
|
|
span = lax.select(
|
|
|
|
maxval_out_of_range & (maxval > minval),
|
2022-03-07 12:25:01 -08:00
|
|
|
lax.add(span, _lax_const(span, 1)),
|
2021-03-31 15:49:03 -07:00
|
|
|
span)
|
|
|
|
|
2020-11-19 09:22:31 -08:00
|
|
|
# To compute a remainder operation on an integer that might have twice as many
|
|
|
|
# bits as we can represent in the native unsigned dtype, we compute a
|
|
|
|
# multiplier equal to 2**nbits % span. To avoid overflow, we use the identity:
|
|
|
|
# (a * b) % N = [(a % N) * (b % N)] % N
|
2022-03-07 12:25:01 -08:00
|
|
|
multiplier = lax.rem(_lax_const(span, 2 ** (nbits // 2)), span)
|
2020-11-19 09:22:31 -08:00
|
|
|
multiplier = lax.rem(lax.mul(multiplier, multiplier), span)
|
|
|
|
|
|
|
|
random_offset = lax.add(lax.mul(lax.rem(higher_bits, span), multiplier),
|
|
|
|
lax.rem(lower_bits, span))
|
|
|
|
random_offset = lax.rem(random_offset, span)
|
|
|
|
return lax.add(minval, lax.convert_element_type(random_offset, dtype))
|
|
|
|
|
|
|
|
|
2022-09-30 19:37:48 -07:00
|
|
|
def shuffle(key: KeyArray, x: ArrayLike, axis: int = 0) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
"""Shuffle the elements of an array uniformly at random along an axis.
|
|
|
|
|
|
|
|
Args:
|
2021-06-08 11:16:33 -07:00
|
|
|
key: a PRNG key used as the random key.
|
2020-11-19 09:22:31 -08:00
|
|
|
x: the array to be shuffled.
|
|
|
|
axis: optional, an int axis along which to shuffle (default 0).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A shuffled version of x.
|
|
|
|
"""
|
|
|
|
msg = ("jax.random.shuffle is deprecated and will be removed in a future release. "
|
2021-11-02 11:39:41 -07:00
|
|
|
"Use jax.random.permutation with independent=True.")
|
2020-11-19 09:22:31 -08:00
|
|
|
warnings.warn(msg, FutureWarning)
|
2021-08-15 08:09:30 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2020-11-19 09:22:31 -08:00
|
|
|
return _shuffle(key, x, axis) # type: ignore
|
|
|
|
|
|
|
|
|
2021-10-11 12:00:43 -06:00
|
|
|
def permutation(key: KeyArray,
|
2022-09-30 19:37:48 -07:00
|
|
|
x: Union[int, ArrayLike],
|
2021-11-02 11:39:41 -07:00
|
|
|
axis: int = 0,
|
2022-09-30 19:37:48 -07:00
|
|
|
independent: bool = False) -> Array:
|
2021-10-13 11:57:57 -06:00
|
|
|
"""Returns a randomly permuted array or range.
|
2020-11-19 09:22:31 -08:00
|
|
|
|
2021-10-11 12:00:43 -06:00
|
|
|
Args:
|
2021-06-08 11:16:33 -07:00
|
|
|
key: a PRNG key used as the random key.
|
2021-10-11 12:00:43 -06:00
|
|
|
x: int or array. If x is an integer, randomly shuffle np.arange(x).
|
|
|
|
If x is an array, randomly shuffle its elements.
|
|
|
|
axis: int, optional. The axis which x is shuffled along. Default is 0.
|
2021-11-02 11:39:41 -07:00
|
|
|
independent: bool, optional. If set to True, each individual vector along
|
|
|
|
the given axis is shuffled independently. Default is False.
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
A shuffled version of x or array range
|
|
|
|
"""
|
2021-08-15 08:09:30 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2023-03-13 12:18:36 -07:00
|
|
|
check_arraylike("permutation", x)
|
2021-11-23 16:54:02 -08:00
|
|
|
axis = canonicalize_axis(axis, np.ndim(x) or 1)
|
2020-11-19 09:22:31 -08:00
|
|
|
if not np.ndim(x):
|
|
|
|
if not np.issubdtype(lax.dtype(x), np.integer):
|
|
|
|
raise TypeError("x must be an integer or at least 1-dimensional")
|
2021-10-13 11:57:57 -06:00
|
|
|
r = core.concrete_or_error(int, x, 'argument x of jax.random.permutation()')
|
|
|
|
return _shuffle(key, jnp.arange(r), axis)
|
2021-11-02 11:39:41 -07:00
|
|
|
if independent or np.ndim(x) == 1:
|
2021-10-11 12:00:43 -06:00
|
|
|
return _shuffle(key, x, axis)
|
2021-10-13 11:57:57 -06:00
|
|
|
ind = _shuffle(key, jnp.arange(x.shape[axis]), 0) # type: ignore[union-attr]
|
2022-09-29 09:34:03 -07:00
|
|
|
return jnp.take(x, ind, axis, unique_indices=True)
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
|
2021-08-20 13:43:38 -07:00
|
|
|
@partial(jit, static_argnums=(2,), inline=True)
|
2022-09-30 19:37:48 -07:00
|
|
|
def _shuffle(key, x, axis) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
# On parallel architectures, Fisher-Yates is more expensive than doing
|
|
|
|
# multiple sorts. This algorithm is based on one developed and analyzed by
|
|
|
|
# tjablin@. We sort according to randomly-generated 32bit keys, but those keys
|
|
|
|
# may have collisions. If we repeat the process, using fresh 32bit keys for
|
|
|
|
# each sort, then whenever all pairs of elements have been assigned distinct
|
|
|
|
# keys at some iteration (or equivalently when the strings formed by
|
|
|
|
# concatenating the successive keys for each element are all distinct) then we
|
|
|
|
# are guaranteed to have a perfect sample (assuming that either the sort is
|
|
|
|
# stable or that any bias is not value-dependent). Since checking uniqueness
|
|
|
|
# at runtime may be expensive, we use a heuristic static stop criterion
|
|
|
|
# developed by tjablin@. See tensorflow/compiler/tf2xla/random_ops.cc for more
|
|
|
|
# info, and for the original implementation of this algorithm. See also
|
|
|
|
# Section 2 of http://people.csail.mit.edu/costis/6896sp11/lec5s.pdf for
|
|
|
|
# another analysis (where the keys are generated one bit at a time).
|
|
|
|
exponent = 3 # see tjablin@'s analysis for explanation of this parameter
|
|
|
|
uint32max = jnp.iinfo(np.uint32).max
|
2021-04-12 09:52:18 -07:00
|
|
|
num_rounds = int(np.ceil(exponent * np.log(max(1, x.size)) / np.log(uint32max)))
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
for _ in range(num_rounds):
|
2021-08-15 08:09:30 -07:00
|
|
|
key, subkey = _split(key)
|
2020-11-19 09:22:31 -08:00
|
|
|
sort_keys = _random_bits(subkey, 32, x.shape)
|
|
|
|
_, x = lax.sort_key_val(sort_keys, x, axis)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def choice(key: KeyArray,
|
2022-09-30 19:37:48 -07:00
|
|
|
a: Union[int, ArrayLike],
|
|
|
|
shape: Shape = (),
|
2021-03-17 16:37:09 -04:00
|
|
|
replace: bool = True,
|
2021-10-11 12:00:43 -06:00
|
|
|
p: Optional[RealArray] = None,
|
2022-09-30 19:37:48 -07:00
|
|
|
axis: int = 0) -> Array:
|
2021-10-11 12:00:43 -06:00
|
|
|
"""Generates a random sample from a given array.
|
2020-11-19 09:22:31 -08:00
|
|
|
|
2022-02-14 21:41:30 +01:00
|
|
|
.. warning::
|
|
|
|
If ``p`` has fewer non-zero elements than the requested number of samples,
|
|
|
|
as specified in ``shape``, and ``replace=False``, the output of this
|
|
|
|
function is ill-defined. Please make sure to use appropriate inputs.
|
|
|
|
|
2020-11-19 09:22:31 -08:00
|
|
|
Args:
|
2021-06-08 11:16:33 -07:00
|
|
|
key: a PRNG key used as the random key.
|
2021-10-11 12:00:43 -06:00
|
|
|
a : array or int. If an ndarray, a random sample is generated from
|
2020-11-19 09:22:31 -08:00
|
|
|
its elements. If an int, the random sample is generated as if a were
|
|
|
|
arange(a).
|
|
|
|
shape : tuple of ints, optional. Output shape. If the given shape is,
|
|
|
|
e.g., ``(m, n)``, then ``m * n`` samples are drawn. Default is (),
|
|
|
|
in which case a single value is returned.
|
|
|
|
replace : boolean. Whether the sample is with or without replacement.
|
|
|
|
default is True.
|
|
|
|
p : 1-D array-like, The probabilities associated with each entry in a.
|
|
|
|
If not given the sample assumes a uniform distribution over all
|
|
|
|
entries in a.
|
2021-10-11 12:00:43 -06:00
|
|
|
axis: int, optional. The axis along which the selection is performed.
|
|
|
|
The default, 0, selects by row.
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array of shape `shape` containing samples from `a`.
|
|
|
|
"""
|
2021-08-15 08:09:30 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2020-11-19 09:22:31 -08:00
|
|
|
if not isinstance(shape, Sequence):
|
|
|
|
raise TypeError("shape argument of jax.random.choice must be a sequence, "
|
|
|
|
f"got {shape}")
|
2023-03-13 12:18:36 -07:00
|
|
|
check_arraylike("choice", a)
|
2022-09-30 19:37:48 -07:00
|
|
|
arr = jnp.asarray(a)
|
|
|
|
if arr.ndim == 0:
|
|
|
|
n_inputs = core.concrete_or_error(int, a, "The error occurred in jax.random.choice()")
|
2020-11-19 09:22:31 -08:00
|
|
|
else:
|
2022-09-30 19:37:48 -07:00
|
|
|
axis = canonicalize_axis(axis, arr.ndim)
|
|
|
|
n_inputs = arr.shape[axis]
|
2023-02-28 12:40:30 -08:00
|
|
|
n_draws = math.prod(shape)
|
2020-11-19 09:22:31 -08:00
|
|
|
if n_draws == 0:
|
2022-09-30 19:37:48 -07:00
|
|
|
return jnp.zeros(shape, dtype=arr.dtype)
|
2020-11-19 09:22:31 -08:00
|
|
|
if n_inputs <= 0:
|
|
|
|
raise ValueError("a must be greater than 0 unless no samples are taken")
|
|
|
|
if not replace and n_draws > n_inputs:
|
|
|
|
raise ValueError("Cannot take a larger sample than population when 'replace=False'")
|
|
|
|
|
|
|
|
if p is None:
|
|
|
|
if replace:
|
|
|
|
ind = randint(key, shape, 0, n_inputs)
|
2022-09-30 19:37:48 -07:00
|
|
|
result = ind if arr.ndim == 0 else jnp.take(arr, ind, axis)
|
2020-11-19 09:22:31 -08:00
|
|
|
else:
|
2021-10-13 11:57:57 -06:00
|
|
|
slices = (slice(None),) * axis + (slice(n_draws),)
|
2022-09-30 19:37:48 -07:00
|
|
|
result = permutation(key, n_inputs if arr.ndim == 0 else arr, axis)[slices]
|
2020-11-19 09:22:31 -08:00
|
|
|
else:
|
2023-03-13 12:18:36 -07:00
|
|
|
check_arraylike("choice", p)
|
|
|
|
p_arr, = promote_dtypes_inexact(p)
|
2022-09-30 19:37:48 -07:00
|
|
|
if p_arr.shape != (n_inputs,):
|
2020-11-19 09:22:31 -08:00
|
|
|
raise ValueError("p must be None or match the shape of a")
|
|
|
|
if replace:
|
2022-09-30 19:37:48 -07:00
|
|
|
p_cuml = jnp.cumsum(p_arr)
|
2022-06-28 10:41:30 -07:00
|
|
|
r = p_cuml[-1] * (1 - uniform(key, shape, dtype=p_cuml.dtype))
|
2020-11-19 09:22:31 -08:00
|
|
|
ind = jnp.searchsorted(p_cuml, r)
|
|
|
|
else:
|
|
|
|
# Gumbel top-k trick: https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/
|
2022-09-30 19:37:48 -07:00
|
|
|
g = -gumbel(key, (n_inputs,), dtype=p_arr.dtype) - jnp.log(p_arr)
|
2020-11-19 09:22:31 -08:00
|
|
|
ind = jnp.argsort(g)[:n_draws]
|
2022-09-30 19:37:48 -07:00
|
|
|
result = ind if arr.ndim == 0 else jnp.take(arr, ind, axis)
|
2021-10-11 12:00:43 -06:00
|
|
|
|
2022-09-30 19:37:48 -07:00
|
|
|
return result.reshape(shape if arr.ndim == 0 else
|
|
|
|
np.insert(np.delete(arr.shape, axis), axis, shape))
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def normal(key: KeyArray,
|
2022-09-30 19:37:48 -07:00
|
|
|
shape: Union[Shape, NamedShape] = (),
|
|
|
|
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
2023-03-22 10:55:30 -07:00
|
|
|
r"""Sample standard normal random values with given shape and float dtype.
|
|
|
|
|
|
|
|
The values are returned according to the probability density function:
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
f(x) = \frac{1}{\sqrt{2\pi}}e^{-x^2/2}
|
|
|
|
|
|
|
|
on the domain :math:`-\infty < x < \infty`
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
Args:
|
2021-06-08 11:16:33 -07:00
|
|
|
key: a PRNG key used as the random key.
|
2020-11-19 09:22:31 -08:00
|
|
|
shape: optional, a tuple of nonnegative integers representing the result
|
|
|
|
shape. Default ().
|
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
2021-08-15 08:09:30 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2021-01-05 12:40:08 +01:00
|
|
|
if not dtypes.issubdtype(dtype, np.inexact):
|
|
|
|
raise ValueError(f"dtype argument to `normal` must be a float or complex dtype, "
|
2020-11-19 09:22:31 -08:00
|
|
|
f"got {dtype}")
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2021-02-04 12:38:12 +00:00
|
|
|
shape = core.as_named_shape(shape)
|
2020-11-19 09:22:31 -08:00
|
|
|
return _normal(key, shape, dtype) # type: ignore
|
|
|
|
|
2021-08-20 13:43:38 -07:00
|
|
|
@partial(jit, static_argnums=(1, 2), inline=True)
|
2022-09-30 19:37:48 -07:00
|
|
|
def _normal(key, shape, dtype) -> Array:
|
2021-01-05 12:40:08 +01:00
|
|
|
if dtypes.issubdtype(dtype, np.complexfloating):
|
|
|
|
sqrt2 = np.array(np.sqrt(2), dtype)
|
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
key_re, key_im = _split(key)
|
2021-05-07 10:14:34 -07:00
|
|
|
real_dtype = np.array(0, dtype).real.dtype
|
2022-05-27 11:12:39 -07:00
|
|
|
_re = _normal_real(key_re, shape, real_dtype).astype(dtype)
|
|
|
|
_im = _normal_real(key_im, shape, real_dtype).astype(dtype)
|
2021-03-14 22:23:05 +01:00
|
|
|
return (_re + 1j * _im) / sqrt2
|
2021-01-05 12:40:08 +01:00
|
|
|
else:
|
|
|
|
return _normal_real(key, shape, dtype) # type: ignore
|
|
|
|
|
2021-08-20 13:43:38 -07:00
|
|
|
@partial(jit, static_argnums=(1, 2), inline=True)
|
2022-09-30 19:37:48 -07:00
|
|
|
def _normal_real(key, shape, dtype) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
_check_shape("normal", shape)
|
2021-06-22 15:58:29 -04:00
|
|
|
lo = np.nextafter(np.array(-1., dtype), np.array(0., dtype), dtype=dtype)
|
2020-11-19 09:22:31 -08:00
|
|
|
hi = np.array(1., dtype)
|
2021-01-31 15:34:20 +02:00
|
|
|
u = uniform(key, shape, dtype, lo, hi) # type: ignore[arg-type]
|
2022-09-23 09:59:46 -07:00
|
|
|
return lax.mul(np.array(np.sqrt(2), dtype), lax.erf_inv(u))
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def multivariate_normal(key: KeyArray,
|
2021-03-17 16:37:09 -04:00
|
|
|
mean: RealArray,
|
|
|
|
cov: RealArray,
|
2022-09-30 19:37:48 -07:00
|
|
|
shape: Optional[Shape] = None,
|
2022-12-01 15:51:06 -08:00
|
|
|
dtype: DTypeLikeFloat = None,
|
2022-09-30 19:37:48 -07:00
|
|
|
method: str = 'cholesky') -> Array:
|
2023-03-22 10:55:30 -07:00
|
|
|
r"""Sample multivariate normal random values with given mean and covariance.
|
|
|
|
|
|
|
|
The values are returned according to the probability density function:
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
f(x;\mu, \Sigma) = (2\pi)^{-k/2} \det(\Sigma)^{-1}e^{-\frac{1}{2}(x - \mu)^T \Sigma^{-1} (x - \mu)}
|
|
|
|
|
|
|
|
where :math:`k` is the dimension, :math:`\mu` is the mean (given by ``mean``) and
|
|
|
|
:math:`\Sigma` is the covariance matrix (given by ``cov``).
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
Args:
|
2021-06-08 11:16:33 -07:00
|
|
|
key: a PRNG key used as the random key.
|
2020-11-19 09:22:31 -08:00
|
|
|
mean: a mean vector of shape ``(..., n)``.
|
|
|
|
cov: a positive definite covariance matrix of shape ``(..., n, n)``. The
|
|
|
|
batch shape ``...`` must be broadcast-compatible with that of ``mean``.
|
|
|
|
shape: optional, a tuple of nonnegative integers specifying the result
|
|
|
|
batch shape; that is, the prefix of the result shape excluding the last
|
|
|
|
axis. Must be broadcast-compatible with ``mean.shape[:-1]`` and
|
|
|
|
``cov.shape[:-2]``. The default (None) produces a result batch shape by
|
|
|
|
broadcasting together the batch shapes of ``mean`` and ``cov``.
|
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
2021-01-07 17:03:39 +09:00
|
|
|
method: optional, a method to compute the factor of ``cov``.
|
2022-12-05 12:43:05 -08:00
|
|
|
Must be one of 'svd', 'eigh', and 'cholesky'. Default 'cholesky'. For
|
|
|
|
singular covariance matrices, use 'svd' or 'eigh'.
|
2020-11-19 09:22:31 -08:00
|
|
|
Returns:
|
|
|
|
A random array with the specified dtype and shape given by
|
|
|
|
``shape + mean.shape[-1:]`` if ``shape`` is not None, or else
|
|
|
|
``broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]``.
|
|
|
|
"""
|
2021-08-15 08:09:30 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2023-03-13 12:18:36 -07:00
|
|
|
mean, cov = promote_dtypes_inexact(mean, cov)
|
2021-01-06 16:26:34 +09:00
|
|
|
if method not in {'svd', 'eigh', 'cholesky'}:
|
|
|
|
raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}")
|
2022-12-01 15:51:06 -08:00
|
|
|
if dtype is None:
|
|
|
|
dtype = mean.dtype
|
2020-11-19 09:22:31 -08:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `multivariate_normal` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
|
|
|
if shape is not None:
|
2020-11-18 21:17:02 -05:00
|
|
|
shape = core.canonicalize_shape(shape)
|
2021-01-06 16:26:34 +09:00
|
|
|
return _multivariate_normal(key, mean, cov, shape, dtype, method) # type: ignore
|
2020-11-19 09:22:31 -08:00
|
|
|
|
2021-08-20 13:43:38 -07:00
|
|
|
@partial(jit, static_argnums=(3, 4, 5), inline=True)
|
2022-09-30 19:37:48 -07:00
|
|
|
def _multivariate_normal(key, mean, cov, shape, dtype, method) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
if not np.ndim(mean) >= 1:
|
|
|
|
msg = "multivariate_normal requires mean.ndim >= 1, got mean.ndim == {}"
|
|
|
|
raise ValueError(msg.format(np.ndim(mean)))
|
|
|
|
if not np.ndim(cov) >= 2:
|
|
|
|
msg = "multivariate_normal requires cov.ndim >= 2, got cov.ndim == {}"
|
|
|
|
raise ValueError(msg.format(np.ndim(cov)))
|
|
|
|
n = mean.shape[-1]
|
|
|
|
if np.shape(cov)[-2:] != (n, n):
|
|
|
|
msg = ("multivariate_normal requires cov.shape == (..., n, n) for n={n}, "
|
|
|
|
"but got cov.shape == {shape}.")
|
|
|
|
raise ValueError(msg.format(n=n, shape=np.shape(cov)))
|
|
|
|
|
|
|
|
if shape is None:
|
|
|
|
shape = lax.broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
|
|
|
|
else:
|
|
|
|
_check_shape("normal", shape, mean.shape[:-1], cov.shape[:-2])
|
|
|
|
|
2021-01-06 16:26:34 +09:00
|
|
|
if method == 'svd':
|
|
|
|
(u, s, _) = svd(cov)
|
2021-09-28 09:58:08 -07:00
|
|
|
factor = u * jnp.sqrt(s[..., None, :])
|
2021-01-06 16:26:34 +09:00
|
|
|
elif method == 'eigh':
|
|
|
|
(w, v) = eigh(cov)
|
2021-09-28 09:58:08 -07:00
|
|
|
factor = v * jnp.sqrt(w[..., None, :])
|
2021-01-06 16:26:34 +09:00
|
|
|
else: # 'cholesky'
|
|
|
|
factor = cholesky(cov)
|
2020-11-19 09:22:31 -08:00
|
|
|
normal_samples = normal(key, shape + mean.shape[-1:], dtype)
|
2022-12-05 12:43:05 -08:00
|
|
|
with jax.numpy_rank_promotion('allow'):
|
|
|
|
result = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
|
|
|
|
return result
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def truncated_normal(key: KeyArray,
|
2021-03-17 16:37:09 -04:00
|
|
|
lower: RealArray,
|
|
|
|
upper: RealArray,
|
2022-09-30 19:37:48 -07:00
|
|
|
shape: Optional[Union[Shape, NamedShape]] = None,
|
|
|
|
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
2023-03-22 10:55:30 -07:00
|
|
|
r"""Sample truncated standard normal random values with given shape and dtype.
|
|
|
|
|
|
|
|
The values are returned according to the probability density function:
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
f(x) \propto e^{-x^2/2}
|
|
|
|
|
|
|
|
on the domain :math:`\rm{lower} < x < \rm{upper}`.
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
Args:
|
2021-06-08 11:16:33 -07:00
|
|
|
key: a PRNG key used as the random key.
|
2020-11-19 09:22:31 -08:00
|
|
|
lower: a float or array of floats representing the lower bound for
|
|
|
|
truncation. Must be broadcast-compatible with ``upper``.
|
|
|
|
upper: a float or array of floats representing the upper bound for
|
|
|
|
truncation. Must be broadcast-compatible with ``lower``.
|
|
|
|
shape: optional, a tuple of nonnegative integers specifying the result
|
|
|
|
shape. Must be broadcast-compatible with ``lower`` and ``upper``. The
|
|
|
|
default (None) produces a result shape by broadcasting ``lower`` and
|
|
|
|
``upper``.
|
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified dtype and shape given by ``shape`` if
|
|
|
|
``shape`` is not None, or else by broadcasting ``lower`` and ``upper``.
|
|
|
|
Returns values in the open interval ``(lower, upper)``.
|
|
|
|
"""
|
2021-08-15 08:09:30 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2020-11-19 09:22:31 -08:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `truncated_normal` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
|
|
|
if shape is not None:
|
2021-02-04 12:38:12 +00:00
|
|
|
shape = core.as_named_shape(shape)
|
2020-11-19 09:22:31 -08:00
|
|
|
return _truncated_normal(key, lower, upper, shape, dtype) # type: ignore
|
|
|
|
|
2021-08-20 13:43:38 -07:00
|
|
|
@partial(jit, static_argnums=(3, 4), inline=True)
|
2022-09-30 19:37:48 -07:00
|
|
|
def _truncated_normal(key, lower, upper, shape, dtype) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
if shape is None:
|
|
|
|
shape = lax.broadcast_shapes(np.shape(lower), np.shape(upper))
|
|
|
|
else:
|
|
|
|
_check_shape("truncated_normal", shape, np.shape(lower), np.shape(upper))
|
|
|
|
|
|
|
|
sqrt2 = np.array(np.sqrt(2), dtype)
|
|
|
|
lower = lax.convert_element_type(lower, dtype)
|
|
|
|
upper = lax.convert_element_type(upper, dtype)
|
|
|
|
a = lax.erf(lower / sqrt2)
|
|
|
|
b = lax.erf(upper / sqrt2)
|
|
|
|
if not jnp.issubdtype(dtype, np.floating):
|
|
|
|
raise TypeError("truncated_normal only accepts floating point dtypes.")
|
|
|
|
u = uniform(key, shape, dtype, minval=a, maxval=b)
|
|
|
|
out = sqrt2 * lax.erf_inv(u)
|
|
|
|
# Clamp the value to the open interval (lower, upper) to make sure that
|
|
|
|
# rounding (or if we chose `a` for `u`) doesn't push us outside of the range.
|
|
|
|
return jnp.clip(
|
|
|
|
out,
|
|
|
|
lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
|
|
|
|
lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype)))
|
|
|
|
|
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def bernoulli(key: KeyArray,
|
2021-03-17 16:37:09 -04:00
|
|
|
p: RealArray = np.float32(0.5),
|
2022-09-30 19:37:48 -07:00
|
|
|
shape: Optional[Union[Shape, NamedShape]] = None) -> Array:
|
2023-03-22 10:55:30 -07:00
|
|
|
r"""Sample Bernoulli random values with given shape and mean.
|
|
|
|
|
|
|
|
The values are distributed according to the probability mass function:
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
f(k; p) = p^k(1 - p)^{1 - k}
|
|
|
|
|
|
|
|
where :math:`k \in \{0, 1\}` and :math:`0 \le p \le 1`.
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
Args:
|
2021-06-08 11:16:33 -07:00
|
|
|
key: a PRNG key used as the random key.
|
2020-11-19 09:22:31 -08:00
|
|
|
p: optional, a float or array of floats for the mean of the random
|
|
|
|
variables. Must be broadcast-compatible with ``shape``. Default 0.5.
|
|
|
|
shape: optional, a tuple of nonnegative integers representing the result
|
|
|
|
shape. Must be broadcast-compatible with ``p.shape``. The default (None)
|
|
|
|
produces a result shape equal to ``p.shape``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with boolean dtype and shape given by ``shape`` if ``shape``
|
|
|
|
is not None, or else ``p.shape``.
|
|
|
|
"""
|
2021-08-15 08:09:30 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2020-11-19 09:22:31 -08:00
|
|
|
dtype = dtypes.canonicalize_dtype(lax.dtype(p))
|
|
|
|
if shape is not None:
|
2021-02-04 12:45:20 +00:00
|
|
|
shape = core.as_named_shape(shape)
|
2020-11-19 09:22:31 -08:00
|
|
|
if not jnp.issubdtype(dtype, np.floating):
|
|
|
|
msg = "bernoulli probability `p` must have a floating dtype, got {}."
|
|
|
|
raise TypeError(msg.format(dtype))
|
|
|
|
p = lax.convert_element_type(p, dtype)
|
|
|
|
return _bernoulli(key, p, shape) # type: ignore
|
|
|
|
|
2021-08-20 13:43:38 -07:00
|
|
|
@partial(jit, static_argnums=(2,), inline=True)
|
2022-09-30 19:37:48 -07:00
|
|
|
def _bernoulli(key, p, shape) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
if shape is None:
|
2021-02-04 12:45:20 +00:00
|
|
|
# TODO: Use the named part of `p` as well
|
2020-11-19 09:22:31 -08:00
|
|
|
shape = np.shape(p)
|
|
|
|
else:
|
|
|
|
_check_shape("bernoulli", shape, np.shape(p))
|
|
|
|
|
|
|
|
return uniform(key, shape, lax.dtype(p)) < p
|
|
|
|
|
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def beta(key: KeyArray,
|
2021-03-17 16:37:09 -04:00
|
|
|
a: RealArray,
|
|
|
|
b: RealArray,
|
2022-09-30 19:37:48 -07:00
|
|
|
shape: Optional[Shape] = None,
|
|
|
|
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
2023-03-22 10:55:30 -07:00
|
|
|
r"""Sample Beta random values with given shape and float dtype.
|
|
|
|
|
|
|
|
The values are distributed according to the probability density function:
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
f(x;a,b) \propto x^{a - 1}(1 - x)^{b - 1}
|
|
|
|
|
|
|
|
on the domain :math:`0 \le x \le 1`.
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
Args:
|
2021-06-08 11:16:33 -07:00
|
|
|
key: a PRNG key used as the random key.
|
2020-11-19 09:22:31 -08:00
|
|
|
a: a float or array of floats broadcast-compatible with ``shape``
|
|
|
|
representing the first parameter "alpha".
|
|
|
|
b: a float or array of floats broadcast-compatible with ``shape``
|
|
|
|
representing the second parameter "beta".
|
|
|
|
shape: optional, a tuple of nonnegative integers specifying the result
|
|
|
|
shape. Must be broadcast-compatible with ``a`` and ``b``. The default
|
|
|
|
(None) produces a result shape by broadcasting ``a`` and ``b``.
|
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified dtype and shape given by ``shape`` if
|
|
|
|
``shape`` is not None, or else by broadcasting ``a`` and ``b``.
|
|
|
|
"""
|
2021-08-15 08:09:30 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2020-11-19 09:22:31 -08:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `beta` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
|
|
|
if shape is not None:
|
2020-11-18 21:17:02 -05:00
|
|
|
shape = core.canonicalize_shape(shape)
|
2020-11-19 09:22:31 -08:00
|
|
|
return _beta(key, a, b, shape, dtype)
|
|
|
|
|
2022-03-21 08:33:11 -07:00
|
|
|
|
2022-09-30 19:37:48 -07:00
|
|
|
def _beta(key, a, b, shape, dtype) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
if shape is None:
|
|
|
|
shape = lax.broadcast_shapes(np.shape(a), np.shape(b))
|
|
|
|
else:
|
|
|
|
_check_shape("beta", shape, np.shape(a), np.shape(b))
|
|
|
|
|
|
|
|
a = lax.convert_element_type(a, dtype)
|
|
|
|
b = lax.convert_element_type(b, dtype)
|
2021-08-15 08:09:30 -07:00
|
|
|
key_a, key_b = _split(key)
|
2020-11-19 09:22:31 -08:00
|
|
|
a = jnp.broadcast_to(a, shape)
|
|
|
|
b = jnp.broadcast_to(b, shape)
|
2022-03-21 08:33:11 -07:00
|
|
|
log_gamma_a = loggamma(key_a, a, shape, dtype)
|
|
|
|
log_gamma_b = loggamma(key_b, b, shape, dtype)
|
|
|
|
# Compute gamma_a / (gamma_a + gamma_b) without losing precision.
|
|
|
|
log_max = lax.max(log_gamma_a, log_gamma_b)
|
|
|
|
gamma_a_scaled = jnp.exp(log_gamma_a - log_max)
|
|
|
|
gamma_b_scaled = jnp.exp(log_gamma_b - log_max)
|
|
|
|
return gamma_a_scaled / (gamma_a_scaled + gamma_b_scaled)
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def cauchy(key: KeyArray,
|
2022-09-30 19:37:48 -07:00
|
|
|
shape: Shape = (),
|
|
|
|
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
2023-03-22 10:55:30 -07:00
|
|
|
r"""Sample Cauchy random values with given shape and float dtype.
|
|
|
|
|
|
|
|
The values are distributed according to the probability density function:
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
f(x) \propto \frac{1}{x^2 + 1}
|
|
|
|
|
|
|
|
on the domain :math:`-\infty < x < \infty`
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
Args:
|
2021-06-08 11:16:33 -07:00
|
|
|
key: a PRNG key used as the random key.
|
2020-11-19 09:22:31 -08:00
|
|
|
shape: optional, a tuple of nonnegative integers representing the result
|
|
|
|
shape. Default ().
|
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
2021-08-15 08:09:30 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2020-11-19 09:22:31 -08:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `cauchy` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-11-18 21:17:02 -05:00
|
|
|
shape = core.canonicalize_shape(shape)
|
2020-11-19 09:22:31 -08:00
|
|
|
return _cauchy(key, shape, dtype)
|
|
|
|
|
2021-08-20 13:43:38 -07:00
|
|
|
@partial(jit, static_argnums=(1, 2), inline=True)
|
2022-09-30 19:37:48 -07:00
|
|
|
def _cauchy(key, shape, dtype) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
_check_shape("cauchy", shape)
|
|
|
|
u = uniform(key, shape, dtype, minval=jnp.finfo(dtype).eps, maxval=1.)
|
2022-03-07 12:25:01 -08:00
|
|
|
pi = _lax_const(u, np.pi)
|
|
|
|
return lax.tan(lax.mul(pi, lax.sub(u, _lax_const(u, 0.5))))
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def dirichlet(key: KeyArray,
|
2021-03-17 16:37:09 -04:00
|
|
|
alpha: RealArray,
|
2022-09-30 19:37:48 -07:00
|
|
|
shape: Optional[Shape] = None,
|
|
|
|
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
2023-03-22 10:55:30 -07:00
|
|
|
r"""Sample Dirichlet random values with given shape and float dtype.
|
|
|
|
|
|
|
|
The values are distributed according the the probability density function:
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
f(\{x_i\}; \{\alpha_i\}) = \propto \prod_{i=1}^k x_i^{\alpha_i}
|
|
|
|
|
|
|
|
Where :math:`k` is the dimension, and :math:`\{x_i\}` satisfies
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
\sum_{i=1}^k x_i = 1
|
|
|
|
|
|
|
|
and :math:`0 \le x_i \le 1` for all :math:`x_i`.
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
Args:
|
2021-06-08 11:16:33 -07:00
|
|
|
key: a PRNG key used as the random key.
|
2020-11-19 09:22:31 -08:00
|
|
|
alpha: an array of shape ``(..., n)`` used as the concentration
|
|
|
|
parameter of the random variables.
|
|
|
|
shape: optional, a tuple of nonnegative integers specifying the result
|
|
|
|
batch shape; that is, the prefix of the result shape excluding the last
|
|
|
|
element of value ``n``. Must be broadcast-compatible with
|
|
|
|
``alpha.shape[:-1]``. The default (None) produces a result shape equal to
|
|
|
|
``alpha.shape``.
|
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified dtype and shape given by
|
|
|
|
``shape + (alpha.shape[-1],)`` if ``shape`` is not None, or else
|
|
|
|
``alpha.shape``.
|
|
|
|
"""
|
2021-08-15 08:09:30 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2020-11-19 09:22:31 -08:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `dirichlet` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
|
|
|
if shape is not None:
|
2020-11-18 21:17:02 -05:00
|
|
|
shape = core.canonicalize_shape(shape)
|
2020-11-19 09:22:31 -08:00
|
|
|
return _dirichlet(key, alpha, shape, dtype)
|
|
|
|
|
2021-08-20 13:43:38 -07:00
|
|
|
@partial(jit, static_argnums=(2, 3), inline=True)
|
2022-09-30 19:37:48 -07:00
|
|
|
def _dirichlet(key, alpha, shape, dtype) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
if not np.ndim(alpha) >= 1:
|
|
|
|
msg = "dirichlet requires alpha.ndim >= 1, got alpha.ndim == {}"
|
|
|
|
raise ValueError(msg.format(np.ndim(alpha)))
|
|
|
|
|
|
|
|
if shape is None:
|
|
|
|
shape = np.shape(alpha)[:-1]
|
|
|
|
else:
|
|
|
|
_check_shape("dirichlet", shape, np.shape(alpha)[:-1])
|
|
|
|
|
|
|
|
alpha = lax.convert_element_type(alpha, dtype)
|
2022-03-21 08:33:11 -07:00
|
|
|
|
|
|
|
# Compute gamma in log space, otherwise small alpha can lead to poor behavior.
|
|
|
|
log_gamma_samples = loggamma(key, alpha, shape + np.shape(alpha)[-1:], dtype)
|
|
|
|
return _softmax(log_gamma_samples, -1)
|
|
|
|
|
|
|
|
|
2022-09-30 19:37:48 -07:00
|
|
|
def _softmax(x, axis) -> Array:
|
2022-03-21 08:33:11 -07:00
|
|
|
"""Utility to compute the softmax of x along a given axis."""
|
|
|
|
if not dtypes.issubdtype(x.dtype, np.floating):
|
|
|
|
raise TypeError(f"_softmax only accepts floating dtypes, got {x.dtype}")
|
|
|
|
x_max = jnp.max(x, axis, keepdims=True)
|
|
|
|
unnormalized = jnp.exp(x - lax.stop_gradient(x_max))
|
|
|
|
return unnormalized / unnormalized.sum(axis, keepdims=True)
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def exponential(key: KeyArray,
|
2022-09-30 19:37:48 -07:00
|
|
|
shape: Shape = (),
|
|
|
|
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
2023-03-22 10:55:30 -07:00
|
|
|
r"""Sample Exponential random values with given shape and float dtype.
|
|
|
|
|
|
|
|
The values are distributed according the the probability density function:
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
f(x) = e^{-x}
|
|
|
|
|
|
|
|
on the domain :math:`0 \le x < \infty`.
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
Args:
|
2021-06-08 11:16:33 -07:00
|
|
|
key: a PRNG key used as the random key.
|
2020-11-19 09:22:31 -08:00
|
|
|
shape: optional, a tuple of nonnegative integers representing the result
|
|
|
|
shape. Default ().
|
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
2021-08-15 08:09:30 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2020-11-19 09:22:31 -08:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `exponential` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-11-18 21:17:02 -05:00
|
|
|
shape = core.canonicalize_shape(shape)
|
2020-11-19 09:22:31 -08:00
|
|
|
return _exponential(key, shape, dtype)
|
|
|
|
|
2021-08-20 13:43:38 -07:00
|
|
|
@partial(jit, static_argnums=(1, 2), inline=True)
|
2022-09-30 19:37:48 -07:00
|
|
|
def _exponential(key, shape, dtype) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
_check_shape("exponential", shape)
|
|
|
|
u = uniform(key, shape, dtype)
|
|
|
|
# taking 1 - u to move the domain of log to (0, 1] instead of [0, 1)
|
|
|
|
return lax.neg(lax.log1p(lax.neg(u)))
|
|
|
|
|
|
|
|
|
2022-09-30 19:37:48 -07:00
|
|
|
def _gamma_one(key: KeyArray, alpha, log_space) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
# Ref: A simple method for generating gamma variables, George Marsaglia and Wai Wan Tsang
|
|
|
|
# The algorithm can also be founded in:
|
|
|
|
# https://en.wikipedia.org/wiki/Gamma_distribution#Generating_gamma-distributed_random_variables
|
2022-03-07 12:25:01 -08:00
|
|
|
zero = _lax_const(alpha, 0)
|
|
|
|
one = _lax_const(alpha, 1)
|
|
|
|
minus_one = _lax_const(alpha, -1)
|
|
|
|
one_over_two = _lax_const(alpha, 0.5)
|
|
|
|
one_over_three = _lax_const(alpha, 1. / 3.)
|
|
|
|
squeeze_const = _lax_const(alpha, 0.0331)
|
2020-11-19 09:22:31 -08:00
|
|
|
dtype = lax.dtype(alpha)
|
|
|
|
|
|
|
|
# for alpha < 1, we boost alpha to alpha + 1 and get a sample according to
|
2022-03-21 08:33:11 -07:00
|
|
|
# Gamma(alpha) ~ Gamma(alpha+1) * Uniform()^(1 / alpha)
|
|
|
|
# When alpha is very small, this boost can be problematic because it may result
|
|
|
|
# in floating point underflow; for this reason we compute it in log space if
|
|
|
|
# specified by the `log_space` argument:
|
|
|
|
# log[Gamma(alpha)] ~ log[Gamma(alpha + 1)] + log[Uniform()] / alpha
|
|
|
|
# Note that log[Uniform()] ~ Exponential(), but the exponential() function is
|
|
|
|
# computed via log[1 - Uniform()] to avoid taking log(0). We want the generated
|
|
|
|
# sequence to match between log_space=True and log_space=False, so we avoid this
|
|
|
|
# for now to maintain backward compatibility with the original implementation.
|
|
|
|
# TODO(jakevdp) should we change the convention to avoid -inf in log-space?
|
|
|
|
boost_mask = lax.ge(alpha, one)
|
|
|
|
alpha_orig = alpha
|
|
|
|
alpha = lax.select(boost_mask, alpha, lax.add(alpha, one))
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
d = lax.sub(alpha, one_over_three)
|
2021-05-24 10:43:35 +01:00
|
|
|
c = lax.div(one_over_three, lax.sqrt(d))
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
def _cond_fn(kXVU):
|
|
|
|
_, X, V, U = kXVU
|
|
|
|
# TODO: use lax.cond when its batching rule is supported
|
|
|
|
# The reason is to avoid evaluating second condition which involves log+log
|
|
|
|
# if the first condition is satisfied
|
|
|
|
cond = lax.bitwise_and(lax.ge(U, lax.sub(one, lax.mul(squeeze_const, lax.mul(X, X)))),
|
|
|
|
lax.ge(lax.log(U), lax.add(lax.mul(X, one_over_two),
|
|
|
|
lax.mul(d, lax.add(lax.sub(one, V),
|
|
|
|
lax.log(V))))))
|
|
|
|
return cond
|
|
|
|
|
|
|
|
def _body_fn(kXVU):
|
|
|
|
def _next_kxv(kxv):
|
|
|
|
key = kxv[0]
|
2021-08-15 08:09:30 -07:00
|
|
|
key, subkey = _split(key)
|
2020-11-19 09:22:31 -08:00
|
|
|
x = normal(subkey, (), dtype=dtype)
|
|
|
|
v = lax.add(one, lax.mul(x, c))
|
|
|
|
return key, x, v
|
|
|
|
|
|
|
|
key = kXVU[0]
|
2021-08-15 08:09:30 -07:00
|
|
|
key, x_key, U_key = _split(key, 3)
|
2020-11-19 09:22:31 -08:00
|
|
|
_, x, v = lax.while_loop(lambda kxv: lax.le(kxv[2], zero), _next_kxv, (x_key, zero, minus_one))
|
|
|
|
X = lax.mul(x, x)
|
|
|
|
V = lax.mul(lax.mul(v, v), v)
|
|
|
|
U = uniform(U_key, (), dtype=dtype)
|
|
|
|
return key, X, V, U
|
|
|
|
|
|
|
|
# initial state is chosen such that _cond_fn will return True
|
2022-03-21 08:33:11 -07:00
|
|
|
key, subkey = _split(key)
|
|
|
|
u_boost = uniform(subkey, (), dtype=dtype)
|
2022-03-07 12:25:01 -08:00
|
|
|
_, _, V, _ = lax.while_loop(_cond_fn, _body_fn, (key, zero, one, _lax_const(alpha, 2)))
|
2022-03-21 08:33:11 -07:00
|
|
|
if log_space:
|
|
|
|
# TODO(jakevdp): there are negative infinities here due to issues mentioned above. How should
|
|
|
|
# we handle those?
|
|
|
|
log_boost = lax.select(boost_mask, zero, lax.mul(lax.log(u_boost), lax.div(one, alpha_orig)))
|
|
|
|
return lax.add(lax.add(lax.log(d), lax.log(V)), log_boost)
|
|
|
|
else:
|
|
|
|
boost = lax.select(boost_mask, one, lax.pow(u_boost, lax.div(one, alpha_orig)))
|
|
|
|
z = lax.mul(lax.mul(d, V), boost)
|
|
|
|
return lax.select(lax.eq(z, zero), jnp.finfo(z.dtype).tiny, z)
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
def _gamma_grad(sample, a, *, log_space):
|
2020-11-19 09:22:31 -08:00
|
|
|
samples = jnp.reshape(sample, -1)
|
|
|
|
alphas = jnp.reshape(a, -1)
|
2022-03-21 08:33:11 -07:00
|
|
|
if log_space:
|
|
|
|
# d[log(sample)] = d[sample] / sample
|
|
|
|
# This requires computing exp(log_sample), which may be zero due to float roundoff.
|
|
|
|
# In this case, we use the same zero-correction used in gamma() above.
|
|
|
|
samples = lax.exp(samples)
|
|
|
|
zero = lax_internal._const(sample, 0)
|
|
|
|
tiny = lax.full_like(samples, jnp.finfo(samples.dtype).tiny)
|
|
|
|
samples = lax.select(lax.eq(samples, zero), tiny, samples)
|
|
|
|
gamma_grad = lambda alpha, sample: lax.random_gamma_grad(alpha, sample) / sample
|
|
|
|
else:
|
|
|
|
gamma_grad = lax.random_gamma_grad
|
2020-11-19 09:22:31 -08:00
|
|
|
if xla_bridge.get_backend().platform == 'cpu':
|
2022-03-21 08:33:11 -07:00
|
|
|
grads = lax.map(lambda args: gamma_grad(*args), (alphas, samples))
|
2020-11-19 09:22:31 -08:00
|
|
|
else:
|
2022-03-21 08:33:11 -07:00
|
|
|
grads = vmap(gamma_grad)(alphas, samples)
|
2020-11-19 09:22:31 -08:00
|
|
|
return grads.reshape(np.shape(a))
|
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
def _gamma_impl(key, a, *, log_space, use_vmap=False):
|
2022-08-19 21:01:38 -07:00
|
|
|
# split key to match the shape of a
|
2022-08-22 13:56:50 -07:00
|
|
|
a_shape = jnp.shape(a)
|
2023-02-28 12:40:30 -08:00
|
|
|
split_count = math.prod(a_shape[key.ndim:])
|
2022-08-22 13:56:50 -07:00
|
|
|
keys = key.flatten()
|
|
|
|
keys = vmap(_split, in_axes=(0, None))(keys, split_count)
|
|
|
|
keys = keys.flatten()
|
|
|
|
alphas = a.flatten()
|
|
|
|
|
2020-11-19 09:22:31 -08:00
|
|
|
if use_vmap:
|
2022-03-21 08:33:11 -07:00
|
|
|
samples = vmap(partial(_gamma_one, log_space=log_space))(keys, alphas)
|
2020-11-19 09:22:31 -08:00
|
|
|
else:
|
2022-08-22 13:56:50 -07:00
|
|
|
samples = lax.map(
|
|
|
|
lambda args: _gamma_one(*args, log_space=log_space), (keys, alphas))
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
return jnp.reshape(samples, a_shape)
|
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
def _gamma_batching_rule(batched_args, batch_dims, *, log_space):
|
|
|
|
k, a = batched_args
|
|
|
|
bk, ba = batch_dims
|
|
|
|
size = next(
|
|
|
|
t.shape[i] for t, i in zip(batched_args, batch_dims) if i is not None)
|
|
|
|
k = batching.bdim_at_front(k, bk, size)
|
|
|
|
a = batching.bdim_at_front(a, ba, size)
|
|
|
|
return random_gamma_p.bind(k, a, log_space=log_space), 0
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
random_gamma_p = core.Primitive('random_gamma')
|
|
|
|
random_gamma_p.def_impl(_gamma_impl)
|
2022-03-09 17:05:28 -08:00
|
|
|
random_gamma_p.def_abstract_eval(lambda key, a, **_: core.raise_to_shaped(a))
|
2022-08-22 13:56:50 -07:00
|
|
|
ad.defjvp2(
|
|
|
|
random_gamma_p, None,
|
|
|
|
lambda tangent, ans, key, a, **kwds: tangent * _gamma_grad(ans, a, **kwds))
|
2022-04-06 12:53:19 -07:00
|
|
|
mlir.register_lowering(random_gamma_p, mlir.lower_fun(
|
|
|
|
partial(_gamma_impl, use_vmap=True),
|
|
|
|
multiple_results=False))
|
|
|
|
mlir.register_lowering(random_gamma_p, mlir.lower_fun(
|
|
|
|
partial(_gamma_impl, use_vmap=False),
|
|
|
|
multiple_results=False), platform='cpu')
|
2020-11-19 09:22:31 -08:00
|
|
|
batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule
|
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def gamma(key: KeyArray,
|
2021-03-17 16:37:09 -04:00
|
|
|
a: RealArray,
|
2022-09-30 19:37:48 -07:00
|
|
|
shape: Optional[Shape] = None,
|
|
|
|
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
2023-03-22 10:55:30 -07:00
|
|
|
r"""Sample Gamma random values with given shape and float dtype.
|
|
|
|
|
|
|
|
The values are distributed according the the probability density function:
|
2020-11-19 09:22:31 -08:00
|
|
|
|
2023-03-22 10:55:30 -07:00
|
|
|
.. math::
|
|
|
|
f(x;a) \propto x^{a - 1} e^{-x}
|
|
|
|
|
|
|
|
on the domain :math:`0 \le x < \infty`, with :math:`a > 0`.
|
|
|
|
|
|
|
|
This is the standard gamma density, with a unit scale/rate parameter.
|
2023-02-28 13:33:46 -08:00
|
|
|
Dividing the sample output by the rate is equivalent to sampling from
|
|
|
|
*gamma(a, rate)*, and multiplying the sample output by the scale is equivalent
|
|
|
|
to sampling from *gamma(a, scale)*.
|
|
|
|
|
2020-11-19 09:22:31 -08:00
|
|
|
Args:
|
2021-06-08 11:16:33 -07:00
|
|
|
key: a PRNG key used as the random key.
|
2020-11-19 09:22:31 -08:00
|
|
|
a: a float or array of floats broadcast-compatible with ``shape``
|
|
|
|
representing the parameter of the distribution.
|
|
|
|
shape: optional, a tuple of nonnegative integers specifying the result
|
|
|
|
shape. Must be broadcast-compatible with ``a``. The default (None)
|
|
|
|
produces a result shape equal to ``a.shape``.
|
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified dtype and with shape given by ``shape`` if
|
|
|
|
``shape`` is not None, or else by ``a.shape``.
|
2022-03-21 08:33:11 -07:00
|
|
|
|
|
|
|
See Also:
|
|
|
|
loggamma : sample gamma values in log-space, which can provide improved
|
|
|
|
accuracy for small values of ``a``.
|
2020-11-19 09:22:31 -08:00
|
|
|
"""
|
2021-08-15 08:09:30 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2020-11-19 09:22:31 -08:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `gamma` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
|
|
|
if shape is not None:
|
2020-11-18 21:17:02 -05:00
|
|
|
shape = core.canonicalize_shape(shape)
|
2022-03-21 08:33:11 -07:00
|
|
|
return _gamma(key, a, shape=shape, dtype=dtype)
|
2020-11-19 09:22:31 -08:00
|
|
|
|
2022-03-21 08:33:11 -07:00
|
|
|
|
|
|
|
def loggamma(key: KeyArray,
|
|
|
|
a: RealArray,
|
2022-09-30 19:37:48 -07:00
|
|
|
shape: Optional[Shape] = None,
|
|
|
|
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
2022-03-21 08:33:11 -07:00
|
|
|
"""Sample log-gamma random values with given shape and float dtype.
|
|
|
|
|
|
|
|
This function is implemented such that the following will hold for a
|
|
|
|
dtype-appropriate tolerance::
|
|
|
|
|
|
|
|
np.testing.assert_allclose(jnp.exp(loggamma(*args)), gamma(*args), rtol=rtol)
|
|
|
|
|
|
|
|
The benefit of log-gamma is that for samples very close to zero (which occur frequently
|
|
|
|
when `a << 1`) sampling in log space provides better precision.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNG key used as the random key.
|
|
|
|
a: a float or array of floats broadcast-compatible with ``shape``
|
|
|
|
representing the parameter of the distribution.
|
|
|
|
shape: optional, a tuple of nonnegative integers specifying the result
|
|
|
|
shape. Must be broadcast-compatible with ``a``. The default (None)
|
|
|
|
produces a result shape equal to ``a.shape``.
|
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified dtype and with shape given by ``shape`` if
|
|
|
|
``shape`` is not None, or else by ``a.shape``.
|
|
|
|
|
|
|
|
See Also:
|
|
|
|
gamma : standard gamma sampler.
|
|
|
|
"""
|
|
|
|
key, _ = _check_prng_key(key)
|
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `gamma` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
|
|
|
if shape is not None:
|
|
|
|
shape = core.canonicalize_shape(shape)
|
|
|
|
return _gamma(key, a, shape=shape, dtype=dtype, log_space=True)
|
|
|
|
|
|
|
|
|
|
|
|
@partial(jit, static_argnames=('shape', 'dtype', 'log_space'), inline=True)
|
2022-09-30 19:37:48 -07:00
|
|
|
def _gamma(key, a, shape, dtype, log_space=False) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
if shape is None:
|
|
|
|
shape = np.shape(a)
|
|
|
|
else:
|
|
|
|
_check_shape("gamma", shape, np.shape(a))
|
|
|
|
|
|
|
|
a = lax.convert_element_type(a, dtype)
|
|
|
|
if np.shape(a) != shape:
|
|
|
|
a = jnp.broadcast_to(a, shape)
|
2022-08-22 13:56:50 -07:00
|
|
|
return random_gamma_p.bind(key, a, log_space=log_space)
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
|
2021-08-20 13:43:38 -07:00
|
|
|
@partial(jit, static_argnums=(2, 3, 4), inline=True)
|
2022-09-30 19:37:48 -07:00
|
|
|
def _poisson_knuth(key, lam, shape, dtype, max_iters) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
# Knuth's algorithm for generating Poisson random variates.
|
|
|
|
# Reference:
|
|
|
|
# https://en.wikipedia.org/wiki/Poisson_distribution#Generating_Poisson-distributed_random_variables
|
|
|
|
|
|
|
|
def body_fn(carry):
|
|
|
|
i, k, rng, log_prod = carry
|
2021-08-15 08:09:30 -07:00
|
|
|
rng, subkey = _split(rng)
|
2020-11-19 09:22:31 -08:00
|
|
|
k = lax.select(log_prod > -lam, k + 1, k)
|
|
|
|
u = uniform(subkey, shape, np.float32)
|
|
|
|
return i + 1, k, rng, log_prod + jnp.log(u)
|
|
|
|
|
|
|
|
def cond_fn(carry):
|
|
|
|
i, log_prod = carry[0], carry[3]
|
|
|
|
return (log_prod > -lam).any() & (i < max_iters)
|
|
|
|
|
|
|
|
k_init = lax.full_like(lam, 0, dtype, shape)
|
|
|
|
log_rate_init = lax.full_like(lam, 0, np.float32, shape)
|
|
|
|
k = lax.while_loop(cond_fn, body_fn, (0, k_init, key, log_rate_init))[1]
|
|
|
|
return (k - 1).astype(dtype)
|
|
|
|
|
|
|
|
|
2021-08-20 13:43:38 -07:00
|
|
|
@partial(jit, static_argnums=(2, 3, 4), inline=True)
|
2022-09-30 19:37:48 -07:00
|
|
|
def _poisson_rejection(key, lam, shape, dtype, max_iters) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
# Transformed rejection due to Hormann.
|
|
|
|
# Reference:
|
|
|
|
# http://citeseer.ist.psu.edu/viewdoc/citations;jsessionid=1BEB35946CC807879F55D42512E5490C?doi=10.1.1.48.3054.
|
|
|
|
log_lam = lax.log(lam)
|
|
|
|
b = 0.931 + 2.53 * lax.sqrt(lam)
|
|
|
|
a = -0.059 + 0.02483 * b
|
|
|
|
inv_alpha = 1.1239 + 1.1328 / (b - 3.4)
|
|
|
|
v_r = 0.9277 - 3.6224 / (b - 2)
|
|
|
|
|
|
|
|
def body_fn(carry):
|
|
|
|
i, k_out, accepted, key = carry
|
2021-08-15 08:09:30 -07:00
|
|
|
key, subkey_0, subkey_1 = _split(key, 3)
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
u = uniform(subkey_0, shape, lam.dtype) - 0.5
|
|
|
|
v = uniform(subkey_1, shape, lam.dtype)
|
|
|
|
u_shifted = 0.5 - abs(u)
|
|
|
|
|
|
|
|
k = lax.floor((2 * a / u_shifted + b) * u + lam + 0.43)
|
|
|
|
s = lax.log(v * inv_alpha / (a / (u_shifted * u_shifted) + b))
|
|
|
|
t = -lam + k * log_lam - lax.lgamma(k + 1)
|
|
|
|
|
|
|
|
accept1 = (u_shifted >= 0.07) & (v <= v_r)
|
|
|
|
reject = (k < 0) | ((u_shifted < 0.013) & (v > u_shifted))
|
|
|
|
accept2 = s <= t
|
|
|
|
accept = accept1 | (~reject & accept2)
|
|
|
|
|
|
|
|
k_out = lax.select(accept, k, k_out)
|
|
|
|
accepted |= accept
|
|
|
|
|
|
|
|
return i + 1, k_out, accepted, key
|
|
|
|
|
|
|
|
def cond_fn(carry):
|
|
|
|
i, k_out, accepted, key = carry
|
|
|
|
return (~accepted).any() & (i < max_iters)
|
|
|
|
|
|
|
|
k_init = lax.full_like(lam, -1, lam.dtype, shape)
|
|
|
|
accepted = lax.full_like(lam, False, jnp.bool_, shape)
|
|
|
|
k = lax.while_loop(cond_fn, body_fn, (0, k_init, accepted, key))[1]
|
|
|
|
return k.astype(dtype)
|
|
|
|
|
|
|
|
|
2021-08-20 13:43:38 -07:00
|
|
|
@partial(jit, static_argnums=(2, 3), inline=True)
|
2022-09-30 19:37:48 -07:00
|
|
|
def _poisson(key, lam, shape, dtype) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
# The implementation matches TensorFlow and NumPy:
|
|
|
|
# https://github.com/tensorflow/tensorflow/blob/v2.2.0-rc3/tensorflow/core/kernels/random_poisson_op.cc
|
|
|
|
# https://github.com/numpy/numpy/blob/v1.18.3/numpy/random/src/distributions/distributions.c#L574
|
|
|
|
# For lambda < 10, we use the Knuth algorithm; otherwise, we use transformed
|
|
|
|
# rejection sampling.
|
2022-02-28 12:10:47 -08:00
|
|
|
use_knuth = _isnan(lam) | (lam < 10)
|
2020-11-19 09:22:31 -08:00
|
|
|
lam_knuth = lax.select(use_knuth, lam, lax.full_like(lam, 0.0))
|
|
|
|
# The acceptance probability for rejection sampling maxes out at 89% as
|
|
|
|
# λ -> ∞, so pick some arbitrary large value.
|
|
|
|
lam_rejection = lax.select(use_knuth, lax.full_like(lam, 1e5), lam)
|
|
|
|
max_iters = dtype.type(jnp.iinfo(dtype).max) # insanely conservative
|
2021-03-08 09:27:11 -08:00
|
|
|
result = lax.select(
|
|
|
|
use_knuth,
|
|
|
|
_poisson_knuth(key, lam_knuth, shape, dtype, max_iters),
|
|
|
|
_poisson_rejection(key, lam_rejection, shape, dtype, max_iters),
|
2020-11-19 09:22:31 -08:00
|
|
|
)
|
2021-03-08 09:27:11 -08:00
|
|
|
return lax.select(lam == 0, jnp.zeros_like(result), result)
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def poisson(key: KeyArray,
|
2021-03-17 16:37:09 -04:00
|
|
|
lam: RealArray,
|
2022-09-30 19:37:48 -07:00
|
|
|
shape: Optional[Shape] = None,
|
|
|
|
dtype: DTypeLikeInt = dtypes.int_) -> Array:
|
2023-03-22 10:55:30 -07:00
|
|
|
r"""Sample Poisson random values with given shape and integer dtype.
|
|
|
|
|
|
|
|
The values are distributed according to the probability mass function:
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
f(k; \lambda) = \frac{\lambda^k e^{-\lambda}}{k!}
|
|
|
|
|
|
|
|
Where `k` is a non-negative integer and :math:`\lambda > 0`.
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
Args:
|
2021-06-08 11:16:33 -07:00
|
|
|
key: a PRNG key used as the random key.
|
2021-10-29 00:45:40 +05:30
|
|
|
lam: rate parameter (mean of the distribution), must be >= 0. Must be broadcast-compatible with ``shape``
|
2020-11-19 09:22:31 -08:00
|
|
|
shape: optional, a tuple of nonnegative integers representing the result
|
2021-10-29 00:45:40 +05:30
|
|
|
shape. Default (None) produces a result shape equal to ``lam.shape``.
|
2020-11-19 09:22:31 -08:00
|
|
|
dtype: optional, a integer dtype for the returned values (default int64 if
|
|
|
|
jax_enable_x64 is true, otherwise int32).
|
|
|
|
|
|
|
|
Returns:
|
2021-10-29 00:45:40 +05:30
|
|
|
A random array with the specified dtype and with shape given by ``shape`` if
|
|
|
|
``shape is not None, or else by ``lam.shape``.
|
2020-11-19 09:22:31 -08:00
|
|
|
"""
|
2021-08-15 08:09:30 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2022-08-22 13:56:50 -07:00
|
|
|
# TODO(frostig): generalize underlying poisson implementation and
|
2022-08-30 14:16:51 -07:00
|
|
|
# remove this check
|
2022-09-30 19:37:48 -07:00
|
|
|
key_impl = key.dtype.impl # type: ignore[union-attr]
|
2022-08-22 13:56:50 -07:00
|
|
|
if key_impl is not prng.threefry_prng_impl:
|
2021-06-08 11:16:33 -07:00
|
|
|
raise NotImplementedError(
|
2021-10-07 19:15:43 -07:00
|
|
|
'`poisson` is only implemented for the threefry2x32 RNG, '
|
2022-08-22 13:56:50 -07:00
|
|
|
f'not {key_impl}')
|
2020-11-19 09:22:31 -08:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2021-10-29 00:45:40 +05:30
|
|
|
if shape is not None:
|
|
|
|
shape = core.canonicalize_shape(shape)
|
|
|
|
else:
|
|
|
|
shape = np.shape(lam)
|
|
|
|
lam = jnp.broadcast_to(lam, shape)
|
2020-11-19 09:22:31 -08:00
|
|
|
lam = lax.convert_element_type(lam, np.float32)
|
|
|
|
return _poisson(key, lam, shape, dtype)
|
|
|
|
|
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def gumbel(key: KeyArray,
|
2022-09-30 19:37:48 -07:00
|
|
|
shape: Shape = (),
|
|
|
|
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
"""Sample Gumbel random values with given shape and float dtype.
|
|
|
|
|
2023-03-22 10:55:30 -07:00
|
|
|
The values are distributed according to the probability density function:
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
f(x) = e^{-(x + e^{-x})}
|
|
|
|
|
2020-11-19 09:22:31 -08:00
|
|
|
Args:
|
2021-06-08 11:16:33 -07:00
|
|
|
key: a PRNG key used as the random key.
|
2020-11-19 09:22:31 -08:00
|
|
|
shape: optional, a tuple of nonnegative integers representing the result
|
|
|
|
shape. Default ().
|
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
2021-08-15 08:09:30 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2020-11-19 09:22:31 -08:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `gumbel` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-11-18 21:17:02 -05:00
|
|
|
shape = core.canonicalize_shape(shape)
|
2020-11-19 09:22:31 -08:00
|
|
|
return _gumbel(key, shape, dtype)
|
|
|
|
|
2021-08-20 13:43:38 -07:00
|
|
|
@partial(jit, static_argnums=(1, 2), inline=True)
|
2022-09-30 19:37:48 -07:00
|
|
|
def _gumbel(key, shape, dtype) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
_check_shape("gumbel", shape)
|
|
|
|
return -jnp.log(-jnp.log(
|
2021-01-07 16:54:33 -06:00
|
|
|
uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.)))
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def categorical(key: KeyArray,
|
2021-05-10 18:18:03 -04:00
|
|
|
logits: RealArray,
|
2021-03-17 16:37:09 -04:00
|
|
|
axis: int = -1,
|
2022-09-30 19:37:48 -07:00
|
|
|
shape: Optional[Shape] = None) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
"""Sample random values from categorical distributions.
|
|
|
|
|
|
|
|
Args:
|
2021-06-08 11:16:33 -07:00
|
|
|
key: a PRNG key used as the random key.
|
2020-11-19 09:22:31 -08:00
|
|
|
logits: Unnormalized log probabilities of the categorical distribution(s) to sample from,
|
|
|
|
so that `softmax(logits, axis)` gives the corresponding probabilities.
|
|
|
|
axis: Axis along which logits belong to the same categorical distribution.
|
|
|
|
shape: Optional, a tuple of nonnegative integers representing the result shape.
|
|
|
|
Must be broadcast-compatible with ``np.delete(logits.shape, axis)``.
|
|
|
|
The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with int dtype and shape given by ``shape`` if ``shape``
|
|
|
|
is not None, or else ``np.delete(logits.shape, axis)``.
|
|
|
|
"""
|
2021-08-15 08:09:30 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2023-03-13 12:18:36 -07:00
|
|
|
check_arraylike("categorical", logits)
|
2022-09-30 19:37:48 -07:00
|
|
|
logits_arr = jnp.asarray(logits)
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
if axis >= 0:
|
2022-09-30 19:37:48 -07:00
|
|
|
axis -= len(logits_arr.shape)
|
2020-11-19 09:22:31 -08:00
|
|
|
|
2022-09-30 19:37:48 -07:00
|
|
|
batch_shape = tuple(np.delete(logits_arr.shape, axis))
|
2020-11-19 09:22:31 -08:00
|
|
|
if shape is None:
|
|
|
|
shape = batch_shape
|
|
|
|
else:
|
2021-03-17 16:37:09 -04:00
|
|
|
shape = tuple(shape)
|
2020-11-19 09:22:31 -08:00
|
|
|
_check_shape("categorical", shape, batch_shape)
|
|
|
|
|
2022-11-04 19:49:07 -07:00
|
|
|
shape_prefix = shape[:len(shape)-len(batch_shape)]
|
|
|
|
logits_shape = list(shape[len(shape) - len(batch_shape):])
|
|
|
|
logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis])
|
2021-06-30 13:52:32 +01:00
|
|
|
return jnp.argmax(
|
2022-11-04 19:49:07 -07:00
|
|
|
gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype) +
|
|
|
|
lax.expand_dims(logits_arr, tuple(range(len(shape_prefix)))),
|
2021-06-30 13:52:32 +01:00
|
|
|
axis=axis)
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def laplace(key: KeyArray,
|
2022-09-30 19:37:48 -07:00
|
|
|
shape: Shape = (),
|
|
|
|
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
2023-03-22 10:55:30 -07:00
|
|
|
r"""Sample Laplace random values with given shape and float dtype.
|
|
|
|
|
|
|
|
The values are distributed according to the probability density function:
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
f(x) = \frac{1}{2}e^{-|x|}
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
Args:
|
2021-06-08 11:16:33 -07:00
|
|
|
key: a PRNG key used as the random key.
|
2020-11-19 09:22:31 -08:00
|
|
|
shape: optional, a tuple of nonnegative integers representing the result
|
|
|
|
shape. Default ().
|
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
2021-08-15 08:09:30 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2020-11-19 09:22:31 -08:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `laplace` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-11-18 21:17:02 -05:00
|
|
|
shape = core.canonicalize_shape(shape)
|
2020-11-19 09:22:31 -08:00
|
|
|
return _laplace(key, shape, dtype)
|
|
|
|
|
2021-08-20 13:43:38 -07:00
|
|
|
@partial(jit, static_argnums=(1, 2), inline=True)
|
2022-09-30 19:37:48 -07:00
|
|
|
def _laplace(key, shape, dtype) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
_check_shape("laplace", shape)
|
|
|
|
u = uniform(
|
|
|
|
key, shape, dtype, minval=-1. + jnp.finfo(dtype).epsneg, maxval=1.)
|
|
|
|
return lax.mul(lax.sign(u), lax.log1p(lax.neg(lax.abs(u))))
|
|
|
|
|
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def logistic(key: KeyArray,
|
2022-09-30 19:37:48 -07:00
|
|
|
shape: Shape = (),
|
|
|
|
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
2023-03-22 10:55:30 -07:00
|
|
|
r"""Sample logistic random values with given shape and float dtype.
|
|
|
|
|
|
|
|
The values are distributed according to the probability density function:
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
f(x) = \frac{e^{-x}}{(1 + e^{-x})^2}
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
Args:
|
2021-06-08 11:16:33 -07:00
|
|
|
key: a PRNG key used as the random key.
|
2020-11-19 09:22:31 -08:00
|
|
|
shape: optional, a tuple of nonnegative integers representing the result
|
|
|
|
shape. Default ().
|
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
2021-08-15 08:09:30 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2020-11-19 09:22:31 -08:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `logistic` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-11-18 21:17:02 -05:00
|
|
|
shape = core.canonicalize_shape(shape)
|
2020-11-19 09:22:31 -08:00
|
|
|
return _logistic(key, shape, dtype)
|
|
|
|
|
2021-08-20 13:43:38 -07:00
|
|
|
@partial(jit, static_argnums=(1, 2), inline=True)
|
2020-11-19 09:22:31 -08:00
|
|
|
def _logistic(key, shape, dtype):
|
|
|
|
_check_shape("logistic", shape)
|
2021-01-07 16:40:30 -08:00
|
|
|
x = uniform(key, shape, dtype, minval=jnp.finfo(dtype).eps, maxval=1.)
|
2022-03-07 12:25:01 -08:00
|
|
|
return lax.log(lax.div(x, lax.sub(_lax_const(x, 1), x)))
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def pareto(key: KeyArray,
|
2021-03-17 16:37:09 -04:00
|
|
|
b: RealArray,
|
2022-09-30 19:37:48 -07:00
|
|
|
shape: Optional[Shape] = None,
|
|
|
|
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
2023-03-22 10:55:30 -07:00
|
|
|
r"""Sample Pareto random values with given shape and float dtype.
|
|
|
|
|
|
|
|
The values are distributed according to the probability density function:
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
f(x; b) = b / x^{b + 1}
|
|
|
|
|
|
|
|
on the domain :math:`0 \le x < \infty` with :math:`b > 0`
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
Args:
|
2021-06-08 11:16:33 -07:00
|
|
|
key: a PRNG key used as the random key.
|
2021-03-17 16:37:09 -04:00
|
|
|
b: a float or array of floats broadcast-compatible with ``shape``
|
2020-11-19 09:22:31 -08:00
|
|
|
representing the parameter of the distribution.
|
|
|
|
shape: optional, a tuple of nonnegative integers specifying the result
|
|
|
|
shape. Must be broadcast-compatible with ``b``. The default (None)
|
|
|
|
produces a result shape equal to ``b.shape``.
|
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified dtype and with shape given by ``shape`` if
|
|
|
|
``shape`` is not None, or else by ``b.shape``.
|
|
|
|
"""
|
2021-08-15 08:09:30 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2020-11-19 09:22:31 -08:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `pareto` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
|
|
|
if shape is not None:
|
2020-11-18 21:17:02 -05:00
|
|
|
shape = core.canonicalize_shape(shape)
|
2020-11-19 09:22:31 -08:00
|
|
|
return _pareto(key, b, shape, dtype)
|
|
|
|
|
2021-08-20 13:43:38 -07:00
|
|
|
@partial(jit, static_argnums=(2, 3), inline=True)
|
2022-09-30 19:37:48 -07:00
|
|
|
def _pareto(key, b, shape, dtype) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
if shape is None:
|
|
|
|
shape = np.shape(b)
|
|
|
|
else:
|
|
|
|
_check_shape("pareto", shape)
|
|
|
|
|
|
|
|
b = lax.convert_element_type(b, dtype)
|
|
|
|
e = exponential(key, shape, dtype)
|
|
|
|
return lax.exp(e / b)
|
|
|
|
|
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def t(key: KeyArray,
|
2021-03-17 16:37:09 -04:00
|
|
|
df: RealArray,
|
2022-09-30 19:37:48 -07:00
|
|
|
shape: Shape = (),
|
|
|
|
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
2023-03-22 10:55:30 -07:00
|
|
|
r"""Sample Student's t random values with given shape and float dtype.
|
|
|
|
|
|
|
|
The values are distributed according to the probability density function:
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
f(t; \nu) \propto \left(1 + \frac{t^2}{\nu}\right)^{-(\nu + 1)/2}
|
|
|
|
|
|
|
|
Where :math:`\nu > 0` is the degrees of freedom, given by the parameter ``df``.
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
Args:
|
2021-06-08 11:16:33 -07:00
|
|
|
key: a PRNG key used as the random key.
|
2020-11-19 09:22:31 -08:00
|
|
|
df: a float or array of floats broadcast-compatible with ``shape``
|
2023-03-22 10:55:30 -07:00
|
|
|
representing the degrees of freedom parameter of the distribution.
|
2020-11-19 09:22:31 -08:00
|
|
|
shape: optional, a tuple of nonnegative integers specifying the result
|
|
|
|
shape. Must be broadcast-compatible with ``df``. The default (None)
|
|
|
|
produces a result shape equal to ``df.shape``.
|
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified dtype and with shape given by ``shape`` if
|
|
|
|
``shape`` is not None, or else by ``df.shape``.
|
|
|
|
"""
|
2021-08-15 08:09:30 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2020-11-19 09:22:31 -08:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `t` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-11-18 21:17:02 -05:00
|
|
|
shape = core.canonicalize_shape(shape)
|
2020-11-19 09:22:31 -08:00
|
|
|
return _t(key, df, shape, dtype)
|
|
|
|
|
2021-08-20 13:43:38 -07:00
|
|
|
@partial(jit, static_argnums=(2, 3), inline=True)
|
2022-09-30 19:37:48 -07:00
|
|
|
def _t(key, df, shape, dtype) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
if shape is None:
|
|
|
|
shape = np.shape(df)
|
|
|
|
else:
|
|
|
|
_check_shape("t", shape, np.shape(df))
|
|
|
|
|
|
|
|
df = lax.convert_element_type(df, dtype)
|
2021-08-15 08:09:30 -07:00
|
|
|
key_n, key_g = _split(key)
|
2020-11-19 09:22:31 -08:00
|
|
|
n = normal(key_n, shape, dtype)
|
2022-03-07 12:25:01 -08:00
|
|
|
two = _lax_const(n, 2)
|
2020-11-19 09:22:31 -08:00
|
|
|
half_df = lax.div(df, two)
|
2023-02-21 14:52:05 +00:00
|
|
|
g = gamma(key_g, half_df, shape, dtype)
|
2020-11-19 09:22:31 -08:00
|
|
|
return n * jnp.sqrt(half_df / g)
|
|
|
|
|
|
|
|
|
2023-02-24 10:00:05 +08:00
|
|
|
def chisquare(key: KeyArray,
|
|
|
|
df: RealArray,
|
|
|
|
shape: Optional[Shape] = None,
|
|
|
|
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
2023-03-22 10:55:30 -07:00
|
|
|
r"""Sample Chisquare random values with given shape and float dtype.
|
|
|
|
|
|
|
|
The values are distributed according to the probability density function:
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
f(x; \nu) \propto x^{k/2 - 1}e^{-x/2}
|
|
|
|
|
|
|
|
on the domain :math:`0 < x < \infty`, where :math:`\nu > 0` represents the
|
|
|
|
degrees of freedom, given by the parameter ``df``.
|
2023-02-24 10:00:05 +08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNG key used as the random key.
|
|
|
|
df: a float or array of floats broadcast-compatible with ``shape``
|
|
|
|
representing the parameter of the distribution.
|
|
|
|
shape: optional, a tuple of nonnegative integers specifying the result
|
|
|
|
shape. Must be broadcast-compatible with ``df``. The default (None)
|
|
|
|
produces a result shape equal to ``df.shape``.
|
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified dtype and with shape given by ``shape`` if
|
|
|
|
``shape`` is not None, or else by ``df.shape``.
|
|
|
|
"""
|
|
|
|
key, _ = _check_prng_key(key)
|
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError("dtype argument to `chisquare` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
|
|
|
if shape is not None:
|
|
|
|
shape = core.canonicalize_shape(shape)
|
|
|
|
return _chisquare(key, df, shape, dtype)
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(2, 3), inline=True)
|
|
|
|
def _chisquare(key, df, shape, dtype) -> Array:
|
|
|
|
if shape is None:
|
|
|
|
shape = np.shape(df)
|
|
|
|
else:
|
|
|
|
_check_shape("chisquare", shape, np.shape(df))
|
|
|
|
df = lax.convert_element_type(df, dtype)
|
|
|
|
two = _lax_const(df, 2)
|
|
|
|
half_df = lax.div(df, two)
|
|
|
|
log_g = loggamma(key, a=half_df, shape=shape, dtype=dtype)
|
|
|
|
chi2 = lax.mul(jnp.exp(log_g), two)
|
|
|
|
return chi2
|
|
|
|
|
|
|
|
|
|
|
|
def f(key: KeyArray,
|
|
|
|
dfnum: RealArray,
|
|
|
|
dfden: RealArray,
|
|
|
|
shape: Optional[Shape] = None,
|
|
|
|
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
2023-03-22 10:55:30 -07:00
|
|
|
r"""Sample F-distribution random values with given shape and float dtype.
|
|
|
|
|
|
|
|
The values are distributed according to the probability density function:
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
f(x; \nu) \propto x^{\nu_1/2 - 1}\left(1 + \frac{\nu_1}{\nu_2}x\right)^{
|
|
|
|
-(\nu_1 + \nu_2) / 2}
|
|
|
|
|
|
|
|
on the domain :math:`0 < x < \infty`. Here :math:`\nu_1` is the degrees of
|
|
|
|
freedom of the numerator (``dfnum``), and :math:`\nu_2` is the degrees of
|
|
|
|
freedom of the denominator (``dfden``).
|
2023-02-24 10:00:05 +08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNG key used as the random key.
|
|
|
|
dfnum: a float or array of floats broadcast-compatible with ``shape``
|
|
|
|
representing the numerator's ``df`` of the distribution.
|
|
|
|
dfden: a float or array of floats broadcast-compatible with ``shape``
|
|
|
|
representing the denominator's ``df`` of the distribution.
|
|
|
|
shape: optional, a tuple of nonnegative integers specifying the result
|
|
|
|
shape. Must be broadcast-compatible with ``dfnum`` and ``dfden``.
|
|
|
|
The default (None) produces a result shape equal to ``dfnum.shape``,
|
|
|
|
and ``dfden.shape``.
|
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified dtype and with shape given by ``shape`` if
|
|
|
|
``shape`` is not None, or else by ``df.shape``.
|
|
|
|
"""
|
|
|
|
key, _ = _check_prng_key(key)
|
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError("dtype argument to `f` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
|
|
|
if shape is not None:
|
|
|
|
shape = core.canonicalize_shape(shape)
|
|
|
|
return _f(key, dfnum, dfden, shape, dtype)
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(3, 4), inline=True)
|
|
|
|
def _f(key, dfnum, dfden, shape, dtype) -> Array:
|
|
|
|
if shape is None:
|
|
|
|
shape = lax.broadcast_shapes(np.shape(dfden), np.shape(dfnum))
|
|
|
|
else:
|
|
|
|
_check_shape("f", shape, np.shape(dfden), np.shape(dfnum))
|
|
|
|
dfden = lax.convert_element_type(dfden, dtype)
|
|
|
|
dfnum = lax.convert_element_type(dfnum, dtype)
|
|
|
|
key_dfd, key_dfn = _split(key)
|
|
|
|
chi2_dfn = chisquare(key_dfn, dfnum, shape, dtype)
|
|
|
|
chi2_dfd = chisquare(key_dfd, dfden, shape, dtype)
|
|
|
|
# broadcast dfden and dfnum to do div operation
|
|
|
|
dfden = jnp.broadcast_to(dfden, shape)
|
|
|
|
dfnum = jnp.broadcast_to(dfnum, shape)
|
|
|
|
num = lax.div(chi2_dfn, dfnum)
|
|
|
|
den = lax.div(chi2_dfd ,dfden)
|
|
|
|
f = lax.div(num, den)
|
|
|
|
return f
|
|
|
|
|
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def rademacher(key: KeyArray,
|
2022-09-30 19:37:48 -07:00
|
|
|
shape: Shape,
|
|
|
|
dtype: DTypeLikeInt = dtypes.int_) -> Array:
|
2023-03-22 10:55:30 -07:00
|
|
|
r"""Sample from a Rademacher distribution.
|
|
|
|
|
|
|
|
The values are distributed according to the probability mass function:
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
f(k) = \frac{1}{2}(\delta(k - 1) + \delta(k + 1))
|
|
|
|
|
|
|
|
on the domain :math:`k \in \{-1, 1}`, where `\delta(x)` is the dirac delta function.
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
Args:
|
2021-06-08 11:16:33 -07:00
|
|
|
key: a PRNG key.
|
2020-11-19 09:22:31 -08:00
|
|
|
shape: The shape of the returned samples.
|
|
|
|
dtype: The type used for samples.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A jnp.array of samples, of shape `shape`. Each element in the output has
|
|
|
|
a 50% change of being 1 or -1.
|
|
|
|
|
|
|
|
"""
|
2021-08-15 08:09:30 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2020-11-19 09:22:31 -08:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-11-18 21:17:02 -05:00
|
|
|
shape = core.canonicalize_shape(shape)
|
2020-11-19 09:22:31 -08:00
|
|
|
return _rademacher(key, shape, dtype)
|
|
|
|
|
|
|
|
|
2021-08-20 13:43:38 -07:00
|
|
|
@partial(jit, static_argnums=(1, 2), inline=True)
|
2022-09-30 19:37:48 -07:00
|
|
|
def _rademacher(key, shape, dtype) -> Array:
|
2022-05-27 11:12:39 -07:00
|
|
|
bernoulli_samples = bernoulli(key=key, p=0.5, shape=shape).astype(dtype)
|
2020-11-19 09:22:31 -08:00
|
|
|
return (2 * bernoulli_samples - 1).astype(dtype)
|
|
|
|
|
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def maxwell(key: KeyArray,
|
2022-09-30 19:37:48 -07:00
|
|
|
shape: Shape = (),
|
|
|
|
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
2023-03-22 10:55:30 -07:00
|
|
|
r"""Sample from a one sided Maxwell distribution.
|
|
|
|
|
|
|
|
The values are distributed according to the probability density function:
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
f(x) \propto x^2 e^{-x^2 / 2}
|
2020-11-19 09:22:31 -08:00
|
|
|
|
2023-03-22 10:55:30 -07:00
|
|
|
on the domain :math:`0 \le x < \infty`.
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
Args:
|
2021-06-08 11:16:33 -07:00
|
|
|
key: a PRNG key.
|
2020-11-19 09:22:31 -08:00
|
|
|
shape: The shape of the returned samples.
|
|
|
|
dtype: The type used for samples.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A jnp.array of samples, of shape `shape`.
|
|
|
|
|
|
|
|
"""
|
|
|
|
# Generate samples using:
|
|
|
|
# sqrt(X^2 + Y^2 + Z^2), X,Y,Z ~N(0,1)
|
2021-08-15 08:09:30 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2020-11-19 09:22:31 -08:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `maxwell` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-11-18 21:17:02 -05:00
|
|
|
shape = core.canonicalize_shape(shape)
|
2020-11-19 09:22:31 -08:00
|
|
|
return _maxwell(key, shape, dtype)
|
|
|
|
|
|
|
|
|
2021-08-20 13:43:38 -07:00
|
|
|
@partial(jit, static_argnums=(1, 2), inline=True)
|
2022-09-30 19:37:48 -07:00
|
|
|
def _maxwell(key, shape, dtype) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
shape = shape + (3,)
|
|
|
|
norm_rvs = normal(key=key, shape=shape, dtype=dtype)
|
|
|
|
return jnp.linalg.norm(norm_rvs, axis=-1)
|
|
|
|
|
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def double_sided_maxwell(key: KeyArray,
|
2021-03-17 16:37:09 -04:00
|
|
|
loc: RealArray,
|
|
|
|
scale: RealArray,
|
2022-09-30 19:37:48 -07:00
|
|
|
shape: Shape = (),
|
|
|
|
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
2023-03-22 10:55:30 -07:00
|
|
|
r"""Sample from a double sided Maxwell distribution.
|
|
|
|
|
|
|
|
The values are distributed according to the probability density function:
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
f(x;\mu,\sigma) \propto z^2 e^{-z^2 / 2}
|
2020-11-19 09:22:31 -08:00
|
|
|
|
2023-03-22 10:55:30 -07:00
|
|
|
where :math:`z = (x - \mu) / \sigma`, with the center :math:`\mu` specified by
|
|
|
|
``loc`` and the scale :math:`\sigma` specified by ``scale``.
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
Args:
|
2021-06-08 11:16:33 -07:00
|
|
|
key: a PRNG key.
|
2020-11-19 09:22:31 -08:00
|
|
|
loc: The location parameter of the distribution.
|
|
|
|
scale: The scale parameter of the distribution.
|
|
|
|
shape: The shape added to the parameters loc and scale broadcastable shape.
|
|
|
|
dtype: The type used for samples.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A jnp.array of samples.
|
|
|
|
|
|
|
|
"""
|
2021-08-15 08:09:30 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2020-11-19 09:22:31 -08:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `double_sided_maxwell` must be a float"
|
|
|
|
f" dtype, got {dtype}")
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-11-18 21:17:02 -05:00
|
|
|
shape = core.canonicalize_shape(shape)
|
2020-11-19 09:22:31 -08:00
|
|
|
return _double_sided_maxwell(key, loc, scale, shape, dtype)
|
|
|
|
|
|
|
|
|
2021-08-20 13:43:38 -07:00
|
|
|
@partial(jit, static_argnums=(3, 4), inline=True)
|
2022-09-30 19:37:48 -07:00
|
|
|
def _double_sided_maxwell(key, loc, scale, shape, dtype) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
params_shapes = lax.broadcast_shapes(np.shape(loc), np.shape(scale))
|
|
|
|
if not shape:
|
|
|
|
shape = params_shapes
|
|
|
|
|
|
|
|
shape = shape + params_shapes
|
2021-08-15 08:09:30 -07:00
|
|
|
maxwell_key, rademacher_key = _split(key)
|
2020-11-19 09:22:31 -08:00
|
|
|
maxwell_rvs = maxwell(maxwell_key, shape=shape, dtype=dtype)
|
|
|
|
# Generate random signs for the symmetric variates.
|
|
|
|
random_sign = rademacher(rademacher_key, shape=shape, dtype=dtype)
|
|
|
|
assert random_sign.shape == maxwell_rvs.shape
|
|
|
|
|
|
|
|
return random_sign * maxwell_rvs * scale + loc
|
|
|
|
|
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def weibull_min(key: KeyArray,
|
2021-03-17 16:37:09 -04:00
|
|
|
scale: RealArray,
|
|
|
|
concentration: RealArray,
|
2022-09-30 19:37:48 -07:00
|
|
|
shape: Shape = (),
|
|
|
|
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
2023-03-22 10:55:30 -07:00
|
|
|
r"""Sample from a Weibull distribution.
|
2020-11-19 09:22:31 -08:00
|
|
|
|
2023-03-22 10:55:30 -07:00
|
|
|
The values are distributed according to the probability density function:
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
f(x;\sigma,c) \propto x^{c - 1} \exp(-(x / \sigma)^c)
|
|
|
|
|
|
|
|
on the domain :math:`0 < x < \infty`, where :math:`c > 0` is the concentration
|
|
|
|
parameter, and :math:`\sigma > 0` is the scale parameter.
|
2020-11-19 09:22:31 -08:00
|
|
|
|
|
|
|
Args:
|
2021-06-08 11:16:33 -07:00
|
|
|
key: a PRNG key.
|
2020-11-19 09:22:31 -08:00
|
|
|
scale: The scale parameter of the distribution.
|
|
|
|
concentration: The concentration parameter of the distribution.
|
|
|
|
shape: The shape added to the parameters loc and scale broadcastable shape.
|
|
|
|
dtype: The type used for samples.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A jnp.array of samples.
|
|
|
|
|
|
|
|
"""
|
2021-08-15 08:09:30 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2020-11-19 09:22:31 -08:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `weibull_min` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-11-18 21:17:02 -05:00
|
|
|
shape = core.canonicalize_shape(shape)
|
2020-11-19 09:22:31 -08:00
|
|
|
return _weibull_min(key, scale, concentration, shape, dtype)
|
|
|
|
|
|
|
|
|
2021-08-20 13:43:38 -07:00
|
|
|
@partial(jit, static_argnums=(3, 4), inline=True)
|
2022-09-30 19:37:48 -07:00
|
|
|
def _weibull_min(key, scale, concentration, shape, dtype) -> Array:
|
2020-11-19 09:22:31 -08:00
|
|
|
random_uniform = uniform(
|
|
|
|
key=key, shape=shape, minval=0, maxval=1, dtype=dtype)
|
|
|
|
|
|
|
|
# Inverse weibull CDF.
|
|
|
|
return jnp.power(-jnp.log1p(-random_uniform), 1.0/concentration) * scale
|
2021-08-15 09:04:25 -07:00
|
|
|
|
|
|
|
|
|
|
|
# TODO(frostig): remove these aliases
|
|
|
|
|
|
|
|
threefry2x32_p = prng.threefry2x32_p
|
|
|
|
|
|
|
|
def threefry_2x32(keypair, count):
|
|
|
|
warnings.warn('jax.random.threefry_2x32 has moved to jax.prng.threefry_2x32 '
|
|
|
|
'and will be removed from `random` module.', FutureWarning)
|
|
|
|
return prng.threefry_2x32(keypair, count)
|
2022-04-29 14:20:50 -04:00
|
|
|
|
|
|
|
def orthogonal(
|
|
|
|
key: KeyArray,
|
|
|
|
n: int,
|
2022-09-30 19:37:48 -07:00
|
|
|
shape: Shape = (),
|
2022-04-29 14:20:50 -04:00
|
|
|
dtype: DTypeLikeFloat = dtypes.float_
|
2022-09-30 19:37:48 -07:00
|
|
|
) -> Array:
|
2022-04-29 14:20:50 -04:00
|
|
|
"""Sample uniformly from the orthogonal group O(n).
|
|
|
|
|
|
|
|
If the dtype is complex, sample uniformly from the unitary group U(n).
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNG key used as the random key.
|
|
|
|
n: an integer indicating the resulting dimension.
|
|
|
|
shape: optional, the batch dimensions of the result. Default ().
|
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array of shape `(*shape, n, n)` and specified dtype.
|
|
|
|
"""
|
2022-08-22 13:56:50 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2022-04-29 14:20:50 -04:00
|
|
|
_check_shape("orthogonal", shape)
|
|
|
|
n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
|
|
|
|
z = normal(key, (*shape, n, n), dtype)
|
|
|
|
q, r = jnp.linalg.qr(z)
|
|
|
|
d = jnp.diagonal(r, 0, -2, -1)
|
2022-05-27 11:12:39 -07:00
|
|
|
return lax.mul(q, lax.expand_dims(lax.div(d, abs(d).astype(d.dtype)), [-2]))
|
2022-06-03 15:11:29 -04:00
|
|
|
|
|
|
|
def generalized_normal(
|
|
|
|
key: KeyArray,
|
|
|
|
p: float,
|
2022-09-30 19:37:48 -07:00
|
|
|
shape: Shape = (),
|
2022-06-03 15:11:29 -04:00
|
|
|
dtype: DTypeLikeFloat = dtypes.float_
|
2022-09-30 19:37:48 -07:00
|
|
|
) -> Array:
|
2023-03-22 10:55:30 -07:00
|
|
|
r"""Sample from the generalized normal distribution.
|
|
|
|
|
|
|
|
The values are returned according to the probability density function:
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
f(x;p) \propto e^{-|x|^p}
|
|
|
|
|
|
|
|
on the domain :math:`-\infty < x < \infty`, where :math:`p > 0` is the
|
|
|
|
shape parameter.
|
2022-06-03 15:11:29 -04:00
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNG key used as the random key.
|
|
|
|
p: a float representing the shape parameter.
|
|
|
|
shape: optional, the batch dimensions of the result. Default ().
|
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
2022-08-22 13:56:50 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2022-06-03 15:11:29 -04:00
|
|
|
_check_shape("generalized_normal", shape)
|
|
|
|
keys = split(key)
|
|
|
|
g = gamma(keys[0], 1/p, shape, dtype)
|
|
|
|
r = rademacher(keys[1], shape, dtype)
|
|
|
|
return r * g ** (1 / p)
|
|
|
|
|
|
|
|
def ball(
|
|
|
|
key: KeyArray,
|
|
|
|
d: int,
|
|
|
|
p: float = 2,
|
2022-09-30 19:37:48 -07:00
|
|
|
shape: Shape = (),
|
2022-06-03 15:11:29 -04:00
|
|
|
dtype: DTypeLikeFloat = dtypes.float_
|
|
|
|
):
|
|
|
|
"""Sample uniformly from the unit Lp ball.
|
|
|
|
|
|
|
|
Reference: https://arxiv.org/abs/math/0503650.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNG key used as the random key.
|
|
|
|
d: a nonnegative int representing the dimensionality of the ball.
|
|
|
|
p: a float representing the p parameter of the Lp norm.
|
|
|
|
shape: optional, the batch dimensions of the result. Default ().
|
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array of shape `(*shape, d)` and specified dtype.
|
|
|
|
"""
|
2022-08-22 13:56:50 -07:00
|
|
|
key, _ = _check_prng_key(key)
|
2022-06-03 15:11:29 -04:00
|
|
|
_check_shape("ball", shape)
|
|
|
|
d = core.concrete_or_error(index, d, "The error occurred in jax.random.ball()")
|
2022-08-22 13:56:50 -07:00
|
|
|
k1, k2 = split(key)
|
|
|
|
g = generalized_normal(k1, p, (*shape, d), dtype)
|
|
|
|
e = exponential(k2, shape, dtype)
|
2022-06-03 15:11:29 -04:00
|
|
|
return g / (((jnp.abs(g) ** p).sum(-1) + e) ** (1 / p))[..., None]
|
2023-03-10 16:34:29 +08:00
|
|
|
|
|
|
|
|
|
|
|
def rayleigh(key: KeyArray,
|
|
|
|
scale: RealArray,
|
|
|
|
shape: Optional[Shape] = None,
|
|
|
|
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
2023-03-22 10:55:30 -07:00
|
|
|
r"""Sample Rayleigh random values with given shape and float dtype.
|
|
|
|
|
|
|
|
The values are returned according to the probability density function:
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
f(x;\sigma) \propto xe^{-x^2/(2\sigma^2)}
|
|
|
|
|
|
|
|
on the domain :math:`-\infty < x < \infty`, and where `\sigma > 0` is the scale
|
|
|
|
parameter of the distribution.
|
2023-03-10 16:34:29 +08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNG key used as the random key.
|
|
|
|
scale: a float or array of floats broadcast-compatible with ``shape``
|
|
|
|
representing the parameter of the distribution.
|
|
|
|
shape: optional, a tuple of nonnegative integers specifying the result
|
|
|
|
shape. Must be broadcast-compatible with ``scale``. The default (None)
|
|
|
|
produces a result shape equal to ``scale.shape``.
|
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified dtype and with shape given by ``shape`` if
|
|
|
|
``shape`` is not None, or else by ``scale.shape``.
|
|
|
|
"""
|
|
|
|
key, _ = _check_prng_key(key)
|
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError("dtype argument to `rayleigh` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
|
|
|
if shape is not None:
|
|
|
|
shape = core.canonicalize_shape(shape)
|
|
|
|
return _rayleigh(key, scale, shape, dtype)
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(2, 3), inline=True)
|
|
|
|
def _rayleigh(key, scale, shape, dtype) -> Array:
|
|
|
|
if shape is None:
|
|
|
|
shape = np.shape(scale)
|
|
|
|
else:
|
|
|
|
_check_shape("rayleigh", shape, np.shape(scale))
|
|
|
|
u = uniform(key, shape, dtype)
|
|
|
|
scale = scale.astype(dtype)
|
|
|
|
log_u = lax.log(u)
|
|
|
|
n_two = _lax_const(scale, -2)
|
|
|
|
sqrt_u = lax.sqrt(lax.mul(log_u, n_two))
|
|
|
|
ray = lax.mul(scale, sqrt_u)
|
|
|
|
return ray
|
2023-03-21 11:13:40 +08:00
|
|
|
|
|
|
|
def wald(key: KeyArray,
|
|
|
|
mean: RealArray,
|
|
|
|
shape: Optional[Shape] = None,
|
|
|
|
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
|
2023-03-22 10:55:30 -07:00
|
|
|
r"""Sample Wald random values with given shape and float dtype.
|
|
|
|
|
|
|
|
The values are returned according to the probability density function:
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
f(x;\mu) = \frac{1}{\sqrt{2\pi x^3}} \exp\left(-\frac{(x - \mu)^2}{2\mu^2 x}\right)
|
|
|
|
|
|
|
|
on the domain :math:`-\infty < x < \infty`, and where :math:`\mu > 0` is the location
|
|
|
|
parameter of the distribution.
|
|
|
|
|
2023-03-21 11:13:40 +08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNG key used as the random key.
|
|
|
|
mean: a float or array of floats broadcast-compatible with ``shape``
|
|
|
|
representing the mean parameter of the distribution.
|
|
|
|
shape: optional, a tuple of nonnegative integers specifying the result
|
2023-03-22 10:44:50 -07:00
|
|
|
shape. Must be broadcast-compatible with ``mean``. The default
|
|
|
|
(None) produces a result shape equal to ``np.shape(mean)``.
|
2023-03-21 11:13:40 +08:00
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified dtype and with shape given by ``shape`` if
|
|
|
|
``shape`` is not None, or else by ``mean.shape`` and ``scale.shape``.
|
|
|
|
"""
|
|
|
|
key, _ = _check_prng_key(key)
|
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError("dtype argument to `wald` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
|
|
|
if shape is not None:
|
|
|
|
shape = core.canonicalize_shape(shape)
|
2023-03-22 10:44:50 -07:00
|
|
|
return _wald(key, mean, shape, dtype)
|
2023-03-21 11:13:40 +08:00
|
|
|
|
2023-03-22 10:44:50 -07:00
|
|
|
@partial(jit, static_argnums=(2, 3), inline=True)
|
|
|
|
def _wald(key, mean, shape, dtype) -> Array:
|
2023-03-21 11:13:40 +08:00
|
|
|
if shape is None:
|
2023-03-22 10:44:50 -07:00
|
|
|
shape = np.shape(mean)
|
2023-03-21 11:13:40 +08:00
|
|
|
else:
|
2023-03-22 10:44:50 -07:00
|
|
|
_check_shape("wald", shape, np.shape(mean))
|
2023-03-21 11:13:40 +08:00
|
|
|
k1, k2 = _split(key, 2)
|
|
|
|
mean = mean.astype(dtype)
|
|
|
|
mean = jnp.broadcast_to(mean, shape)
|
|
|
|
v = normal(k1, shape, dtype)
|
|
|
|
z = uniform(k2, shape, dtype)
|
|
|
|
y = lax.integer_pow(v, 2)
|
|
|
|
y_sq = lax.integer_pow(y, 2)
|
|
|
|
mean_sq = lax.integer_pow(mean, 2)
|
2023-03-22 10:44:50 -07:00
|
|
|
sqrt_term = lax.sqrt(4 * mean * y + mean_sq * y_sq)
|
|
|
|
x = mean + mean_sq * y / 2 - mean / 2 * sqrt_term
|
2023-03-21 11:13:40 +08:00
|
|
|
w = lax.select(lax.le(z, mean / (mean + x)), x, mean_sq / x)
|
|
|
|
return w
|