2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2019-02-13 19:42:47 -08:00
|
|
|
"""JAX pseudo-random number generators (PRNGs).
|
|
|
|
|
2020-05-26 09:38:28 +02:00
|
|
|
Example usage:
|
2019-02-13 19:42:47 -08:00
|
|
|
|
2020-05-29 15:41:28 -07:00
|
|
|
>>> rng = jax.random.PRNGKey(seed)
|
|
|
|
>>> for i in range(num_steps):
|
|
|
|
... rng, rng_input = jax.random.split(rng)
|
2020-05-29 16:00:20 -07:00
|
|
|
... params = compiled_update(rng_input, params, next(batches))
|
2020-05-26 09:38:28 +02:00
|
|
|
|
|
|
|
Context:
|
|
|
|
|
|
|
|
Among other requirements, the JAX PRNG aims to:
|
|
|
|
(a) ensure reproducibility,
|
|
|
|
(b) parallelize well, both in terms of vectorization (generating array values)
|
2020-06-02 17:37:20 -07:00
|
|
|
and multi-replica, multi-core computation. In particular it should not use
|
2020-05-26 09:38:28 +02:00
|
|
|
sequencing constraints between random function calls.
|
|
|
|
|
|
|
|
The approach is based on:
|
|
|
|
1. "Parallel random numbers: as easy as 1, 2, 3" (Salmon et al. 2011)
|
|
|
|
2. "Splittable pseudorandom number generators using cryptographic hashing"
|
|
|
|
(Claessen et al. 2013)
|
|
|
|
|
|
|
|
See also https://github.com/google/jax/blob/master/design_notes/prng.md
|
|
|
|
for the design and its motivation.
|
2019-02-13 19:42:47 -08:00
|
|
|
"""
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2018-11-21 18:31:13 -08:00
|
|
|
from functools import partial
|
2020-04-12 09:14:54 +01:00
|
|
|
from typing import Optional, Sequence, Union
|
2020-05-01 15:18:24 -07:00
|
|
|
import warnings
|
2018-11-21 18:31:13 -08:00
|
|
|
|
2020-05-21 18:12:18 -03:00
|
|
|
import numpy as np
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
from . import lax
|
2020-05-21 18:12:18 -03:00
|
|
|
from . import numpy as jnp
|
2019-11-15 10:02:51 -05:00
|
|
|
from . import dtypes
|
2020-01-15 15:00:38 -08:00
|
|
|
from .api import jit, vmap
|
2020-04-12 09:14:54 +01:00
|
|
|
from .numpy.lax_numpy import _constant_like, asarray
|
2018-12-06 21:35:03 -05:00
|
|
|
from jax.lib import xla_bridge
|
2020-04-23 18:30:47 -04:00
|
|
|
from jax.lib import xla_client
|
2019-11-24 13:06:23 -05:00
|
|
|
from jax.lib import cuda_prng
|
2018-12-15 19:14:05 -08:00
|
|
|
from jax import core
|
2019-11-24 13:06:23 -05:00
|
|
|
from jax import abstract_arrays
|
2019-12-17 13:14:10 -08:00
|
|
|
from jax.numpy.linalg import cholesky
|
2019-12-01 09:44:45 -05:00
|
|
|
from jax.interpreters import ad
|
2019-11-24 13:06:23 -05:00
|
|
|
from jax.interpreters import batching
|
|
|
|
from jax.interpreters import xla
|
2019-12-26 22:43:06 -05:00
|
|
|
from jax.util import prod
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-30 21:42:55 -08:00
|
|
|
|
2020-05-21 18:12:18 -03:00
|
|
|
_UINT_DTYPES = {8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64}
|
2020-05-15 19:09:43 -07:00
|
|
|
|
|
|
|
|
2020-05-21 18:12:18 -03:00
|
|
|
def PRNGKey(seed: int) -> jnp.ndarray:
|
2019-02-16 23:31:27 +09:00
|
|
|
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
|
2019-02-13 19:42:47 -08:00
|
|
|
|
|
|
|
Args:
|
2019-02-13 20:05:15 -08:00
|
|
|
seed: a 64- or 32-bit integer used as the value of the key.
|
2019-02-13 19:42:47 -08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
A PRNG key, which is modeled as an array of shape (2,) and dtype uint32. The
|
|
|
|
key is constructed from a 64-bit seed by effectively bit-casting to a pair
|
|
|
|
of uint32 values (or from a 32-bit seed by first padding out with zeros).
|
|
|
|
"""
|
2020-05-21 18:12:18 -03:00
|
|
|
if np.shape(seed):
|
2018-12-30 21:42:55 -08:00
|
|
|
raise TypeError("PRNGKey seed must be a scalar.")
|
2020-05-21 18:12:18 -03:00
|
|
|
convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1])
|
|
|
|
if isinstance(seed, (int, np.ndarray)):
|
2018-12-30 21:42:55 -08:00
|
|
|
# Special handling of raw integer values, which may have be 64bit even
|
|
|
|
# when jax_enable_x64=False and we don't want to drop the top 32 bits
|
2020-05-21 18:12:18 -03:00
|
|
|
k1 = convert(np.bitwise_and(np.right_shift(seed, 32), 0xFFFFFFFF))
|
2018-12-30 21:42:55 -08:00
|
|
|
else:
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
k1 = convert(lax.shift_right_logical(seed, lax._const(seed, 32)))
|
2020-05-21 18:12:18 -03:00
|
|
|
k2 = convert(jnp.bitwise_and(seed, 0xFFFFFFFF))
|
2018-12-30 21:42:55 -08:00
|
|
|
return lax.concatenate([k1, k2], 0)
|
|
|
|
|
2020-05-21 18:12:18 -03:00
|
|
|
def _is_prng_key(key: jnp.ndarray) -> bool:
|
2018-12-30 21:42:55 -08:00
|
|
|
try:
|
2020-05-21 18:12:18 -03:00
|
|
|
return key.shape == (2,) and key.dtype == np.uint32
|
2018-12-30 21:42:55 -08:00
|
|
|
except AttributeError:
|
|
|
|
return False
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
### utilities
|
|
|
|
|
|
|
|
|
2020-08-26 10:21:56 -07:00
|
|
|
# TODO(mattjj,jakevdp): add more info to error message, use this utility more
|
|
|
|
def _asarray(x):
|
|
|
|
"""A more restrictive jnp.asarray, only accepts JAX arrays and np.ndarrays."""
|
|
|
|
if not isinstance(x, (np.ndarray, jnp.ndarray)):
|
|
|
|
raise TypeError(f"Function requires array input, got {x} of type {type(x)}.")
|
|
|
|
return jnp.asarray(x)
|
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def _make_rotate_left(dtype):
|
2020-05-21 18:12:18 -03:00
|
|
|
if not jnp.issubdtype(dtype, np.integer):
|
2018-11-17 18:03:33 -08:00
|
|
|
raise TypeError("_rotate_left only accepts integer dtypes.")
|
2020-05-21 18:12:18 -03:00
|
|
|
nbits = np.array(jnp.iinfo(dtype).bits, dtype)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def _rotate_left(x, d):
|
2020-09-07 14:41:50 +03:00
|
|
|
if lax.dtype(d) != dtype:
|
|
|
|
d = lax.convert_element_type(d, dtype)
|
|
|
|
if lax.dtype(x) != dtype:
|
|
|
|
x = lax.convert_element_type(x, dtype)
|
2019-06-11 14:56:21 -07:00
|
|
|
return lax.shift_left(x, d) | lax.shift_right_logical(x, nbits - d)
|
2018-11-17 18:03:33 -08:00
|
|
|
return _rotate_left
|
|
|
|
|
|
|
|
|
|
|
|
def _bit_stats(bits):
|
|
|
|
"""This is a debugging function to compute the statistics of bit fields."""
|
2020-05-21 18:12:18 -03:00
|
|
|
return np.array([list(map(int, np.binary_repr(x, 64))) for x in bits]).mean(0)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
### hash function and split
|
|
|
|
|
2019-11-24 13:06:23 -05:00
|
|
|
def _threefry2x32_abstract_eval(*args):
|
2020-05-21 18:12:18 -03:00
|
|
|
if any(a.dtype != jnp.uint32 for a in args):
|
2019-11-24 13:06:23 -05:00
|
|
|
raise TypeError("Arguments to threefry2x32 must have uint32 type, got {}"
|
|
|
|
.format(args))
|
|
|
|
if all(isinstance(arg, abstract_arrays.ShapedArray) for arg in args):
|
|
|
|
shape = lax._broadcasting_shape_rule(*args)
|
2020-05-21 18:12:18 -03:00
|
|
|
aval = abstract_arrays.ShapedArray(shape, jnp.dtype(jnp.uint32))
|
2019-11-24 13:06:23 -05:00
|
|
|
else:
|
2020-05-21 18:12:18 -03:00
|
|
|
aval = abstract_arrays.UnshapedArray(jnp.dtype(jnp.uint32))
|
2019-11-24 13:06:23 -05:00
|
|
|
return (aval,) * 2
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-05-21 18:12:18 -03:00
|
|
|
rotate_left = _make_rotate_left(np.uint32)
|
2020-05-12 23:03:22 +01:00
|
|
|
|
|
|
|
def apply_round(v, rot):
|
|
|
|
v = v[:]
|
|
|
|
v[0] = v[0] + v[1]
|
|
|
|
v[1] = rotate_left(v[1], rot)
|
|
|
|
v[1] = v[0] ^ v[1]
|
|
|
|
return v
|
|
|
|
|
|
|
|
def rotate_list(xs):
|
|
|
|
return xs[1:] + xs[:1]
|
|
|
|
|
|
|
|
def rolled_loop_step(i, state):
|
|
|
|
x, ks, rotations = state
|
|
|
|
for r in rotations[0]:
|
|
|
|
x = apply_round(x, r)
|
2020-05-21 18:12:18 -03:00
|
|
|
new_x = [x[0] + ks[0], x[1] + ks[1] + asarray(i + 1, dtype=np.uint32)]
|
2020-05-12 23:03:22 +01:00
|
|
|
return new_x, rotate_list(ks), rotate_list(rotations)
|
|
|
|
|
2019-11-24 13:06:23 -05:00
|
|
|
def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True):
|
2018-11-17 18:03:33 -08:00
|
|
|
"""Apply the Threefry 2x32 hash.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
keypair: a pair of 32bit unsigned integers used for the key.
|
|
|
|
count: an array of dtype uint32 used for the counts.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array of dtype uint32 with the same shape as `count`.
|
|
|
|
"""
|
2019-11-24 13:06:23 -05:00
|
|
|
x = [x1, x2]
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-05-21 18:12:18 -03:00
|
|
|
rotations = [np.array([13, 15, 26, 6], dtype=np.uint32),
|
|
|
|
np.array([17, 29, 16, 24], dtype=np.uint32)]
|
|
|
|
ks = [key1, key2, key1 ^ key2 ^ np.uint32(0x1BD11BDA)]
|
2019-08-15 16:37:04 -04:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
x[0] = x[0] + ks[0]
|
|
|
|
x[1] = x[1] + ks[1]
|
|
|
|
|
2019-08-29 22:21:31 -07:00
|
|
|
if use_rolled_loops:
|
2020-05-12 23:03:22 +01:00
|
|
|
x, _, _ = lax.fori_loop(0, 5, rolled_loop_step, (x, rotate_list(ks), rotations))
|
2019-08-29 22:21:31 -07:00
|
|
|
|
|
|
|
else:
|
|
|
|
for r in rotations[0]:
|
|
|
|
x = apply_round(x, r)
|
|
|
|
x[0] = x[0] + ks[1]
|
2020-05-21 18:12:18 -03:00
|
|
|
x[1] = x[1] + ks[2] + np.uint32(1)
|
2019-08-29 22:21:31 -07:00
|
|
|
|
|
|
|
for r in rotations[1]:
|
|
|
|
x = apply_round(x, r)
|
|
|
|
x[0] = x[0] + ks[2]
|
2020-05-21 18:12:18 -03:00
|
|
|
x[1] = x[1] + ks[0] + np.uint32(2)
|
2019-08-29 22:21:31 -07:00
|
|
|
|
2019-08-15 16:37:04 -04:00
|
|
|
for r in rotations[0]:
|
2019-08-13 09:43:44 -07:00
|
|
|
x = apply_round(x, r)
|
2019-08-29 22:21:31 -07:00
|
|
|
x[0] = x[0] + ks[0]
|
2020-05-21 18:12:18 -03:00
|
|
|
x[1] = x[1] + ks[1] + np.uint32(3)
|
2019-08-29 22:21:31 -07:00
|
|
|
|
|
|
|
for r in rotations[1]:
|
|
|
|
x = apply_round(x, r)
|
|
|
|
x[0] = x[0] + ks[1]
|
2020-05-21 18:12:18 -03:00
|
|
|
x[1] = x[1] + ks[2] + np.uint32(4)
|
2019-08-29 22:21:31 -07:00
|
|
|
|
|
|
|
for r in rotations[0]:
|
|
|
|
x = apply_round(x, r)
|
|
|
|
x[0] = x[0] + ks[2]
|
2020-05-21 18:12:18 -03:00
|
|
|
x[1] = x[1] + ks[0] + np.uint32(5)
|
2019-08-29 22:21:31 -07:00
|
|
|
|
2019-11-24 13:06:23 -05:00
|
|
|
return tuple(x)
|
|
|
|
|
|
|
|
|
|
|
|
def _threefry2x32_gpu_translation_rule(c, k1, k2, x1, x2):
|
|
|
|
shape = lax.broadcast_shapes(
|
2020-05-11 17:43:55 -04:00
|
|
|
c.get_shape(k1).dimensions(), c.get_shape(k2).dimensions(),
|
|
|
|
c.get_shape(x1).dimensions(), c.get_shape(x2).dimensions())
|
2019-11-24 13:06:23 -05:00
|
|
|
rank = len(shape)
|
|
|
|
def _broadcast(x):
|
2020-05-11 17:43:55 -04:00
|
|
|
ndims = c.get_shape(x).rank()
|
2020-04-23 18:30:47 -04:00
|
|
|
return xla_client.ops.BroadcastInDim(x, shape,
|
|
|
|
tuple(range(rank - ndims, rank)))
|
2019-11-24 13:06:23 -05:00
|
|
|
return cuda_prng.threefry2x32(
|
2020-05-11 17:43:55 -04:00
|
|
|
c, (_broadcast(k1), _broadcast(k2)), (_broadcast(x1), _broadcast(x2)))
|
2019-11-24 13:06:23 -05:00
|
|
|
|
|
|
|
threefry2x32_p = core.Primitive("threefry2x32")
|
|
|
|
threefry2x32_p.multiple_results = True
|
|
|
|
threefry2x32_p.def_impl(partial(xla.apply_primitive, threefry2x32_p))
|
|
|
|
threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval)
|
|
|
|
batching.defbroadcasting(threefry2x32_p)
|
|
|
|
xla.translations[threefry2x32_p] = xla.lower_fun(
|
2020-06-22 17:50:33 -07:00
|
|
|
partial(_threefry2x32_lowering, use_rolled_loops=False),
|
|
|
|
multiple_results=True)
|
2019-11-24 13:06:23 -05:00
|
|
|
xla.backend_specific_translations['cpu'][threefry2x32_p] = xla.lower_fun(
|
2020-06-22 17:50:33 -07:00
|
|
|
partial(_threefry2x32_lowering, use_rolled_loops=True),
|
|
|
|
multiple_results=True)
|
2019-11-24 13:06:23 -05:00
|
|
|
if cuda_prng:
|
|
|
|
xla.backend_specific_translations['gpu'][threefry2x32_p] = \
|
|
|
|
_threefry2x32_gpu_translation_rule
|
|
|
|
|
|
|
|
@jit
|
|
|
|
def threefry_2x32(keypair, count):
|
|
|
|
"""Apply the Threefry 2x32 hash.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
keypair: a pair of 32bit unsigned integers used for the key.
|
|
|
|
count: an array of dtype uint32 used for the counts.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array of dtype uint32 with the same shape as `count`.
|
|
|
|
"""
|
|
|
|
key1, key2 = keypair
|
2020-05-21 18:12:18 -03:00
|
|
|
if not lax.dtype(key1) == lax.dtype(key2) == lax.dtype(count) == np.uint32:
|
2019-11-24 13:06:23 -05:00
|
|
|
msg = "threefry_2x32 requires uint32 arguments, got {}"
|
|
|
|
raise TypeError(msg.format([lax.dtype(x) for x in [key1, key2, count]]))
|
|
|
|
|
|
|
|
odd_size = count.size % 2
|
|
|
|
if odd_size:
|
2020-05-21 18:12:18 -03:00
|
|
|
x = list(jnp.split(jnp.concatenate([count.ravel(), np.uint32([0])]), 2))
|
2019-11-24 13:06:23 -05:00
|
|
|
else:
|
2020-05-21 18:12:18 -03:00
|
|
|
x = list(jnp.split(count.ravel(), 2))
|
2019-11-24 13:06:23 -05:00
|
|
|
|
|
|
|
x = threefry2x32_p.bind(key1, key2, x[0], x[1])
|
2020-05-21 18:12:18 -03:00
|
|
|
out = jnp.concatenate(x)
|
|
|
|
assert out.dtype == np.uint32
|
2018-11-17 18:03:33 -08:00
|
|
|
return lax.reshape(out[:-1] if odd_size else out, count.shape)
|
|
|
|
|
|
|
|
|
2020-05-21 18:12:18 -03:00
|
|
|
def split(key: jnp.ndarray, num: int = 2) -> jnp.ndarray:
|
2019-02-13 09:55:36 -08:00
|
|
|
"""Splits a PRNG key into `num` new keys by adding a leading axis.
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
Args:
|
2019-02-13 09:55:36 -08:00
|
|
|
key: a PRNGKey (an array with shape (2,) and dtype uint32).
|
2018-11-17 18:03:33 -08:00
|
|
|
num: optional, a positive integer indicating the number of keys to produce
|
|
|
|
(default 2).
|
|
|
|
|
|
|
|
Returns:
|
2019-02-13 09:55:36 -08:00
|
|
|
An array with shape (num, 2) and dtype uint32 representing `num` new keys.
|
2018-11-17 18:03:33 -08:00
|
|
|
"""
|
2020-08-08 17:22:54 +01:00
|
|
|
return _split(key, int(num)) # type: ignore
|
2019-04-10 22:09:14 -07:00
|
|
|
|
|
|
|
@partial(jit, static_argnums=(1,))
|
2020-08-08 17:22:54 +01:00
|
|
|
def _split(key, num) -> jnp.ndarray:
|
2020-07-30 12:59:36 -07:00
|
|
|
counts = lax.iota(np.uint32, num * 2)
|
2019-01-02 12:52:39 -08:00
|
|
|
return lax.reshape(threefry_2x32(key, counts), (num, 2))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2019-02-13 09:55:36 -08:00
|
|
|
def fold_in(key, data):
|
|
|
|
"""Folds in data to a PRNG key to form a new PRNG key.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey (an array with shape (2,) and dtype uint32).
|
2019-07-23 12:21:28 +03:00
|
|
|
data: a 32bit integer representing data to be folded in to the key.
|
2019-02-13 09:55:36 -08:00
|
|
|
|
|
|
|
Returns:
|
2020-07-21 15:41:08 -04:00
|
|
|
A new PRNGKey that is a deterministic function of the inputs and is
|
2019-02-13 09:55:36 -08:00
|
|
|
statistically safe for producing a stream of new pseudo-random values.
|
|
|
|
"""
|
2019-04-10 22:09:14 -07:00
|
|
|
return _fold_in(key, data)
|
|
|
|
|
2019-07-19 12:04:33 -07:00
|
|
|
@jit
|
2019-04-10 22:09:14 -07:00
|
|
|
def _fold_in(key, data):
|
2020-07-30 12:59:36 -07:00
|
|
|
return threefry_2x32(key, PRNGKey(data))
|
2019-02-13 09:55:36 -08:00
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def _random_bits(key, bit_width, shape):
|
|
|
|
"""Sample uniform random bits of given width and shape using PRNG key."""
|
2019-02-13 19:42:47 -08:00
|
|
|
if not _is_prng_key(key):
|
2018-12-30 21:42:55 -08:00
|
|
|
raise TypeError("_random_bits got invalid prng key.")
|
2020-05-15 19:09:43 -07:00
|
|
|
if bit_width not in (8, 16, 32, 64):
|
|
|
|
raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
|
2020-08-18 10:17:38 -07:00
|
|
|
size = prod(shape)
|
2020-05-21 18:12:18 -03:00
|
|
|
max_count = int(np.ceil(bit_width * size / 32))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-09-23 19:37:34 -07:00
|
|
|
nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max)
|
|
|
|
if not nblocks:
|
|
|
|
bits = threefry_2x32(key, lax.iota(np.uint32, rem))
|
|
|
|
else:
|
|
|
|
*subkeys, last_key = split(key, nblocks + 1)
|
|
|
|
blocks = [threefry_2x32(k, lax.iota(np.uint32, jnp.iinfo(np.uint32).max))
|
|
|
|
for k in subkeys]
|
|
|
|
last = threefry_2x32(last_key, lax.iota(np.uint32, rem))
|
|
|
|
bits = lax.concatenate(blocks + [last], 0)
|
|
|
|
|
2020-05-15 19:09:43 -07:00
|
|
|
dtype = _UINT_DTYPES[bit_width]
|
2018-11-17 18:03:33 -08:00
|
|
|
if bit_width == 64:
|
2020-05-21 18:12:18 -03:00
|
|
|
bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)]
|
2020-05-15 19:09:43 -07:00
|
|
|
bits = lax.shift_left(bits[0], dtype(32)) | bits[1]
|
|
|
|
elif bit_width in [8, 16]:
|
|
|
|
# this is essentially bits.view(dtype)[:size]
|
|
|
|
bits = lax.bitwise_and(
|
2020-05-21 18:12:18 -03:00
|
|
|
np.uint32(np.iinfo(dtype).max),
|
2020-05-15 19:09:43 -07:00
|
|
|
lax.shift_right_logical(
|
|
|
|
lax.broadcast(bits, (1,)),
|
|
|
|
lax.mul(
|
2020-05-21 18:12:18 -03:00
|
|
|
np.uint32(bit_width),
|
|
|
|
lax.broadcasted_iota(np.uint32, (32 // bit_width, 1), 0)
|
2020-05-15 19:09:43 -07:00
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
2020-05-21 18:12:18 -03:00
|
|
|
bits = lax.reshape(bits, (np.uint32(max_count * 32 // bit_width),), (1, 0))
|
2020-05-15 19:09:43 -07:00
|
|
|
bits = lax.convert_element_type(bits, dtype)[:size]
|
2018-11-17 18:03:33 -08:00
|
|
|
return lax.reshape(bits, shape)
|
|
|
|
|
|
|
|
|
|
|
|
### random samplers
|
|
|
|
|
|
|
|
|
2019-10-17 20:36:51 +00:00
|
|
|
def _check_shape(name, shape, *param_shapes):
|
2020-05-01 21:34:29 +02:00
|
|
|
shape = abstract_arrays.canonicalize_shape(shape)
|
|
|
|
|
2019-10-17 20:36:51 +00:00
|
|
|
if param_shapes:
|
|
|
|
shape_ = lax.broadcast_shapes(shape, *param_shapes)
|
|
|
|
if shape != shape_:
|
|
|
|
msg = ("{} parameter shapes must be broadcast-compatible with shape "
|
|
|
|
"argument, and the result of broadcasting the shapes must equal "
|
|
|
|
"the shape argument, but got result {} for shape argument {}.")
|
|
|
|
raise ValueError(msg.format(name, shape_, shape))
|
2019-05-09 11:40:19 -07:00
|
|
|
|
|
|
|
|
2020-05-21 18:12:18 -03:00
|
|
|
def uniform(key: jnp.ndarray,
|
2020-04-12 09:14:54 +01:00
|
|
|
shape: Sequence[int] = (),
|
2020-07-26 08:58:37 -07:00
|
|
|
dtype: np.dtype = dtypes.float_,
|
2020-05-21 18:12:18 -03:00
|
|
|
minval: Union[float, jnp.ndarray] = 0.,
|
|
|
|
maxval: Union[float, jnp.ndarray] = 1.) -> jnp.ndarray:
|
2018-11-17 18:03:33 -08:00
|
|
|
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
2019-10-17 20:36:51 +00:00
|
|
|
shape: optional, a tuple of nonnegative integers representing the result
|
|
|
|
shape. Default ().
|
2019-05-22 16:22:12 -07:00
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
2020-08-12 19:52:42 +01:00
|
|
|
minval: optional, a minimum (inclusive) value broadcast-compatible with shape for the range (default 0).
|
|
|
|
maxval: optional, a maximum (exclusive) value broadcast-compatible with shape for the range (default 1).
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
2020-06-04 10:13:15 -07:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `uniform` must be a float dtype, "
|
|
|
|
f"got {dtype}")
|
2019-11-15 10:02:51 -05:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-02-05 10:10:33 -08:00
|
|
|
shape = abstract_arrays.canonicalize_shape(shape)
|
2020-08-08 17:22:54 +01:00
|
|
|
return _uniform(key, shape, dtype, minval, maxval) # type: ignore
|
2019-04-10 22:09:14 -07:00
|
|
|
|
|
|
|
@partial(jit, static_argnums=(1, 2))
|
2020-08-08 17:22:54 +01:00
|
|
|
def _uniform(key, shape, dtype, minval, maxval) -> jnp.ndarray:
|
2019-05-09 11:40:19 -07:00
|
|
|
_check_shape("uniform", shape)
|
2020-05-21 18:12:18 -03:00
|
|
|
if not jnp.issubdtype(dtype, np.floating):
|
2018-11-17 18:03:33 -08:00
|
|
|
raise TypeError("uniform only accepts floating point dtypes.")
|
|
|
|
|
|
|
|
minval = lax.convert_element_type(minval, dtype)
|
|
|
|
maxval = lax.convert_element_type(maxval, dtype)
|
2020-08-12 19:52:42 +01:00
|
|
|
minval = lax.broadcast_to_rank(minval, len(shape))
|
|
|
|
maxval = lax.broadcast_to_rank(maxval, len(shape))
|
|
|
|
|
2020-05-21 18:12:18 -03:00
|
|
|
finfo = jnp.finfo(dtype)
|
2018-11-17 18:03:33 -08:00
|
|
|
nbits, nmant = finfo.bits, finfo.nmant
|
|
|
|
|
2020-05-19 14:19:00 -07:00
|
|
|
if nbits not in (16, 32, 64):
|
2018-11-17 18:03:33 -08:00
|
|
|
raise TypeError("uniform only accepts 32- or 64-bit dtypes.")
|
|
|
|
|
|
|
|
bits = _random_bits(key, nbits, shape)
|
|
|
|
|
|
|
|
# The strategy here is to randomize only the mantissa bits with an exponent of
|
|
|
|
# 1 (after applying the bias), then shift and scale to the desired range. The
|
|
|
|
# bit-level transformation we use relies on Numpy and XLA having bit-for-bit
|
|
|
|
# equivalent float representations, which might not be true on all platforms.
|
|
|
|
float_bits = lax.bitwise_or(
|
2020-05-21 18:12:18 -03:00
|
|
|
lax.shift_right_logical(bits, np.array(nbits - nmant, lax.dtype(bits))),
|
|
|
|
np.array(1., dtype).view(_UINT_DTYPES[nbits]))
|
|
|
|
floats = lax.bitcast_convert_type(float_bits, dtype) - np.array(1., dtype)
|
2018-11-21 14:31:25 -08:00
|
|
|
return lax.max(
|
|
|
|
minval,
|
|
|
|
lax.reshape(floats * (maxval - minval) + minval, shape))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-05-21 18:12:18 -03:00
|
|
|
def randint(key: jnp.ndarray,
|
2020-04-12 09:14:54 +01:00
|
|
|
shape: Sequence[int],
|
2020-05-21 18:12:18 -03:00
|
|
|
minval: Union[int, jnp.ndarray],
|
|
|
|
maxval: Union[int, jnp.ndarray],
|
2020-07-26 08:58:37 -07:00
|
|
|
dtype: np.dtype = dtypes.int_):
|
2018-11-17 18:03:33 -08:00
|
|
|
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
|
|
|
shape: a tuple of nonnegative integers representing the shape.
|
2019-05-16 10:32:28 -07:00
|
|
|
minval: int or array of ints broadcast-compatible with ``shape``, a minimum
|
|
|
|
(inclusive) value for the range.
|
2019-10-17 20:36:51 +00:00
|
|
|
maxval: int or array of ints broadcast-compatible with ``shape``, a maximum
|
2019-05-16 10:32:28 -07:00
|
|
|
(exclusive) value for the range.
|
2019-05-22 16:22:12 -07:00
|
|
|
dtype: optional, an int dtype for the returned values (default int64 if
|
|
|
|
jax_enable_x64 is true, otherwise int32).
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
2019-11-15 10:02:51 -05:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-02-05 10:10:33 -08:00
|
|
|
shape = abstract_arrays.canonicalize_shape(shape)
|
2019-04-10 22:09:14 -07:00
|
|
|
return _randint(key, shape, minval, maxval, dtype)
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(1, 4))
|
2019-05-22 16:22:12 -07:00
|
|
|
def _randint(key, shape, minval, maxval, dtype):
|
2020-05-21 18:12:18 -03:00
|
|
|
_check_shape("randint", shape, np.shape(minval), np.shape(maxval))
|
|
|
|
if not jnp.issubdtype(dtype, np.integer):
|
2018-11-17 18:03:33 -08:00
|
|
|
raise TypeError("randint only accepts integer dtypes.")
|
|
|
|
|
|
|
|
minval = lax.convert_element_type(minval, dtype)
|
|
|
|
maxval = lax.convert_element_type(maxval, dtype)
|
2020-08-12 19:52:42 +01:00
|
|
|
minval = lax.broadcast_to_rank(minval, len(shape))
|
|
|
|
maxval = lax.broadcast_to_rank(maxval, len(shape))
|
2020-05-21 18:12:18 -03:00
|
|
|
nbits = jnp.iinfo(dtype).bits
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-05-19 14:19:00 -07:00
|
|
|
if nbits not in (8, 16, 32, 64):
|
|
|
|
raise TypeError("randint only accepts 8-, 16-, 32-, or 64-bit dtypes.")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-01-12 12:54:22 -08:00
|
|
|
# if we don't have minval < maxval, just always return minval
|
|
|
|
# https://github.com/google/jax/issues/222
|
2020-05-21 18:12:18 -03:00
|
|
|
maxval = lax.max(lax.add(minval, np.array(1, dtype)), maxval)
|
2019-01-12 12:54:22 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
# This algorithm is biased whenever (maxval - minval) is not a power of 2.
|
|
|
|
# We generate double the number of random bits required by the dtype so as to
|
|
|
|
# reduce that bias.
|
|
|
|
k1, k2 = split(key)
|
|
|
|
rbits = lambda key: _random_bits(key, nbits, shape)
|
|
|
|
higher_bits, lower_bits = rbits(k1), rbits(k2)
|
|
|
|
|
2020-05-19 14:19:00 -07:00
|
|
|
unsigned_dtype = _UINT_DTYPES[nbits]
|
2018-11-17 18:03:33 -08:00
|
|
|
span = lax.convert_element_type(maxval - minval, unsigned_dtype)
|
|
|
|
|
|
|
|
# To compute a remainder operation on an integer that might have twice as many
|
|
|
|
# bits as we can represent in the native unsigned dtype, we compute a
|
2020-05-19 14:19:00 -07:00
|
|
|
# multiplier equal to 2**nbits % span. To avoid overflow, we use the identity:
|
|
|
|
# (a * b) % N = [(a % N) * (b % N)] % N
|
|
|
|
multiplier = lax.rem(lax._const(span, 2 ** (nbits // 2)), span)
|
2018-11-17 18:03:33 -08:00
|
|
|
multiplier = lax.rem(lax.mul(multiplier, multiplier), span)
|
|
|
|
|
|
|
|
random_offset = lax.add(lax.mul(lax.rem(higher_bits, span), multiplier),
|
|
|
|
lax.rem(lower_bits, span))
|
|
|
|
random_offset = lax.rem(random_offset, span)
|
|
|
|
return lax.add(minval, lax.convert_element_type(random_offset, dtype))
|
|
|
|
|
|
|
|
|
2020-05-21 18:12:18 -03:00
|
|
|
def shuffle(key: jnp.ndarray, x: jnp.ndarray, axis: int = 0) -> jnp.ndarray:
|
2018-11-17 18:03:33 -08:00
|
|
|
"""Shuffle the elements of an array uniformly at random along an axis.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
|
|
|
x: the array to be shuffled.
|
|
|
|
axis: optional, an int axis along which to shuffle (default 0).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A shuffled version of x.
|
|
|
|
"""
|
2020-05-01 15:18:24 -07:00
|
|
|
msg = ("jax.random.shuffle is deprecated and will be removed in a future release. "
|
|
|
|
"Use jax.random.permutation")
|
|
|
|
warnings.warn(msg, FutureWarning)
|
2020-08-08 17:22:54 +01:00
|
|
|
return _shuffle(key, x, axis) # type: ignore
|
2019-04-10 22:09:14 -07:00
|
|
|
|
2020-04-24 07:40:33 +02:00
|
|
|
|
|
|
|
def permutation(key, x):
|
|
|
|
"""
|
|
|
|
Permute elements of an array along its first axis or return a permuted range.
|
|
|
|
|
2020-05-01 15:18:24 -07:00
|
|
|
If `x` is a multi-dimensional array, it is only shuffled along its
|
|
|
|
first index.
|
|
|
|
|
2020-04-24 07:40:33 +02:00
|
|
|
Args:n
|
|
|
|
key: a PRNGKey used as the random key.
|
|
|
|
x: the array or integer range to be shuffled.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A shuffled version of x or array range
|
|
|
|
"""
|
2020-05-21 18:12:18 -03:00
|
|
|
if not np.ndim(x):
|
2020-04-24 07:40:33 +02:00
|
|
|
# scalar case, must be a concrete integer
|
2020-05-21 18:12:18 -03:00
|
|
|
if not np.issubdtype(lax.dtype(x), np.integer):
|
2020-04-24 07:40:33 +02:00
|
|
|
raise TypeError("x must be an integer or at least 1-dimensional")
|
|
|
|
x = int(x)
|
2020-05-21 18:12:18 -03:00
|
|
|
return _shuffle(key, jnp.arange(x), 0)
|
|
|
|
elif np.ndim(x) == 1:
|
2020-04-24 07:40:33 +02:00
|
|
|
return _shuffle(key, x, 0)
|
|
|
|
else:
|
2020-08-27 09:41:16 +03:00
|
|
|
ind = _shuffle(key, jnp.arange(x.shape[0]), 0) # type: ignore[attribute-error]
|
2020-05-01 15:18:24 -07:00
|
|
|
return x[ind]
|
2020-04-24 07:40:33 +02:00
|
|
|
|
|
|
|
|
2019-04-10 22:09:14 -07:00
|
|
|
@partial(jit, static_argnums=(2,))
|
2020-08-08 17:22:54 +01:00
|
|
|
def _shuffle(key, x, axis) -> jnp.ndarray:
|
2018-11-17 18:03:33 -08:00
|
|
|
# On parallel architectures, Fisher-Yates is more expensive than doing
|
|
|
|
# multiple sorts. This algorithm is based on one developed and analyzed by
|
|
|
|
# tjablin@. We sort according to randomly-generated 32bit keys, but those keys
|
|
|
|
# may have collisions. If we repeat the process, using fresh 32bit keys for
|
|
|
|
# each sort, then whenever all pairs of elements have been assigned distinct
|
|
|
|
# keys at some iteration (or equivalently when the strings formed by
|
|
|
|
# concatenating the successive keys for each element are all distinct) then we
|
|
|
|
# are guaranteed to have a perfect sample (assuming that either the sort is
|
|
|
|
# stable or that any bias is not value-dependent). Since checking uniqueness
|
|
|
|
# at runtime may be expensive, we use a heuristic static stop criterion
|
|
|
|
# developed by tjablin@. See tensorflow/compiler/tf2xla/random_ops.cc for more
|
|
|
|
# info, and for the original implementation of this algorithm. See also
|
|
|
|
# Section 2 of http://people.csail.mit.edu/costis/6896sp11/lec5s.pdf for
|
|
|
|
# another analysis (where the keys are generated one bit at a time).
|
|
|
|
exponent = 3 # see tjablin@'s analysis for explanation of this parameter
|
2020-05-21 18:12:18 -03:00
|
|
|
uint32max = jnp.iinfo(np.uint32).max
|
|
|
|
num_rounds = int(np.ceil(exponent * np.log(x.size) / np.log(uint32max)))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
for _ in range(num_rounds):
|
|
|
|
key, subkey = split(key)
|
|
|
|
sort_keys = _random_bits(subkey, 32, x.shape)
|
2018-11-19 07:43:23 -08:00
|
|
|
_, x = lax.sort_key_val(sort_keys, x, axis)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
2020-06-19 16:04:42 -07:00
|
|
|
def choice(key, a, shape=(), replace=True, p=None):
|
|
|
|
"""Generates a random sample from a given 1-D array.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
|
|
|
a : 1D array or int. If an ndarray, a random sample is generated from
|
|
|
|
its elements. If an int, the random sample is generated as if a were
|
|
|
|
arange(a).
|
|
|
|
shape : tuple of ints, optional. Output shape. If the given shape is,
|
|
|
|
e.g., ``(m, n)``, then ``m * n`` samples are drawn. Default is (),
|
|
|
|
in which case a single value is returned.
|
|
|
|
replace : boolean. Whether the sample is with or without replacement.
|
|
|
|
default is True.
|
|
|
|
p : 1-D array-like, The probabilities associated with each entry in a.
|
|
|
|
If not given the sample assumes a uniform distribution over all
|
|
|
|
entries in a.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array of shape `shape` containing samples from `a`.
|
|
|
|
"""
|
2020-08-21 19:58:06 -07:00
|
|
|
if not isinstance(shape, Sequence):
|
|
|
|
raise TypeError("shape argument of jax.random.choice must be a sequence, "
|
|
|
|
f"got {shape}")
|
2020-08-25 05:47:43 -07:00
|
|
|
if np.ndim(a) not in [0, 1]:
|
2020-06-19 16:04:42 -07:00
|
|
|
raise ValueError("a must be an integer or 1-dimensional")
|
2020-08-26 10:21:56 -07:00
|
|
|
if np.ndim(a) == 0:
|
|
|
|
a = int(a)
|
|
|
|
else:
|
|
|
|
a = _asarray(a)
|
|
|
|
n_inputs = a if np.ndim(a) == 0 else len(a)
|
2020-08-18 10:17:38 -07:00
|
|
|
n_draws = prod(shape)
|
2020-06-19 16:04:42 -07:00
|
|
|
if n_draws == 0:
|
2020-08-25 05:47:43 -07:00
|
|
|
return jnp.zeros(shape, dtype=lax.dtype(a))
|
2020-06-19 16:04:42 -07:00
|
|
|
if n_inputs <= 0:
|
|
|
|
raise ValueError("a must be greater than 0 unless no samples are taken")
|
|
|
|
if not replace and n_draws > n_inputs:
|
|
|
|
raise ValueError("Cannot take a larger sample than population when 'replace=False'")
|
|
|
|
|
|
|
|
if p is None:
|
|
|
|
if replace:
|
|
|
|
ind = randint(key, shape, 0, n_inputs)
|
2020-08-25 05:47:43 -07:00
|
|
|
result = ind if np.ndim(a) == 0 else a[ind]
|
2020-06-19 16:04:42 -07:00
|
|
|
else:
|
|
|
|
result = permutation(key, a)[:n_draws]
|
|
|
|
else:
|
2020-08-26 10:21:56 -07:00
|
|
|
if p.shape != (n_inputs,):
|
2020-06-19 16:04:42 -07:00
|
|
|
raise ValueError("p must be None or match the shape of a")
|
|
|
|
if replace:
|
|
|
|
p_cuml = jnp.cumsum(p)
|
|
|
|
r = p_cuml[-1] * (1 - uniform(key, shape))
|
|
|
|
ind = jnp.searchsorted(p_cuml, r)
|
2020-08-25 05:47:43 -07:00
|
|
|
result = ind if np.ndim(a) == 0 else a[ind]
|
2020-06-19 16:04:42 -07:00
|
|
|
else:
|
|
|
|
# Gumbel top-k trick: https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/
|
|
|
|
g = -gumbel(key, (n_inputs,)) - jnp.log(p)
|
|
|
|
ind = jnp.argsort(g)[:n_draws]
|
2020-08-25 05:47:43 -07:00
|
|
|
result = ind if np.ndim(a) == 0 else a[ind]
|
2020-06-19 16:04:42 -07:00
|
|
|
return result.reshape(shape)
|
|
|
|
|
|
|
|
|
2020-05-21 18:12:18 -03:00
|
|
|
def normal(key: jnp.ndarray,
|
2020-04-12 09:14:54 +01:00
|
|
|
shape: Sequence[int] = (),
|
2020-07-26 08:58:37 -07:00
|
|
|
dtype: np.dtype = dtypes.float_) -> jnp.ndarray:
|
2018-11-17 18:03:33 -08:00
|
|
|
"""Sample standard normal random values with given shape and float dtype.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
2019-10-17 20:36:51 +00:00
|
|
|
shape: optional, a tuple of nonnegative integers representing the result
|
|
|
|
shape. Default ().
|
2019-05-22 16:22:12 -07:00
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
2020-06-04 10:13:15 -07:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `normal` must be a float dtype, "
|
|
|
|
f"got {dtype}")
|
2019-11-15 10:02:51 -05:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-02-05 10:10:33 -08:00
|
|
|
shape = abstract_arrays.canonicalize_shape(shape)
|
2020-08-08 17:22:54 +01:00
|
|
|
return _normal(key, shape, dtype) # type: ignore
|
2019-04-10 22:09:14 -07:00
|
|
|
|
|
|
|
@partial(jit, static_argnums=(1, 2))
|
2020-08-08 17:22:54 +01:00
|
|
|
def _normal(key, shape, dtype) -> jnp.ndarray:
|
2019-05-09 11:40:19 -07:00
|
|
|
_check_shape("normal", shape)
|
2020-05-21 18:12:18 -03:00
|
|
|
lo = np.nextafter(np.array(-1., dtype), 0., dtype=dtype)
|
|
|
|
hi = np.array(1., dtype)
|
2018-11-17 18:03:33 -08:00
|
|
|
u = uniform(key, shape, dtype, lo, hi)
|
2020-05-21 18:12:18 -03:00
|
|
|
return np.array(np.sqrt(2), dtype) * lax.erf_inv(u)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-05-21 18:12:18 -03:00
|
|
|
def multivariate_normal(key: jnp.ndarray,
|
|
|
|
mean: jnp.ndarray,
|
|
|
|
cov: jnp.ndarray,
|
2020-04-12 09:14:54 +01:00
|
|
|
shape: Optional[Sequence[int]] = None,
|
2020-07-26 08:58:37 -07:00
|
|
|
dtype: np.dtype = dtypes.float_) -> jnp.ndarray:
|
2019-10-16 01:35:39 +00:00
|
|
|
"""Sample multivariate normal random values with given mean and covariance.
|
2019-10-01 08:42:09 -04:00
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
2019-10-17 20:36:51 +00:00
|
|
|
mean: a mean vector of shape ``(..., n)``.
|
|
|
|
cov: a positive definite covariance matrix of shape ``(..., n, n)``. The
|
2019-10-20 21:14:48 +00:00
|
|
|
batch shape ``...`` must be broadcast-compatible with that of ``mean``.
|
2019-10-17 20:36:51 +00:00
|
|
|
shape: optional, a tuple of nonnegative integers specifying the result
|
|
|
|
batch shape; that is, the prefix of the result shape excluding the last
|
2019-10-20 21:14:48 +00:00
|
|
|
axis. Must be broadcast-compatible with ``mean.shape[:-1]`` and
|
|
|
|
``cov.shape[:-2]``. The default (None) produces a result batch shape by
|
|
|
|
broadcasting together the batch shapes of ``mean`` and ``cov``.
|
2019-10-16 01:35:39 +00:00
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
2019-10-01 08:42:09 -04:00
|
|
|
|
|
|
|
Returns:
|
2019-10-17 20:36:51 +00:00
|
|
|
A random array with the specified dtype and shape given by
|
2019-10-20 21:14:48 +00:00
|
|
|
``shape + mean.shape[-1:]`` if ``shape`` is not None, or else
|
|
|
|
``broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]``.
|
2019-10-01 08:42:09 -04:00
|
|
|
"""
|
2020-06-04 10:13:15 -07:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `multivariate_normal` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
2019-11-15 10:02:51 -05:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-02-05 10:10:33 -08:00
|
|
|
if shape is not None:
|
|
|
|
shape = abstract_arrays.canonicalize_shape(shape)
|
2020-08-08 17:22:54 +01:00
|
|
|
return _multivariate_normal(key, mean, cov, shape, dtype) # type: ignore
|
2019-10-17 20:36:51 +00:00
|
|
|
|
|
|
|
@partial(jit, static_argnums=(3, 4))
|
2020-08-08 17:22:54 +01:00
|
|
|
def _multivariate_normal(key, mean, cov, shape, dtype) -> jnp.ndarray:
|
2020-05-21 18:12:18 -03:00
|
|
|
if not np.ndim(mean) >= 1:
|
2019-10-17 20:36:51 +00:00
|
|
|
msg = "multivariate_normal requires mean.ndim >= 1, got mean.ndim == {}"
|
2020-05-21 18:12:18 -03:00
|
|
|
raise ValueError(msg.format(np.ndim(mean)))
|
|
|
|
if not np.ndim(cov) >= 2:
|
2019-10-17 20:36:51 +00:00
|
|
|
msg = "multivariate_normal requires cov.ndim >= 2, got cov.ndim == {}"
|
2020-05-21 18:12:18 -03:00
|
|
|
raise ValueError(msg.format(np.ndim(cov)))
|
2019-10-17 20:36:51 +00:00
|
|
|
n = mean.shape[-1]
|
2020-05-21 18:12:18 -03:00
|
|
|
if np.shape(cov)[-2:] != (n, n):
|
2019-10-17 20:36:51 +00:00
|
|
|
msg = ("multivariate_normal requires cov.shape == (..., n, n) for n={n}, "
|
|
|
|
"but got cov.shape == {shape}.")
|
2020-05-21 18:12:18 -03:00
|
|
|
raise ValueError(msg.format(n=n, shape=np.shape(cov)))
|
2019-10-17 20:36:51 +00:00
|
|
|
|
2019-10-20 21:14:48 +00:00
|
|
|
if shape is None:
|
|
|
|
shape = lax.broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
|
|
|
|
else:
|
2020-08-09 11:32:45 -07:00
|
|
|
_check_shape("normal", shape, mean.shape[:-1], cov.shape[:-2])
|
2019-10-20 21:14:48 +00:00
|
|
|
|
2019-10-16 01:35:39 +00:00
|
|
|
chol_factor = cholesky(cov)
|
2019-10-17 20:36:51 +00:00
|
|
|
normal_samples = normal(key, shape + mean.shape[-1:], dtype)
|
2020-08-09 11:32:45 -07:00
|
|
|
return mean + jnp.einsum('...ij,...j->...i', chol_factor, normal_samples)
|
2019-09-23 16:15:41 -04:00
|
|
|
|
|
|
|
|
2020-05-21 18:12:18 -03:00
|
|
|
def truncated_normal(key: jnp.ndarray,
|
|
|
|
lower: Union[float, jnp.ndarray],
|
|
|
|
upper: Union[float, jnp.ndarray],
|
2020-04-12 09:14:54 +01:00
|
|
|
shape: Optional[Sequence[int]] = None,
|
2020-07-26 08:58:37 -07:00
|
|
|
dtype: np.dtype = dtypes.float_) -> jnp.ndarray:
|
2019-09-03 17:51:29 -07:00
|
|
|
"""Sample truncated standard normal random values with given shape and dtype.
|
2019-08-16 17:02:20 -07:00
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
2019-10-20 21:14:48 +00:00
|
|
|
lower: a float or array of floats representing the lower bound for
|
|
|
|
truncation. Must be broadcast-compatible with ``upper``.
|
|
|
|
upper: a float or array of floats representing the upper bound for
|
|
|
|
truncation. Must be broadcast-compatible with ``lower``.
|
2019-10-17 20:36:51 +00:00
|
|
|
shape: optional, a tuple of nonnegative integers specifying the result
|
2019-10-20 21:14:48 +00:00
|
|
|
shape. Must be broadcast-compatible with ``lower`` and ``upper``. The
|
|
|
|
default (None) produces a result shape by broadcasting ``lower`` and
|
|
|
|
``upper``.
|
2019-08-16 17:02:20 -07:00
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
|
|
|
|
|
|
|
Returns:
|
2019-10-20 21:14:48 +00:00
|
|
|
A random array with the specified dtype and shape given by ``shape`` if
|
|
|
|
``shape`` is not None, or else by broadcasting ``lower`` and ``upper``.
|
2019-08-16 17:02:20 -07:00
|
|
|
"""
|
2020-06-04 10:13:15 -07:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `truncated_normal` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
2019-11-15 10:02:51 -05:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-02-05 10:10:33 -08:00
|
|
|
if shape is not None:
|
|
|
|
shape = abstract_arrays.canonicalize_shape(shape)
|
2020-08-08 17:22:54 +01:00
|
|
|
return _truncated_normal(key, lower, upper, shape, dtype) # type: ignore
|
2019-08-16 17:02:20 -07:00
|
|
|
|
|
|
|
@partial(jit, static_argnums=(3, 4))
|
2020-08-08 17:22:54 +01:00
|
|
|
def _truncated_normal(key, lower, upper, shape, dtype) -> jnp.ndarray:
|
2019-10-20 21:14:48 +00:00
|
|
|
if shape is None:
|
2020-05-21 18:12:18 -03:00
|
|
|
shape = lax.broadcast_shapes(np.shape(lower), np.shape(upper))
|
2019-10-20 21:14:48 +00:00
|
|
|
else:
|
2020-05-21 18:12:18 -03:00
|
|
|
_check_shape("truncated_normal", shape, np.shape(lower), np.shape(upper))
|
2019-10-20 21:14:48 +00:00
|
|
|
|
2020-05-21 18:12:18 -03:00
|
|
|
sqrt2 = np.array(np.sqrt(2), dtype)
|
2019-09-03 17:51:29 -07:00
|
|
|
a = lax.erf(lax.convert_element_type(lower, dtype) / sqrt2)
|
|
|
|
b = lax.erf(lax.convert_element_type(upper, dtype) / sqrt2)
|
2020-05-21 18:12:18 -03:00
|
|
|
if not jnp.issubdtype(dtype, np.floating):
|
2019-10-01 22:28:31 +01:00
|
|
|
raise TypeError("truncated_normal only accepts floating point dtypes.")
|
2020-05-21 18:12:18 -03:00
|
|
|
u = uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny)
|
2019-08-16 17:02:20 -07:00
|
|
|
return sqrt2 * lax.erf_inv(a + u * (b - a))
|
|
|
|
|
|
|
|
|
2020-05-21 18:12:18 -03:00
|
|
|
def bernoulli(key: jnp.ndarray,
|
|
|
|
p: jnp.ndarray = np.float32(0.5),
|
|
|
|
shape: Optional[Sequence[int]] = None) -> jnp.ndarray:
|
2018-11-17 18:03:33 -08:00
|
|
|
"""Sample Bernoulli random values with given shape and mean.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
2019-10-17 20:36:51 +00:00
|
|
|
p: optional, a float or array of floats for the mean of the random
|
|
|
|
variables. Must be broadcast-compatible with ``shape``. Default 0.5.
|
2019-10-20 21:14:48 +00:00
|
|
|
shape: optional, a tuple of nonnegative integers representing the result
|
|
|
|
shape. Must be broadcast-compatible with ``p.shape``. The default (None)
|
|
|
|
produces a result shape equal to ``p.shape``.
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
Returns:
|
2019-10-20 21:14:48 +00:00
|
|
|
A random array with boolean dtype and shape given by ``shape`` if ``shape``
|
|
|
|
is not None, or else ``p.shape``.
|
2018-11-17 18:03:33 -08:00
|
|
|
"""
|
2019-11-15 10:02:51 -05:00
|
|
|
dtype = dtypes.canonicalize_dtype(lax.dtype(p))
|
2020-02-05 10:10:33 -08:00
|
|
|
if shape is not None:
|
|
|
|
shape = abstract_arrays.canonicalize_shape(shape)
|
2020-05-21 18:12:18 -03:00
|
|
|
if not jnp.issubdtype(dtype, np.floating):
|
2019-05-22 20:06:12 -07:00
|
|
|
msg = "bernoulli probability `p` must have a floating dtype, got {}."
|
|
|
|
raise TypeError(msg.format(dtype))
|
|
|
|
p = lax.convert_element_type(p, dtype)
|
2020-08-08 17:22:54 +01:00
|
|
|
return _bernoulli(key, p, shape) # type: ignore
|
2019-04-10 22:09:14 -07:00
|
|
|
|
|
|
|
@partial(jit, static_argnums=(2,))
|
2020-08-08 17:22:54 +01:00
|
|
|
def _bernoulli(key, p, shape) -> jnp.ndarray:
|
2019-10-20 21:14:48 +00:00
|
|
|
if shape is None:
|
2020-05-21 18:12:18 -03:00
|
|
|
shape = np.shape(p)
|
2019-10-20 21:14:48 +00:00
|
|
|
else:
|
2020-05-21 18:12:18 -03:00
|
|
|
_check_shape("bernoulli", shape, np.shape(p))
|
2019-10-20 21:14:48 +00:00
|
|
|
|
2019-10-17 20:36:51 +00:00
|
|
|
return uniform(key, shape, lax.dtype(p)) < p
|
2019-03-28 17:59:42 -04:00
|
|
|
|
|
|
|
|
2020-05-21 18:12:18 -03:00
|
|
|
def beta(key: jnp.ndarray,
|
|
|
|
a: Union[float, jnp.ndarray],
|
|
|
|
b: Union[float, jnp.ndarray],
|
2020-04-12 09:14:54 +01:00
|
|
|
shape: Optional[Sequence[int]] = None,
|
2020-07-26 08:58:37 -07:00
|
|
|
dtype: np.dtype = dtypes.float_) -> jnp.ndarray:
|
2020-04-20 22:55:23 -04:00
|
|
|
"""Sample Beta random values with given shape and float dtype.
|
2019-04-21 16:25:20 -04:00
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
2019-10-17 20:36:51 +00:00
|
|
|
a: a float or array of floats broadcast-compatible with ``shape``
|
|
|
|
representing the first parameter "alpha".
|
|
|
|
b: a float or array of floats broadcast-compatible with ``shape``
|
|
|
|
representing the second parameter "beta".
|
|
|
|
shape: optional, a tuple of nonnegative integers specifying the result
|
2019-10-20 21:14:48 +00:00
|
|
|
shape. Must be broadcast-compatible with ``a`` and ``b``. The default
|
|
|
|
(None) produces a result shape by broadcasting ``a`` and ``b``.
|
2019-05-22 16:22:12 -07:00
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
2019-04-21 16:25:20 -04:00
|
|
|
|
|
|
|
Returns:
|
2019-10-20 21:14:48 +00:00
|
|
|
A random array with the specified dtype and shape given by ``shape`` if
|
|
|
|
``shape`` is not None, or else by broadcasting ``a`` and ``b``.
|
2019-04-21 16:25:20 -04:00
|
|
|
"""
|
2020-06-04 10:13:15 -07:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `beta` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
2019-11-15 10:02:51 -05:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-02-05 10:10:33 -08:00
|
|
|
if shape is not None:
|
|
|
|
shape = abstract_arrays.canonicalize_shape(shape)
|
2019-04-21 16:25:20 -04:00
|
|
|
return _beta(key, a, b, shape, dtype)
|
|
|
|
|
|
|
|
def _beta(key, a, b, shape, dtype):
|
2019-10-20 21:14:48 +00:00
|
|
|
if shape is None:
|
2020-05-21 18:12:18 -03:00
|
|
|
shape = lax.broadcast_shapes(np.shape(a), np.shape(b))
|
2019-10-20 21:14:48 +00:00
|
|
|
else:
|
2020-05-21 18:12:18 -03:00
|
|
|
_check_shape("beta", shape, np.shape(a), np.shape(b))
|
2019-10-20 21:14:48 +00:00
|
|
|
|
2019-04-21 16:25:20 -04:00
|
|
|
a = lax.convert_element_type(a, dtype)
|
|
|
|
b = lax.convert_element_type(b, dtype)
|
2019-04-21 16:43:18 -04:00
|
|
|
key_a, key_b = split(key)
|
2020-05-21 18:12:18 -03:00
|
|
|
a = jnp.broadcast_to(a, shape)
|
|
|
|
b = jnp.broadcast_to(b, shape)
|
2019-04-21 16:25:20 -04:00
|
|
|
gamma_a = gamma(key_a, a, shape, dtype)
|
|
|
|
gamma_b = gamma(key_b, b, shape, dtype)
|
|
|
|
return gamma_a / (gamma_a + gamma_b)
|
2019-03-28 17:59:42 -04:00
|
|
|
|
|
|
|
|
2020-07-26 08:58:37 -07:00
|
|
|
def cauchy(key, shape=(), dtype=dtypes.float_):
|
2019-03-28 17:59:42 -04:00
|
|
|
"""Sample Cauchy random values with given shape and float dtype.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
2019-10-17 20:36:51 +00:00
|
|
|
shape: optional, a tuple of nonnegative integers representing the result
|
|
|
|
shape. Default ().
|
2019-05-22 16:22:12 -07:00
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
2019-03-28 17:59:42 -04:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
2020-06-04 10:13:15 -07:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `cauchy` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
2019-11-15 10:02:51 -05:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-02-05 10:10:33 -08:00
|
|
|
shape = abstract_arrays.canonicalize_shape(shape)
|
2019-04-10 22:09:14 -07:00
|
|
|
return _cauchy(key, shape, dtype)
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(1, 2))
|
|
|
|
def _cauchy(key, shape, dtype):
|
2019-05-09 11:40:19 -07:00
|
|
|
_check_shape("cauchy", shape)
|
2020-05-21 18:12:18 -03:00
|
|
|
u = uniform(key, shape, dtype, minval=jnp.finfo(dtype).eps, maxval=1.)
|
|
|
|
pi = _constant_like(u, np.pi)
|
2019-03-28 23:57:00 -04:00
|
|
|
return lax.tan(lax.mul(pi, lax.sub(u, _constant_like(u, 0.5))))
|
|
|
|
|
|
|
|
|
2020-07-26 08:58:37 -07:00
|
|
|
def dirichlet(key, alpha, shape=None, dtype=dtypes.float_):
|
2020-02-14 11:04:20 -05:00
|
|
|
"""Sample Dirichlet random values with given shape and float dtype.
|
2019-04-22 11:55:02 -04:00
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
2019-10-17 20:36:51 +00:00
|
|
|
alpha: an array of shape ``(..., n)`` used as the concentration
|
|
|
|
parameter of the random variables.
|
|
|
|
shape: optional, a tuple of nonnegative integers specifying the result
|
|
|
|
batch shape; that is, the prefix of the result shape excluding the last
|
|
|
|
element of value ``n``. Must be broadcast-compatible with
|
2019-10-20 21:14:48 +00:00
|
|
|
``alpha.shape[:-1]``. The default (None) produces a result shape equal to
|
|
|
|
``alpha.shape``.
|
2019-05-22 16:22:12 -07:00
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
2019-04-22 11:55:02 -04:00
|
|
|
|
|
|
|
Returns:
|
2019-10-17 20:36:51 +00:00
|
|
|
A random array with the specified dtype and shape given by
|
2019-10-20 21:14:48 +00:00
|
|
|
``shape + (alpha.shape[-1],)`` if ``shape`` is not None, or else
|
|
|
|
``alpha.shape``.
|
2019-04-22 11:55:02 -04:00
|
|
|
"""
|
2020-06-04 10:13:15 -07:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `dirichlet` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
2019-11-15 10:02:51 -05:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-02-05 10:10:33 -08:00
|
|
|
if shape is not None:
|
|
|
|
shape = abstract_arrays.canonicalize_shape(shape)
|
2019-04-22 11:55:02 -04:00
|
|
|
return _dirichlet(key, alpha, shape, dtype)
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(2, 3))
|
|
|
|
def _dirichlet(key, alpha, shape, dtype):
|
2020-05-21 18:12:18 -03:00
|
|
|
if not np.ndim(alpha) >= 1:
|
2019-10-17 20:36:51 +00:00
|
|
|
msg = "dirichlet requires alpha.ndim >= 1, got alpha.ndim == {}"
|
2020-05-21 18:12:18 -03:00
|
|
|
raise ValueError(msg.format(np.ndim(alpha)))
|
2019-10-17 20:36:51 +00:00
|
|
|
|
2019-10-20 21:14:48 +00:00
|
|
|
if shape is None:
|
2020-05-21 18:12:18 -03:00
|
|
|
shape = np.shape(alpha)[:-1]
|
2019-10-20 21:14:48 +00:00
|
|
|
else:
|
2020-05-21 18:12:18 -03:00
|
|
|
_check_shape("dirichlet", shape, np.shape(alpha)[:-1])
|
2019-10-20 21:14:48 +00:00
|
|
|
|
2019-10-17 20:36:51 +00:00
|
|
|
alpha = lax.convert_element_type(alpha, dtype)
|
2020-05-21 18:12:18 -03:00
|
|
|
gamma_samples = gamma(key, alpha, shape + np.shape(alpha)[-1:], dtype)
|
|
|
|
return gamma_samples / jnp.sum(gamma_samples, axis=-1, keepdims=True)
|
2019-04-22 11:55:02 -04:00
|
|
|
|
|
|
|
|
2020-07-26 08:58:37 -07:00
|
|
|
def exponential(key, shape=(), dtype=dtypes.float_):
|
2019-03-28 23:57:00 -04:00
|
|
|
"""Sample Exponential random values with given shape and float dtype.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
2019-10-17 20:36:51 +00:00
|
|
|
shape: optional, a tuple of nonnegative integers representing the result
|
|
|
|
shape. Default ().
|
2019-05-22 16:22:12 -07:00
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
2019-03-28 23:57:00 -04:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
2020-06-04 10:13:15 -07:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `exponential` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
2019-11-15 10:02:51 -05:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-02-05 10:10:33 -08:00
|
|
|
shape = abstract_arrays.canonicalize_shape(shape)
|
2019-04-10 22:09:14 -07:00
|
|
|
return _exponential(key, shape, dtype)
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(1, 2))
|
|
|
|
def _exponential(key, shape, dtype):
|
2019-05-09 11:40:19 -07:00
|
|
|
_check_shape("exponential", shape)
|
2019-03-28 23:57:00 -04:00
|
|
|
u = uniform(key, shape, dtype)
|
|
|
|
# taking 1 - u to move the domain of log to (0, 1] instead of [0, 1)
|
2019-08-16 13:44:09 -07:00
|
|
|
return lax.neg(lax.log1p(lax.neg(u)))
|
2019-03-28 23:57:00 -04:00
|
|
|
|
|
|
|
|
2019-03-30 18:07:34 -04:00
|
|
|
def _gamma_one(key, alpha):
|
|
|
|
# Ref: A simple method for generating gamma variables, George Marsaglia and Wai Wan Tsang
|
2019-03-31 23:54:31 -04:00
|
|
|
# The algorithm can also be founded in:
|
|
|
|
# https://en.wikipedia.org/wiki/Gamma_distribution#Generating_gamma-distributed_random_variables
|
|
|
|
zero = _constant_like(alpha, 0)
|
|
|
|
one = _constant_like(alpha, 1)
|
2019-06-20 18:41:52 -04:00
|
|
|
minus_one = _constant_like(alpha, -1)
|
2019-03-31 23:54:31 -04:00
|
|
|
one_over_two = _constant_like(alpha, 0.5)
|
|
|
|
one_over_three = _constant_like(alpha, 1. / 3.)
|
|
|
|
squeeze_const = _constant_like(alpha, 0.0331)
|
2019-04-12 16:28:40 -07:00
|
|
|
dtype = lax.dtype(alpha)
|
2019-03-31 23:54:31 -04:00
|
|
|
|
2019-06-27 17:28:36 -04:00
|
|
|
key, subkey = split(key)
|
2019-03-31 23:54:31 -04:00
|
|
|
# for alpha < 1, we boost alpha to alpha + 1 and get a sample according to
|
|
|
|
# Gamma(alpha) ~ Gamma(alpha+1) * Uniform()^(1 / alpha)
|
|
|
|
boost = lax.select(lax.ge(alpha, one),
|
|
|
|
one,
|
2019-06-27 17:28:36 -04:00
|
|
|
lax.pow(uniform(subkey, (), dtype=dtype), lax.div(one, alpha)))
|
2019-03-31 23:54:31 -04:00
|
|
|
alpha = lax.select(lax.ge(alpha, one), alpha, lax.add(alpha, one))
|
2019-03-30 18:07:34 -04:00
|
|
|
|
2019-03-31 23:54:31 -04:00
|
|
|
d = lax.sub(alpha, one_over_three)
|
|
|
|
c = lax.div(one_over_three, lax.pow(d, one_over_two))
|
2019-03-30 18:07:34 -04:00
|
|
|
|
|
|
|
def _cond_fn(kXVU):
|
|
|
|
_, X, V, U = kXVU
|
2019-04-12 16:28:40 -07:00
|
|
|
# TODO: use lax.cond when its batching rule is supported
|
2019-03-31 23:54:31 -04:00
|
|
|
# The reason is to avoid evaluating second condition which involves log+log
|
2019-03-30 18:07:34 -04:00
|
|
|
# if the first condition is satisfied
|
2019-03-31 23:54:31 -04:00
|
|
|
cond = lax.bitwise_and(lax.ge(U, lax.sub(one, lax.mul(squeeze_const, lax.mul(X, X)))),
|
|
|
|
lax.ge(lax.log(U), lax.add(lax.mul(X, one_over_two),
|
|
|
|
lax.mul(d, lax.add(lax.sub(one, V),
|
|
|
|
lax.log(V))))))
|
2019-06-20 18:41:52 -04:00
|
|
|
return cond
|
2019-03-30 18:07:34 -04:00
|
|
|
|
|
|
|
def _body_fn(kXVU):
|
2019-06-20 18:41:52 -04:00
|
|
|
def _next_kxv(kxv):
|
2019-06-27 17:28:36 -04:00
|
|
|
key = kxv[0]
|
|
|
|
key, subkey = split(key)
|
|
|
|
x = normal(subkey, (), dtype=dtype)
|
2019-06-20 18:41:52 -04:00
|
|
|
v = lax.add(one, lax.mul(x, c))
|
2019-06-27 17:28:36 -04:00
|
|
|
return key, x, v
|
2019-06-20 18:41:52 -04:00
|
|
|
|
2019-06-27 17:28:36 -04:00
|
|
|
key = kXVU[0]
|
|
|
|
key, x_key, U_key = split(key, 3)
|
|
|
|
_, x, v = lax.while_loop(lambda kxv: lax.le(kxv[2], zero), _next_kxv, (x_key, zero, minus_one))
|
2019-03-31 23:54:31 -04:00
|
|
|
X = lax.mul(x, x)
|
|
|
|
V = lax.mul(lax.mul(v, v), v)
|
2019-06-27 17:28:36 -04:00
|
|
|
U = uniform(U_key, (), dtype=dtype)
|
2019-03-30 18:07:34 -04:00
|
|
|
return key, X, V, U
|
|
|
|
|
2019-03-31 23:54:31 -04:00
|
|
|
# initial state is chosen such that _cond_fn will return True
|
2019-06-20 18:41:52 -04:00
|
|
|
_, _, V, _ = lax.while_loop(_cond_fn, _body_fn, (key, zero, one, _constant_like(alpha, 2)))
|
2019-03-31 23:54:31 -04:00
|
|
|
z = lax.mul(lax.mul(d, V), boost)
|
2020-05-21 18:12:18 -03:00
|
|
|
return lax.select(lax.eq(z, zero), jnp.finfo(z.dtype).tiny, z)
|
2019-03-30 18:07:34 -04:00
|
|
|
|
2019-06-20 18:41:52 -04:00
|
|
|
|
|
|
|
def _gamma_grad(sample, a):
|
2020-05-21 18:12:18 -03:00
|
|
|
samples = jnp.reshape(sample, -1)
|
|
|
|
alphas = jnp.reshape(a, -1)
|
2019-12-23 22:52:15 -05:00
|
|
|
if xla_bridge.get_backend().platform == 'cpu':
|
2020-06-19 06:34:18 -07:00
|
|
|
grads = lax.map(lambda args: lax.random_gamma_grad(*args), (alphas, samples))
|
2019-12-23 22:52:15 -05:00
|
|
|
else:
|
2020-06-19 06:34:18 -07:00
|
|
|
grads = vmap(lax.random_gamma_grad)(alphas, samples)
|
2020-05-21 18:12:18 -03:00
|
|
|
return grads.reshape(np.shape(a))
|
2019-06-20 18:41:52 -04:00
|
|
|
|
2020-09-03 14:18:35 +03:00
|
|
|
def _gamma_impl(key, a, use_vmap=False):
|
2020-05-21 18:12:18 -03:00
|
|
|
a_shape = jnp.shape(a)
|
2019-12-26 22:43:06 -05:00
|
|
|
# split key to match the shape of a
|
2020-05-21 18:12:18 -03:00
|
|
|
key_ndim = jnp.ndim(key) - 1
|
|
|
|
key = jnp.reshape(key, (-1, 2))
|
2019-12-26 22:43:06 -05:00
|
|
|
key = vmap(split, in_axes=(0, None))(key, prod(a_shape[key_ndim:]))
|
2020-05-21 18:12:18 -03:00
|
|
|
keys = jnp.reshape(key, (-1, 2))
|
|
|
|
alphas = jnp.reshape(a, -1)
|
2020-09-03 14:18:35 +03:00
|
|
|
if use_vmap:
|
2019-06-20 18:41:52 -04:00
|
|
|
samples = vmap(_gamma_one)(keys, alphas)
|
2020-09-03 14:18:35 +03:00
|
|
|
else:
|
|
|
|
samples = lax.map(lambda args: _gamma_one(*args), (keys, alphas))
|
|
|
|
|
2020-05-27 13:57:47 +00:00
|
|
|
return jnp.reshape(samples, a_shape)
|
2019-12-01 09:44:45 -05:00
|
|
|
|
|
|
|
def _gamma_batching_rule(batched_args, batch_dims):
|
|
|
|
k, a = batched_args
|
|
|
|
bk, ba = batch_dims
|
|
|
|
size = next(t.shape[i] for t, i in zip(batched_args, batch_dims) if i is not None)
|
|
|
|
k = batching.bdim_at_front(k, bk, size)
|
|
|
|
a = batching.bdim_at_front(a, ba, size)
|
2020-05-27 13:57:47 +00:00
|
|
|
return random_gamma_p.bind(k, a), 0
|
2019-12-01 09:44:45 -05:00
|
|
|
|
|
|
|
random_gamma_p = core.Primitive('random_gamma')
|
2019-12-23 22:52:15 -05:00
|
|
|
random_gamma_p.def_impl(_gamma_impl)
|
2020-05-27 13:57:47 +00:00
|
|
|
random_gamma_p.def_abstract_eval(lambda key, a: abstract_arrays.raise_to_shaped(a))
|
|
|
|
ad.defjvp2(random_gamma_p, None, lambda tangent, ans, key, a: tangent * _gamma_grad(ans, a))
|
2020-09-03 14:18:35 +03:00
|
|
|
xla.translations[random_gamma_p] = xla.lower_fun(
|
|
|
|
partial(_gamma_impl, use_vmap=True),
|
|
|
|
multiple_results=False)
|
|
|
|
xla.backend_specific_translations['cpu'][random_gamma_p] = xla.lower_fun(
|
|
|
|
partial(_gamma_impl, use_vmap=False),
|
|
|
|
multiple_results=False)
|
2019-12-01 09:44:45 -05:00
|
|
|
batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule
|
2019-06-20 18:41:52 -04:00
|
|
|
|
2020-07-26 08:58:37 -07:00
|
|
|
def gamma(key, a, shape=None, dtype=dtypes.float_):
|
2019-03-30 18:07:34 -04:00
|
|
|
"""Sample Gamma random values with given shape and float dtype.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
2019-10-17 20:36:51 +00:00
|
|
|
a: a float or array of floats broadcast-compatible with ``shape``
|
|
|
|
representing the parameter of the distribution.
|
|
|
|
shape: optional, a tuple of nonnegative integers specifying the result
|
2019-10-20 21:14:48 +00:00
|
|
|
shape. Must be broadcast-compatible with ``a``. The default (None)
|
|
|
|
produces a result shape equal to ``a.shape``.
|
2019-05-22 16:22:12 -07:00
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
2019-03-30 18:07:34 -04:00
|
|
|
|
|
|
|
Returns:
|
2019-10-20 21:14:48 +00:00
|
|
|
A random array with the specified dtype and with shape given by ``shape`` if
|
|
|
|
``shape`` is not None, or else by ``a.shape``.
|
2019-03-30 18:07:34 -04:00
|
|
|
"""
|
2020-06-04 10:13:15 -07:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `gamma` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
2019-11-15 10:02:51 -05:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-02-05 10:10:33 -08:00
|
|
|
if shape is not None:
|
|
|
|
shape = abstract_arrays.canonicalize_shape(shape)
|
2019-04-12 07:15:41 -07:00
|
|
|
return _gamma(key, a, shape, dtype)
|
2019-04-10 13:00:31 -07:00
|
|
|
|
2019-04-12 07:15:41 -07:00
|
|
|
@partial(jit, static_argnums=(2, 3))
|
2019-05-22 16:22:12 -07:00
|
|
|
def _gamma(key, a, shape, dtype):
|
2019-10-20 21:14:48 +00:00
|
|
|
if shape is None:
|
2020-05-21 18:12:18 -03:00
|
|
|
shape = np.shape(a)
|
2019-10-20 21:14:48 +00:00
|
|
|
else:
|
2020-05-21 18:12:18 -03:00
|
|
|
_check_shape("gamma", shape, np.shape(a))
|
2019-10-20 21:14:48 +00:00
|
|
|
|
2019-03-30 18:07:34 -04:00
|
|
|
a = lax.convert_element_type(a, dtype)
|
2020-05-21 18:12:18 -03:00
|
|
|
if np.shape(a) != shape:
|
|
|
|
a = jnp.broadcast_to(a, shape)
|
2020-05-27 13:57:47 +00:00
|
|
|
return random_gamma_p.bind(key, a)
|
2019-03-30 18:07:34 -04:00
|
|
|
|
|
|
|
|
2020-05-02 08:24:59 -07:00
|
|
|
@partial(jit, static_argnums=(2, 3, 4))
|
|
|
|
def _poisson_knuth(key, lam, shape, dtype, max_iters):
|
|
|
|
# Knuth's algorithm for generating Poisson random variates.
|
|
|
|
# Reference:
|
|
|
|
# https://en.wikipedia.org/wiki/Poisson_distribution#Generating_Poisson-distributed_random_variables
|
|
|
|
|
|
|
|
def body_fn(carry):
|
|
|
|
i, k, rng, log_prod = carry
|
|
|
|
rng, subkey = split(rng)
|
|
|
|
k = lax.select(log_prod > -lam, k + 1, k)
|
2020-05-21 18:12:18 -03:00
|
|
|
u = uniform(subkey, shape, np.float32)
|
|
|
|
return i + 1, k, rng, log_prod + jnp.log(u)
|
2020-05-02 08:24:59 -07:00
|
|
|
|
|
|
|
def cond_fn(carry):
|
|
|
|
i, log_prod = carry[0], carry[3]
|
|
|
|
return (log_prod > -lam).any() & (i < max_iters)
|
|
|
|
|
|
|
|
k_init = lax.full_like(lam, 0, dtype, shape)
|
2020-05-21 18:12:18 -03:00
|
|
|
log_rate_init = lax.full_like(lam, 0, np.float32, shape)
|
2020-05-02 08:24:59 -07:00
|
|
|
k = lax.while_loop(cond_fn, body_fn, (0, k_init, key, log_rate_init))[1]
|
|
|
|
return (k - 1).astype(dtype)
|
|
|
|
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(2, 3, 4))
|
|
|
|
def _poisson_rejection(key, lam, shape, dtype, max_iters):
|
|
|
|
# Transformed rejection due to Hormann.
|
|
|
|
# Reference:
|
|
|
|
# http://citeseer.ist.psu.edu/viewdoc/citations;jsessionid=1BEB35946CC807879F55D42512E5490C?doi=10.1.1.48.3054.
|
|
|
|
log_lam = lax.log(lam)
|
|
|
|
b = 0.931 + 2.53 * lax.sqrt(lam)
|
|
|
|
a = -0.059 + 0.02483 * b
|
|
|
|
inv_alpha = 1.1239 + 1.1328 / (b - 3.4)
|
|
|
|
v_r = 0.9277 - 3.6224 / (b - 2)
|
|
|
|
|
|
|
|
def body_fn(carry):
|
|
|
|
i, k_out, accepted, key = carry
|
|
|
|
key, subkey_0, subkey_1 = split(key, 3)
|
|
|
|
|
|
|
|
u = uniform(subkey_0, shape, lam.dtype) - 0.5
|
|
|
|
v = uniform(subkey_1, shape, lam.dtype)
|
|
|
|
u_shifted = 0.5 - abs(u)
|
|
|
|
|
|
|
|
k = lax.floor((2 * a / u_shifted + b) * u + lam + 0.43)
|
|
|
|
s = lax.log(v * inv_alpha / (a / (u_shifted * u_shifted) + b))
|
|
|
|
t = -lam + k * log_lam - lax.lgamma(k + 1)
|
|
|
|
|
|
|
|
accept1 = (u_shifted >= 0.07) & (v <= v_r)
|
|
|
|
reject = (k < 0) | ((u_shifted < 0.013) & (v > u_shifted))
|
|
|
|
accept2 = s <= t
|
|
|
|
accept = accept1 | (~reject & accept2)
|
|
|
|
|
|
|
|
k_out = lax.select(accept, k, k_out)
|
|
|
|
accepted |= accept
|
|
|
|
|
|
|
|
return i + 1, k_out, accepted, key
|
|
|
|
|
|
|
|
def cond_fn(carry):
|
|
|
|
i, k_out, accepted, key = carry
|
|
|
|
return (~accepted).any() & (i < max_iters)
|
|
|
|
|
|
|
|
k_init = lax.full_like(lam, -1, lam.dtype, shape)
|
2020-05-21 18:12:18 -03:00
|
|
|
accepted = lax.full_like(lam, False, jnp.bool_, shape)
|
2020-05-02 08:24:59 -07:00
|
|
|
k = lax.while_loop(cond_fn, body_fn, (0, k_init, accepted, key))[1]
|
|
|
|
return k.astype(dtype)
|
|
|
|
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(2, 3))
|
|
|
|
def _poisson(key, lam, shape, dtype):
|
|
|
|
# The implementation matches TensorFlow and NumPy:
|
|
|
|
# https://github.com/tensorflow/tensorflow/blob/v2.2.0-rc3/tensorflow/core/kernels/random_poisson_op.cc
|
|
|
|
# https://github.com/numpy/numpy/blob/v1.18.3/numpy/random/src/distributions/distributions.c#L574
|
|
|
|
# For lambda < 10, we use the Knuth algorithm; otherwise, we use transformed
|
|
|
|
# rejection sampling.
|
|
|
|
use_knuth = lam < 10
|
|
|
|
lam_knuth = lax.select(use_knuth, lam, lax.full_like(lam, 0.0))
|
|
|
|
# The acceptance probability for rejection sampling maxes out at 89% as
|
|
|
|
# λ -> ∞, so pick some arbitrary large value.
|
|
|
|
lam_rejection = lax.select(use_knuth, lax.full_like(lam, 1e5), lam)
|
2020-07-26 08:58:37 -07:00
|
|
|
max_iters = dtype.type(jnp.iinfo(dtype).max) # insanely conservative
|
2020-05-02 08:24:59 -07:00
|
|
|
return lax.select(
|
|
|
|
use_knuth,
|
|
|
|
_poisson_knuth(key, lam_knuth, shape, dtype, max_iters),
|
|
|
|
_poisson_rejection(key, lam_rejection, shape, dtype, max_iters),
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2020-07-26 08:58:37 -07:00
|
|
|
def poisson(key, lam, shape=(), dtype=dtypes.int_):
|
2020-05-02 08:24:59 -07:00
|
|
|
"""Sample Poisson random values with given shape and integer dtype.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
|
|
|
lam: rate parameter (mean of the distribution), must be >= 0.
|
|
|
|
shape: optional, a tuple of nonnegative integers representing the result
|
|
|
|
shape. Default ().
|
|
|
|
dtype: optional, a integer dtype for the returned values (default int64 if
|
|
|
|
jax_enable_x64 is true, otherwise int32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
|
|
|
shape = abstract_arrays.canonicalize_shape(shape)
|
2020-05-21 18:12:18 -03:00
|
|
|
if np.shape(lam) != shape:
|
|
|
|
lam = jnp.broadcast_to(lam, shape)
|
2020-06-12 01:42:25 -04:00
|
|
|
lam = lax.convert_element_type(lam, np.float32)
|
2020-05-02 08:24:59 -07:00
|
|
|
return _poisson(key, lam, shape, dtype)
|
|
|
|
|
|
|
|
|
2020-07-26 08:58:37 -07:00
|
|
|
def gumbel(key, shape=(), dtype=dtypes.float_):
|
2019-04-21 16:25:20 -04:00
|
|
|
"""Sample Gumbel random values with given shape and float dtype.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
2019-10-17 20:36:51 +00:00
|
|
|
shape: optional, a tuple of nonnegative integers representing the result
|
|
|
|
shape. Default ().
|
2019-05-22 16:22:12 -07:00
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
2019-04-21 16:25:20 -04:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
2020-06-04 10:13:15 -07:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `gumbel` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
2019-11-15 10:02:51 -05:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-02-05 10:10:33 -08:00
|
|
|
shape = abstract_arrays.canonicalize_shape(shape)
|
2019-04-21 16:25:20 -04:00
|
|
|
return _gumbel(key, shape, dtype)
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(1, 2))
|
|
|
|
def _gumbel(key, shape, dtype):
|
2019-05-09 11:40:19 -07:00
|
|
|
_check_shape("gumbel", shape)
|
2020-05-21 18:12:18 -03:00
|
|
|
return -jnp.log(-jnp.log(
|
|
|
|
uniform(key, shape, dtype, minval=jnp.finfo(dtype).eps, maxval=1.)))
|
2019-04-21 16:25:20 -04:00
|
|
|
|
2020-05-02 08:24:59 -07:00
|
|
|
|
2020-01-10 13:28:03 +00:00
|
|
|
def categorical(key, logits, axis=-1, shape=None):
|
2019-12-13 11:46:08 +00:00
|
|
|
"""Sample random values from categorical distributions.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
2020-01-10 13:28:03 +00:00
|
|
|
logits: Unnormalized log probabilities of the categorical distribution(s) to sample from,
|
|
|
|
so that `softmax(logits, axis)` gives the corresponding probabilities.
|
|
|
|
axis: Axis along which logits belong to the same categorical distribution.
|
|
|
|
shape: Optional, a tuple of nonnegative integers representing the result shape.
|
2020-05-21 18:12:18 -03:00
|
|
|
Must be broadcast-compatible with ``np.delete(logits.shape, axis)``.
|
|
|
|
The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``.
|
2019-12-13 11:46:08 +00:00
|
|
|
|
|
|
|
Returns:
|
2020-01-10 13:28:03 +00:00
|
|
|
A random array with int dtype and shape given by ``shape`` if ``shape``
|
2020-05-21 18:12:18 -03:00
|
|
|
is not None, or else ``np.delete(logits.shape, axis)``.
|
2020-01-10 13:28:03 +00:00
|
|
|
"""
|
|
|
|
|
2019-12-13 11:46:08 +00:00
|
|
|
if axis >= 0:
|
2020-01-10 13:28:03 +00:00
|
|
|
axis -= len(logits.shape)
|
|
|
|
|
2020-05-21 18:12:18 -03:00
|
|
|
batch_shape = tuple(np.delete(logits.shape, axis))
|
2020-01-10 13:28:03 +00:00
|
|
|
if shape is None:
|
|
|
|
shape = batch_shape
|
|
|
|
else:
|
|
|
|
_check_shape("categorical", shape, batch_shape)
|
2019-12-13 11:46:08 +00:00
|
|
|
|
2020-01-10 13:28:03 +00:00
|
|
|
sample_shape = shape[:len(shape)-len(batch_shape)]
|
2020-05-21 18:12:18 -03:00
|
|
|
return jnp.argmax(gumbel(key, sample_shape + logits.shape, logits.dtype) + logits, axis=axis)
|
2019-04-21 16:25:20 -04:00
|
|
|
|
2020-05-02 08:24:59 -07:00
|
|
|
|
2020-07-26 08:58:37 -07:00
|
|
|
def laplace(key, shape=(), dtype=dtypes.float_):
|
2019-03-28 23:57:00 -04:00
|
|
|
"""Sample Laplace random values with given shape and float dtype.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
2019-10-17 20:36:51 +00:00
|
|
|
shape: optional, a tuple of nonnegative integers representing the result
|
|
|
|
shape. Default ().
|
2019-05-22 16:22:12 -07:00
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
2019-03-28 23:57:00 -04:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
2020-06-04 10:13:15 -07:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `laplace` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
2019-11-15 10:02:51 -05:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-02-05 10:10:33 -08:00
|
|
|
shape = abstract_arrays.canonicalize_shape(shape)
|
2019-04-10 22:09:14 -07:00
|
|
|
return _laplace(key, shape, dtype)
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(1, 2))
|
|
|
|
def _laplace(key, shape, dtype):
|
2019-05-09 11:40:19 -07:00
|
|
|
_check_shape("laplace", shape)
|
2019-07-04 19:44:01 -07:00
|
|
|
u = uniform(
|
2020-05-21 18:12:18 -03:00
|
|
|
key, shape, dtype, minval=-1. + jnp.finfo(dtype).epsneg, maxval=1.)
|
2019-03-28 23:57:00 -04:00
|
|
|
return lax.mul(lax.sign(u), lax.log1p(lax.neg(lax.abs(u))))
|
2019-03-30 16:34:20 -04:00
|
|
|
|
|
|
|
|
2020-07-26 08:58:37 -07:00
|
|
|
def logistic(key, shape=(), dtype=dtypes.float_):
|
2019-08-06 12:19:05 +01:00
|
|
|
"""Sample logistic random values with given shape and float dtype.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
2019-10-17 20:36:51 +00:00
|
|
|
shape: optional, a tuple of nonnegative integers representing the result
|
|
|
|
shape. Default ().
|
2019-08-06 12:19:05 +01:00
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
2020-06-04 10:13:15 -07:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `logistic` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
2019-11-15 10:02:51 -05:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-02-05 10:10:33 -08:00
|
|
|
shape = abstract_arrays.canonicalize_shape(shape)
|
2019-08-06 12:19:05 +01:00
|
|
|
return _logistic(key, shape, dtype)
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(1, 2))
|
|
|
|
def _logistic(key, shape, dtype):
|
2020-05-15 14:29:02 -07:00
|
|
|
# Mathematically, we can compute the distribution by generating uniformly-distributed
|
|
|
|
# numbers x in the open interval (a, b) and computing:
|
|
|
|
# z = log[ (x - a) / (b - x))
|
|
|
|
# It's important to avoid x=a or x=b, which lead to infinite values for z.
|
|
|
|
# The uniform() function generates pseudorandom floating point numbers x in the
|
|
|
|
# semi-closed interval [0, 1), so if used directly with (a,b)=(0,1), it will
|
|
|
|
# lead to infinite output in a small number of cases (as many as 1 in 2^23 for float32).
|
|
|
|
#
|
|
|
|
# Instead, we let (a, b) = (-ε, 1) where ε is the smallest step between floating point
|
|
|
|
# values: then numbers in the interval (-ε, 1) are approximated by standard uniformly
|
|
|
|
# drawn numbers in [0, 1).
|
2019-08-06 12:19:05 +01:00
|
|
|
_check_shape("logistic", shape)
|
2020-05-15 14:29:02 -07:00
|
|
|
x = uniform(key, shape, dtype)
|
2020-05-21 18:12:18 -03:00
|
|
|
eps = jnp.finfo(dtype).eps
|
2020-05-15 14:29:02 -07:00
|
|
|
return lax.log(lax.div(lax.add(lax._const(x, eps), x), lax.sub(lax._const(x, 1), x)))
|
2019-08-06 12:19:05 +01:00
|
|
|
|
|
|
|
|
2020-07-26 08:58:37 -07:00
|
|
|
def pareto(key, b, shape=None, dtype=dtypes.float_):
|
2019-03-30 16:34:20 -04:00
|
|
|
"""Sample Pareto random values with given shape and float dtype.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
2019-10-17 20:36:51 +00:00
|
|
|
a: a float or array of floats broadcast-compatible with ``shape``
|
|
|
|
representing the parameter of the distribution.
|
|
|
|
shape: optional, a tuple of nonnegative integers specifying the result
|
2019-10-20 21:14:48 +00:00
|
|
|
shape. Must be broadcast-compatible with ``b``. The default (None)
|
|
|
|
produces a result shape equal to ``b.shape``.
|
2019-05-22 16:22:12 -07:00
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
2019-03-30 16:34:20 -04:00
|
|
|
|
|
|
|
Returns:
|
2019-10-20 21:14:48 +00:00
|
|
|
A random array with the specified dtype and with shape given by ``shape`` if
|
|
|
|
``shape`` is not None, or else by ``b.shape``.
|
2019-03-30 16:34:20 -04:00
|
|
|
"""
|
2020-06-04 10:13:15 -07:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `pareto` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
2019-11-15 10:02:51 -05:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-02-05 10:10:33 -08:00
|
|
|
if shape is not None:
|
|
|
|
shape = abstract_arrays.canonicalize_shape(shape)
|
2019-04-10 22:09:14 -07:00
|
|
|
return _pareto(key, b, shape, dtype)
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(2, 3))
|
|
|
|
def _pareto(key, b, shape, dtype):
|
2019-10-20 21:14:48 +00:00
|
|
|
if shape is None:
|
2020-05-21 18:12:18 -03:00
|
|
|
shape = np.shape(b)
|
2019-10-20 21:14:48 +00:00
|
|
|
else:
|
|
|
|
_check_shape("pareto", shape)
|
|
|
|
|
2019-03-30 16:34:20 -04:00
|
|
|
b = lax.convert_element_type(b, dtype)
|
|
|
|
e = exponential(key, shape, dtype)
|
2019-10-17 20:36:51 +00:00
|
|
|
return lax.exp(e / b)
|
2019-04-02 10:55:03 +01:00
|
|
|
|
|
|
|
|
2020-07-26 08:58:37 -07:00
|
|
|
def t(key, df, shape=(), dtype=dtypes.float_):
|
2019-04-21 16:25:20 -04:00
|
|
|
"""Sample Student's t random values with given shape and float dtype.
|
2019-04-02 10:55:03 +01:00
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
2019-10-17 20:36:51 +00:00
|
|
|
df: a float or array of floats broadcast-compatible with ``shape``
|
|
|
|
representing the parameter of the distribution.
|
|
|
|
shape: optional, a tuple of nonnegative integers specifying the result
|
2019-10-20 21:14:48 +00:00
|
|
|
shape. Must be broadcast-compatible with ``df``. The default (None)
|
|
|
|
produces a result shape equal to ``df.shape``.
|
2019-05-22 16:22:12 -07:00
|
|
|
dtype: optional, a float dtype for the returned values (default float64 if
|
|
|
|
jax_enable_x64 is true, otherwise float32).
|
2019-04-02 10:55:03 +01:00
|
|
|
|
|
|
|
Returns:
|
2019-10-20 21:14:48 +00:00
|
|
|
A random array with the specified dtype and with shape given by ``shape`` if
|
|
|
|
``shape`` is not None, or else by ``df.shape``.
|
2019-04-02 10:55:03 +01:00
|
|
|
"""
|
2020-06-04 10:13:15 -07:00
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `t` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
2019-11-15 10:02:51 -05:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-02-05 10:10:33 -08:00
|
|
|
shape = abstract_arrays.canonicalize_shape(shape)
|
2019-04-21 16:25:20 -04:00
|
|
|
return _t(key, df, shape, dtype)
|
2019-04-10 22:09:14 -07:00
|
|
|
|
2019-04-21 16:25:20 -04:00
|
|
|
@partial(jit, static_argnums=(2, 3))
|
|
|
|
def _t(key, df, shape, dtype):
|
2019-10-20 21:14:48 +00:00
|
|
|
if shape is None:
|
2020-05-21 18:12:18 -03:00
|
|
|
shape = np.shape(df)
|
2019-10-20 21:14:48 +00:00
|
|
|
else:
|
2020-05-21 18:12:18 -03:00
|
|
|
_check_shape("t", shape, np.shape(df))
|
2019-10-20 21:14:48 +00:00
|
|
|
|
2019-04-21 16:25:20 -04:00
|
|
|
df = lax.convert_element_type(df, dtype)
|
2019-04-21 16:43:18 -04:00
|
|
|
key_n, key_g = split(key)
|
2019-04-21 16:25:20 -04:00
|
|
|
n = normal(key_n, shape, dtype)
|
|
|
|
two = _constant_like(n, 2)
|
|
|
|
half_df = lax.div(df, two)
|
2019-04-21 16:43:18 -04:00
|
|
|
g = gamma(key_n, half_df, shape, dtype)
|
2020-05-21 18:12:18 -03:00
|
|
|
return n * jnp.sqrt(half_df / g)
|
2020-08-20 16:46:55 +02:00
|
|
|
|
|
|
|
|
|
|
|
def rademacher(key, shape, dtype=dtypes.int_):
|
|
|
|
"""Sample from a Rademacher distribution.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey key.
|
|
|
|
shape: The shape of the returned samples.
|
|
|
|
dtype: The type used for samples.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A jnp.array of samples, of shape `shape`. Each element in the output has
|
|
|
|
a 50% change of being 1 or -1.
|
|
|
|
|
|
|
|
"""
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
|
|
|
shape = abstract_arrays.canonicalize_shape(shape)
|
|
|
|
return _rademacher(key, shape, dtype)
|
|
|
|
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(1, 2))
|
|
|
|
def _rademacher(key, shape, dtype):
|
|
|
|
bernoulli_samples = bernoulli(key=key, p=0.5, shape=shape)
|
|
|
|
return (2 * bernoulli_samples - 1).astype(dtype)
|
|
|
|
|
|
|
|
|
|
|
|
def maxwell(key, shape=(), dtype=dtypes.float_):
|
|
|
|
"""Sample from a one sided Maxwell distribution.
|
|
|
|
|
|
|
|
The scipy counterpart is `scipy.stats.maxwell`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey key.
|
|
|
|
shape: The shape of the returned samples.
|
|
|
|
dtype: The type used for samples.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A jnp.array of samples, of shape `shape`.
|
|
|
|
|
|
|
|
"""
|
|
|
|
# Generate samples using:
|
|
|
|
# sqrt(X^2 + Y^2 + Z^2), X,Y,Z ~N(0,1)
|
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `maxwell` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
|
|
|
shape = abstract_arrays.canonicalize_shape(shape)
|
|
|
|
return _maxwell(key, shape, dtype)
|
|
|
|
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(1, 2))
|
|
|
|
def _maxwell(key, shape, dtype):
|
|
|
|
shape = shape + (3,)
|
|
|
|
norm_rvs = normal(key=key, shape=shape, dtype=dtype)
|
|
|
|
return jnp.linalg.norm(norm_rvs, axis=-1)
|
|
|
|
|
|
|
|
|
|
|
|
def double_sided_maxwell(key, loc, scale, shape=(), dtype=dtypes.float_):
|
|
|
|
"""Sample from a double sided Maxwell distribution.
|
|
|
|
|
|
|
|
Samples using:
|
|
|
|
loc + scale* sgn(U-0.5)* one_sided_maxwell U~Unif;
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey key.
|
|
|
|
loc: The location parameter of the distribution.
|
|
|
|
scale: The scale parameter of the distribution.
|
|
|
|
shape: The shape added to the parameters loc and scale broadcastable shape.
|
|
|
|
dtype: The type used for samples.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A jnp.array of samples.
|
|
|
|
|
|
|
|
"""
|
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `double_sided_maxwell` must be a float"
|
|
|
|
f" dtype, got {dtype}")
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
|
|
|
shape = abstract_arrays.canonicalize_shape(shape)
|
|
|
|
return _double_sided_maxwell(key, loc, scale, shape, dtype)
|
|
|
|
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(1, 2, 3, 4))
|
|
|
|
def _double_sided_maxwell(key, loc, scale, shape, dtype):
|
|
|
|
params_shapes = lax.broadcast_shapes(np.shape(loc), np.shape(scale))
|
|
|
|
if not shape:
|
|
|
|
shape = params_shapes
|
|
|
|
|
|
|
|
shape = shape + params_shapes
|
|
|
|
maxwell_key, rademacher_key = split(key)
|
|
|
|
maxwell_rvs = maxwell(maxwell_key, shape=shape, dtype=dtype)
|
|
|
|
# Generate random signs for the symmetric variates.
|
|
|
|
random_sign = rademacher(rademacher_key, shape=shape, dtype=dtype)
|
|
|
|
assert random_sign.shape == maxwell_rvs.shape
|
|
|
|
|
|
|
|
return random_sign * maxwell_rvs * scale + loc
|
|
|
|
|
|
|
|
|
|
|
|
def weibull_min(key, scale, concentration, shape=(), dtype=dtypes.float_):
|
|
|
|
"""Sample from a Weibull distribution.
|
|
|
|
|
|
|
|
The scipy counterpart is `scipy.stats.weibull_min`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey key.
|
|
|
|
scale: The scale parameter of the distribution.
|
|
|
|
concentration: The concentration parameter of the distribution.
|
|
|
|
shape: The shape added to the parameters loc and scale broadcastable shape.
|
|
|
|
dtype: The type used for samples.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A jnp.array of samples.
|
|
|
|
|
|
|
|
"""
|
|
|
|
if not dtypes.issubdtype(dtype, np.floating):
|
|
|
|
raise ValueError(f"dtype argument to `weibull_min` must be a float "
|
|
|
|
f"dtype, got {dtype}")
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
|
|
|
shape = abstract_arrays.canonicalize_shape(shape)
|
|
|
|
return _weibull_min(key, scale, concentration, shape, dtype)
|
|
|
|
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(1, 2, 3, 4))
|
|
|
|
def _weibull_min(key, scale, concentration, shape, dtype):
|
|
|
|
random_uniform = uniform(
|
|
|
|
key=key, shape=shape, minval=0, maxval=1, dtype=dtype)
|
|
|
|
|
|
|
|
# Inverse weibull CDF.
|
|
|
|
return jnp.power(-jnp.log1p(-random_uniform), 1.0/concentration) * scale
|