2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2019-02-13 19:42:47 -08:00
|
|
|
"""JAX pseudo-random number generators (PRNGs).
|
|
|
|
|
|
|
|
The JAX PRNG system is based on "Parallel random numbers: as easy as 1, 2, 3"
|
|
|
|
(Salmon et al. 2011). For details on the design and its motivation, see:
|
|
|
|
|
|
|
|
https://github.com/google/jax/blob/master/design_notes/prng.md
|
|
|
|
"""
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
2018-11-21 18:31:13 -08:00
|
|
|
from functools import partial
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
import numpy as onp
|
|
|
|
|
|
|
|
from . import lax
|
|
|
|
from . import numpy as np
|
|
|
|
from . import tree_util
|
2019-03-30 18:07:34 -04:00
|
|
|
from .api import jit, vmap
|
2019-04-22 11:55:02 -04:00
|
|
|
from .numpy.lax_numpy import _constant_like, asarray
|
2018-12-06 21:35:03 -05:00
|
|
|
from jax.lib import xla_bridge
|
2018-12-15 19:14:05 -08:00
|
|
|
from jax import core
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-30 21:42:55 -08:00
|
|
|
|
|
|
|
def PRNGKey(seed):
|
2019-02-16 23:31:27 +09:00
|
|
|
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
|
2019-02-13 19:42:47 -08:00
|
|
|
|
|
|
|
Args:
|
2019-02-13 20:05:15 -08:00
|
|
|
seed: a 64- or 32-bit integer used as the value of the key.
|
2019-02-13 19:42:47 -08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
A PRNG key, which is modeled as an array of shape (2,) and dtype uint32. The
|
|
|
|
key is constructed from a 64-bit seed by effectively bit-casting to a pair
|
|
|
|
of uint32 values (or from a 32-bit seed by first padding out with zeros).
|
|
|
|
"""
|
2018-12-30 21:42:55 -08:00
|
|
|
if onp.shape(seed):
|
|
|
|
raise TypeError("PRNGKey seed must be a scalar.")
|
|
|
|
convert = lambda k: lax.reshape(lax.convert_element_type(k, onp.uint32), [1])
|
|
|
|
if isinstance(seed, (int, onp.ndarray)):
|
|
|
|
# Special handling of raw integer values, which may have be 64bit even
|
|
|
|
# when jax_enable_x64=False and we don't want to drop the top 32 bits
|
|
|
|
k1 = convert(onp.bitwise_and(onp.right_shift(seed, 32), 0xFFFFFFFF))
|
|
|
|
else:
|
|
|
|
k1 = convert(lax.shift_right_logical(seed, 32))
|
|
|
|
k2 = convert(lax.bitwise_and(seed, 0xFFFFFFFF))
|
|
|
|
return lax.concatenate([k1, k2], 0)
|
|
|
|
|
2019-02-13 19:42:47 -08:00
|
|
|
def _is_prng_key(key):
|
2018-12-30 21:42:55 -08:00
|
|
|
try:
|
|
|
|
return key.shape == (2,) and key.dtype == onp.uint32
|
|
|
|
except AttributeError:
|
|
|
|
return False
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
### utilities
|
|
|
|
|
|
|
|
|
|
|
|
def _make_rotate_left(dtype):
|
|
|
|
if not onp.issubdtype(dtype, onp.integer):
|
|
|
|
raise TypeError("_rotate_left only accepts integer dtypes.")
|
|
|
|
nbits = onp.array(onp.iinfo(dtype).bits, dtype)
|
|
|
|
|
|
|
|
def _rotate_left(x, d):
|
2019-04-12 16:28:40 -07:00
|
|
|
if lax.dtype(d) != lax.dtype(x):
|
2018-11-17 18:03:33 -08:00
|
|
|
d = lax.convert_element_type(d, x.dtype)
|
|
|
|
return (x << d) | lax.shift_right_logical(x, nbits - d)
|
|
|
|
return _rotate_left
|
|
|
|
|
|
|
|
|
|
|
|
def _bit_stats(bits):
|
|
|
|
"""This is a debugging function to compute the statistics of bit fields."""
|
2018-11-21 13:20:44 -08:00
|
|
|
return onp.array([list(map(int, onp.binary_repr(x, 64))) for x in bits]).mean(0)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
### hash function and split
|
|
|
|
|
|
|
|
|
2018-11-26 09:58:05 -08:00
|
|
|
@jit
|
2018-11-17 18:03:33 -08:00
|
|
|
def threefry_2x32(keypair, count):
|
|
|
|
"""Apply the Threefry 2x32 hash.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
keypair: a pair of 32bit unsigned integers used for the key.
|
|
|
|
count: an array of dtype uint32 used for the counts.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array of dtype uint32 with the same shape as `count`.
|
|
|
|
"""
|
|
|
|
# Based on ThreeFry2x32 by phawkins@ in //.../xla/client/lib/prng.cc
|
2018-12-15 19:14:05 -08:00
|
|
|
key1, key2 = keypair
|
2019-04-12 16:28:40 -07:00
|
|
|
if not lax.dtype(key1) == lax.dtype(key2) == lax.dtype(count) == onp.uint32:
|
2018-11-17 18:03:33 -08:00
|
|
|
msg = "threefry_2x32 requires uint32 arguments, got {}"
|
2019-04-12 16:28:40 -07:00
|
|
|
raise TypeError(msg.format([lax.dtype(x) for x in [key1, key2, count]]))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-04-12 16:28:40 -07:00
|
|
|
rotate_left = _make_rotate_left(lax.dtype(count))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def apply_round(v, rot):
|
|
|
|
v = v[:]
|
|
|
|
v[0] = v[0] + v[1]
|
|
|
|
v[1] = rotate_left(v[1], rot)
|
|
|
|
v[1] = v[0] ^ v[1]
|
|
|
|
return v
|
|
|
|
|
|
|
|
odd_size = count.size % 2
|
|
|
|
if odd_size:
|
|
|
|
x = list(np.split(np.concatenate([count.ravel(), onp.uint32([0])]), 2))
|
|
|
|
else:
|
|
|
|
x = list(np.split(count.ravel(), 2))
|
|
|
|
|
|
|
|
rotations = [13, 15, 26, 6, 17, 29, 16, 24]
|
|
|
|
ks = [key1, key2, key1 ^ key2 ^ onp.uint32(0x1BD11BDA)]
|
|
|
|
|
|
|
|
x[0] = x[0] + ks[0]
|
|
|
|
x[1] = x[1] + ks[1]
|
|
|
|
|
|
|
|
for r in rotations[:4]:
|
|
|
|
x = apply_round(x, r)
|
|
|
|
x[0] = x[0] + ks[1]
|
|
|
|
x[1] = x[1] + ks[2] + onp.uint32(1)
|
|
|
|
|
|
|
|
for r in rotations[4:]:
|
|
|
|
x = apply_round(x, r)
|
|
|
|
x[0] = x[0] + ks[2]
|
|
|
|
x[1] = x[1] + ks[0] + onp.uint32(2)
|
|
|
|
|
|
|
|
for r in rotations[:4]:
|
|
|
|
x = apply_round(x, r)
|
|
|
|
x[0] = x[0] + ks[0]
|
|
|
|
x[1] = x[1] + ks[1] + onp.uint32(3)
|
|
|
|
|
|
|
|
for r in rotations[4:]:
|
|
|
|
x = apply_round(x, r)
|
|
|
|
x[0] = x[0] + ks[1]
|
|
|
|
x[1] = x[1] + ks[2] + onp.uint32(4)
|
|
|
|
|
|
|
|
for r in rotations[:4]:
|
|
|
|
x = apply_round(x, r)
|
|
|
|
x[0] = x[0] + ks[2]
|
|
|
|
x[1] = x[1] + ks[0] + onp.uint32(5)
|
|
|
|
|
|
|
|
out = np.concatenate(x)
|
|
|
|
assert out.dtype == onp.uint32
|
|
|
|
return lax.reshape(out[:-1] if odd_size else out, count.shape)
|
|
|
|
|
|
|
|
|
|
|
|
def split(key, num=2):
|
2019-02-13 09:55:36 -08:00
|
|
|
"""Splits a PRNG key into `num` new keys by adding a leading axis.
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
Args:
|
2019-02-13 09:55:36 -08:00
|
|
|
key: a PRNGKey (an array with shape (2,) and dtype uint32).
|
2018-11-17 18:03:33 -08:00
|
|
|
num: optional, a positive integer indicating the number of keys to produce
|
|
|
|
(default 2).
|
|
|
|
|
|
|
|
Returns:
|
2019-02-13 09:55:36 -08:00
|
|
|
An array with shape (num, 2) and dtype uint32 representing `num` new keys.
|
2018-11-17 18:03:33 -08:00
|
|
|
"""
|
2019-04-10 22:09:14 -07:00
|
|
|
return _split(key, num)
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(1,))
|
|
|
|
def _split(key, num):
|
2018-12-30 21:42:55 -08:00
|
|
|
counts = lax.tie_in(key, lax.iota(onp.uint32, num * 2))
|
2019-01-02 12:52:39 -08:00
|
|
|
return lax.reshape(threefry_2x32(key, counts), (num, 2))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2019-02-13 09:55:36 -08:00
|
|
|
def fold_in(key, data):
|
|
|
|
"""Folds in data to a PRNG key to form a new PRNG key.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey (an array with shape (2,) and dtype uint32).
|
|
|
|
data: an integer representing data to be folded in to the key.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A new PRNGKey that is a deterministic function of the inputs and is
|
|
|
|
statistically safe for producing a stream of new pseudo-random values.
|
|
|
|
"""
|
2019-04-10 22:09:14 -07:00
|
|
|
return _fold_in(key, data)
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(1,))
|
|
|
|
def _fold_in(key, data):
|
2019-02-13 09:55:36 -08:00
|
|
|
key2 = lax.tie_in(key, PRNGKey(data))
|
|
|
|
return threefry_2x32(key, key2)
|
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def _random_bits(key, bit_width, shape):
|
|
|
|
"""Sample uniform random bits of given width and shape using PRNG key."""
|
2019-02-13 19:42:47 -08:00
|
|
|
if not _is_prng_key(key):
|
2018-12-30 21:42:55 -08:00
|
|
|
raise TypeError("_random_bits got invalid prng key.")
|
2018-11-17 18:03:33 -08:00
|
|
|
if bit_width not in (32, 64):
|
|
|
|
raise TypeError("requires 32- or 64-bit field width.")
|
|
|
|
max_count = (bit_width // 32) * onp.prod(shape)
|
|
|
|
if max_count >= onp.iinfo(onp.uint32).max:
|
|
|
|
# TODO(mattjj): just split the key here
|
|
|
|
raise TypeError("requesting more random bits than a single call provides.")
|
|
|
|
|
2018-12-30 21:42:55 -08:00
|
|
|
counts = lax.tie_in(key, lax.iota(onp.uint32, max_count))
|
|
|
|
bits = threefry_2x32(key, counts)
|
2018-11-17 18:03:33 -08:00
|
|
|
if bit_width == 64:
|
|
|
|
bits = [lax.convert_element_type(x, onp.uint64) for x in np.split(bits, 2)]
|
2018-12-06 13:25:42 -05:00
|
|
|
bits = (bits[0] << onp.uint64(32)) | bits[1]
|
2018-11-17 18:03:33 -08:00
|
|
|
return lax.reshape(bits, shape)
|
|
|
|
|
|
|
|
|
|
|
|
### random samplers
|
|
|
|
|
|
|
|
|
|
|
|
def uniform(key, shape, dtype=onp.float32, minval=0., maxval=1.):
|
|
|
|
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
|
|
|
shape: a tuple of nonnegative integers representing the shape.
|
|
|
|
dtype: optional, a float dtype for the returned values (default float32).
|
|
|
|
minval: optional, a minimum (inclusive) value for the range (default 0).
|
|
|
|
maxval: optional, a maximum (exclusive) value for the range (default 1).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
2019-04-10 22:09:14 -07:00
|
|
|
return _uniform(key, shape, dtype, minval, maxval)
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(1, 2))
|
|
|
|
def _uniform(key, shape, dtype, minval, maxval):
|
2018-11-17 18:03:33 -08:00
|
|
|
if not onp.issubdtype(dtype, onp.floating):
|
|
|
|
raise TypeError("uniform only accepts floating point dtypes.")
|
|
|
|
|
|
|
|
dtype = xla_bridge.canonicalize_dtype(dtype)
|
|
|
|
minval = lax.convert_element_type(minval, dtype)
|
|
|
|
maxval = lax.convert_element_type(maxval, dtype)
|
|
|
|
finfo = onp.finfo(dtype)
|
|
|
|
nbits, nmant = finfo.bits, finfo.nmant
|
|
|
|
|
|
|
|
if nbits not in (32, 64):
|
|
|
|
raise TypeError("uniform only accepts 32- or 64-bit dtypes.")
|
|
|
|
|
|
|
|
bits = _random_bits(key, nbits, shape)
|
|
|
|
|
|
|
|
# The strategy here is to randomize only the mantissa bits with an exponent of
|
|
|
|
# 1 (after applying the bias), then shift and scale to the desired range. The
|
|
|
|
# bit-level transformation we use relies on Numpy and XLA having bit-for-bit
|
|
|
|
# equivalent float representations, which might not be true on all platforms.
|
|
|
|
float_bits = lax.bitwise_or(
|
2019-04-12 16:28:40 -07:00
|
|
|
lax.shift_right_logical(bits, onp.array(nbits - nmant, lax.dtype(bits))),
|
2018-11-17 18:03:33 -08:00
|
|
|
onp.array(1., dtype).view(onp.uint32 if nbits == 32 else onp.uint64))
|
|
|
|
floats = lax.bitcast_convert_type(float_bits, dtype) - onp.array(1., dtype)
|
2018-11-21 14:31:25 -08:00
|
|
|
return lax.max(
|
|
|
|
minval,
|
|
|
|
lax.reshape(floats * (maxval - minval) + minval, shape))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
def randint(key, shape, minval, maxval, dtype=onp.int32):
|
|
|
|
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
|
|
|
shape: a tuple of nonnegative integers representing the shape.
|
|
|
|
minval: optional, a minimum (inclusive) value for the range (default 0).
|
|
|
|
maxval: optional, a maximum (exclusive) value for the range (default 1).
|
|
|
|
dtype: optional, an int dtype for the returned values (default int32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
2019-04-10 22:09:14 -07:00
|
|
|
return _randint(key, shape, minval, maxval, dtype)
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(1, 4))
|
|
|
|
def _randint(key, shape, minval, maxval, dtype=onp.int32):
|
2018-11-17 18:03:33 -08:00
|
|
|
if not onp.issubdtype(dtype, onp.integer):
|
|
|
|
raise TypeError("randint only accepts integer dtypes.")
|
|
|
|
|
|
|
|
dtype = xla_bridge.canonicalize_dtype(dtype)
|
|
|
|
minval = lax.convert_element_type(minval, dtype)
|
|
|
|
maxval = lax.convert_element_type(maxval, dtype)
|
|
|
|
nbits = onp.iinfo(dtype).bits
|
|
|
|
|
|
|
|
if nbits not in (32, 64):
|
|
|
|
raise TypeError("randint only accepts 32- or 64-bit dtypes.")
|
|
|
|
|
2019-01-12 12:54:22 -08:00
|
|
|
# if we don't have minval < maxval, just always return minval
|
|
|
|
# https://github.com/google/jax/issues/222
|
|
|
|
maxval = lax.max(lax.add(minval, onp.array(1, dtype)), maxval)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
# This algorithm is biased whenever (maxval - minval) is not a power of 2.
|
|
|
|
# We generate double the number of random bits required by the dtype so as to
|
|
|
|
# reduce that bias.
|
|
|
|
k1, k2 = split(key)
|
|
|
|
rbits = lambda key: _random_bits(key, nbits, shape)
|
|
|
|
higher_bits, lower_bits = rbits(k1), rbits(k2)
|
|
|
|
|
|
|
|
unsigned_dtype = onp.uint32 if nbits == 32 else onp.uint64
|
|
|
|
span = lax.convert_element_type(maxval - minval, unsigned_dtype)
|
|
|
|
|
|
|
|
# To compute a remainder operation on an integer that might have twice as many
|
|
|
|
# bits as we can represent in the native unsigned dtype, we compute a
|
|
|
|
# multiplier equal to 2**nbits % span (using that nbits is 32 or 64).
|
|
|
|
multiplier = lax.rem(onp.array(2**16, unsigned_dtype), span)
|
|
|
|
multiplier = lax.rem(lax.mul(multiplier, multiplier), span)
|
|
|
|
if nbits == 64:
|
|
|
|
multiplier = lax.rem(lax.mul(multiplier, multiplier), span)
|
|
|
|
|
|
|
|
random_offset = lax.add(lax.mul(lax.rem(higher_bits, span), multiplier),
|
|
|
|
lax.rem(lower_bits, span))
|
|
|
|
random_offset = lax.rem(random_offset, span)
|
|
|
|
return lax.add(minval, lax.convert_element_type(random_offset, dtype))
|
|
|
|
|
|
|
|
|
|
|
|
def shuffle(key, x, axis=0):
|
|
|
|
"""Shuffle the elements of an array uniformly at random along an axis.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
|
|
|
x: the array to be shuffled.
|
|
|
|
axis: optional, an int axis along which to shuffle (default 0).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A shuffled version of x.
|
|
|
|
"""
|
2019-04-10 22:09:14 -07:00
|
|
|
return _shuffle(key, x, axis)
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(2,))
|
|
|
|
def _shuffle(key, x, axis):
|
2018-11-17 18:03:33 -08:00
|
|
|
# On parallel architectures, Fisher-Yates is more expensive than doing
|
|
|
|
# multiple sorts. This algorithm is based on one developed and analyzed by
|
|
|
|
# tjablin@. We sort according to randomly-generated 32bit keys, but those keys
|
|
|
|
# may have collisions. If we repeat the process, using fresh 32bit keys for
|
|
|
|
# each sort, then whenever all pairs of elements have been assigned distinct
|
|
|
|
# keys at some iteration (or equivalently when the strings formed by
|
|
|
|
# concatenating the successive keys for each element are all distinct) then we
|
|
|
|
# are guaranteed to have a perfect sample (assuming that either the sort is
|
|
|
|
# stable or that any bias is not value-dependent). Since checking uniqueness
|
|
|
|
# at runtime may be expensive, we use a heuristic static stop criterion
|
|
|
|
# developed by tjablin@. See tensorflow/compiler/tf2xla/random_ops.cc for more
|
|
|
|
# info, and for the original implementation of this algorithm. See also
|
|
|
|
# Section 2 of http://people.csail.mit.edu/costis/6896sp11/lec5s.pdf for
|
|
|
|
# another analysis (where the keys are generated one bit at a time).
|
|
|
|
exponent = 3 # see tjablin@'s analysis for explanation of this parameter
|
2018-12-15 20:00:10 -08:00
|
|
|
uint32max = onp.iinfo(onp.uint32).max
|
|
|
|
num_rounds = int(onp.ceil(exponent * onp.log(x.size) / onp.log(uint32max)))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
for _ in range(num_rounds):
|
|
|
|
key, subkey = split(key)
|
|
|
|
sort_keys = _random_bits(subkey, 32, x.shape)
|
2018-11-19 07:43:23 -08:00
|
|
|
_, x = lax.sort_key_val(sort_keys, x, axis)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def normal(key, shape, dtype=onp.float32):
|
|
|
|
"""Sample standard normal random values with given shape and float dtype.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
|
|
|
shape: a tuple of nonnegative integers representing the shape.
|
|
|
|
dtype: optional, a float dtype for the returned values (default float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
2019-04-10 22:09:14 -07:00
|
|
|
return _normal(key, shape, dtype)
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(1, 2))
|
|
|
|
def _normal(key, shape, dtype):
|
2018-11-20 12:52:00 -08:00
|
|
|
lo = onp.nextafter(onp.array(-1., dtype), 0., dtype=dtype)
|
2018-11-17 18:03:33 -08:00
|
|
|
hi = onp.array(1., dtype)
|
|
|
|
u = uniform(key, shape, dtype, lo, hi)
|
|
|
|
return onp.array(onp.sqrt(2), dtype) * lax.erf_inv(u)
|
|
|
|
|
|
|
|
|
2019-04-21 21:22:50 -04:00
|
|
|
def bernoulli(key, p=onp.float32(0.5), shape=()):
|
2018-11-17 18:03:33 -08:00
|
|
|
"""Sample Bernoulli random values with given shape and mean.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
2019-04-21 21:22:50 -04:00
|
|
|
p: optional, an array-like broadcastable to `shape` for the mean of the
|
2018-11-17 18:03:33 -08:00
|
|
|
random variables (default 0.5).
|
|
|
|
shape: optional, a tuple of nonnegative integers representing the shape
|
|
|
|
(default scalar).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and boolean dtype.
|
|
|
|
"""
|
2019-04-21 21:22:50 -04:00
|
|
|
return _bernoulli(key, p, shape)
|
2019-04-10 22:09:14 -07:00
|
|
|
|
|
|
|
@partial(jit, static_argnums=(2,))
|
2019-04-21 21:22:50 -04:00
|
|
|
def _bernoulli(key, p, shape):
|
|
|
|
shape = shape or onp.shape(p)
|
|
|
|
if not onp.issubdtype(onp.float32, lax.dtype(p)):
|
|
|
|
p = lax.convert_element_type(p, onp.float32)
|
|
|
|
if onp.shape(p) != shape:
|
|
|
|
p = np.broadcast_to(p, shape)
|
|
|
|
return lax.lt(uniform(key, shape, lax.dtype(p)), p)
|
2019-03-28 17:59:42 -04:00
|
|
|
|
|
|
|
|
2019-04-21 16:25:20 -04:00
|
|
|
def beta(key, a, b, shape=(), dtype=onp.float32):
|
|
|
|
"""Sample Bernoulli random values with given shape and mean.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
|
|
|
a: an array-like broadcastable to `shape` and used as the shape parameter
|
|
|
|
alpha of the random variables.
|
|
|
|
b: an array-like broadcastable to `shape` and used as the shape parameter
|
|
|
|
beta of the random variables.
|
|
|
|
shape: optional, a tuple of nonnegative integers representing the shape
|
|
|
|
(default scalar).
|
|
|
|
dtype: optional, a float dtype for the returned values (default float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
|
|
|
return _beta(key, a, b, shape, dtype)
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(3, 4))
|
|
|
|
def _beta(key, a, b, shape, dtype):
|
|
|
|
a = lax.convert_element_type(a, dtype)
|
|
|
|
b = lax.convert_element_type(b, dtype)
|
|
|
|
shape = shape or lax.broadcast_shapes(np.shape(a), np.shape(b))
|
2019-04-21 16:43:18 -04:00
|
|
|
key_a, key_b = split(key)
|
2019-04-21 16:25:20 -04:00
|
|
|
gamma_a = gamma(key_a, a, shape, dtype)
|
|
|
|
gamma_b = gamma(key_b, b, shape, dtype)
|
|
|
|
return gamma_a / (gamma_a + gamma_b)
|
2019-03-28 17:59:42 -04:00
|
|
|
|
|
|
|
|
|
|
|
def cauchy(key, shape=(), dtype=onp.float32):
|
|
|
|
"""Sample Cauchy random values with given shape and float dtype.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
|
|
|
shape: optional, a tuple of nonnegative integers representing the shape
|
|
|
|
(default scalar).
|
|
|
|
dtype: optional, a float dtype for the returned values (default float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
2019-04-10 22:09:14 -07:00
|
|
|
return _cauchy(key, shape, dtype)
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(1, 2))
|
|
|
|
def _cauchy(key, shape, dtype):
|
2019-03-28 17:59:42 -04:00
|
|
|
u = uniform(key, shape, dtype)
|
|
|
|
pi = _constant_like(u, onp.pi)
|
2019-03-28 23:57:00 -04:00
|
|
|
return lax.tan(lax.mul(pi, lax.sub(u, _constant_like(u, 0.5))))
|
|
|
|
|
|
|
|
|
2019-04-22 11:55:02 -04:00
|
|
|
def dirichlet(key, alpha, shape=(), dtype=onp.float32):
|
|
|
|
"""Sample Cauchy random values with given shape and float dtype.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
|
|
|
alpha: an array-like with `alpha.shape[:-1]` broadcastable to `shape` and
|
|
|
|
used as the concentration parameter of the random variables.
|
|
|
|
shape: optional, a tuple of nonnegative integers representing the batch
|
|
|
|
shape (defaults to `alpha.shape[:-1]`).
|
|
|
|
dtype: optional, a float dtype for the returned values (default float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
|
|
|
return _dirichlet(key, alpha, shape, dtype)
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(2, 3))
|
|
|
|
def _dirichlet(key, alpha, shape, dtype):
|
|
|
|
alpha = asarray(alpha, dtype)
|
|
|
|
shape = shape or alpha.shape[:-1]
|
|
|
|
gamma_samples = gamma(key, alpha, shape + alpha.shape[-1:], dtype)
|
|
|
|
return gamma_samples / np.sum(gamma_samples, axis=-1, keepdims=True)
|
|
|
|
|
|
|
|
|
2019-03-28 23:57:00 -04:00
|
|
|
def exponential(key, shape=(), dtype=onp.float32):
|
|
|
|
"""Sample Exponential random values with given shape and float dtype.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
|
|
|
shape: optional, a tuple of nonnegative integers representing the shape
|
|
|
|
(default scalar).
|
|
|
|
dtype: optional, a float dtype for the returned values (default float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
2019-04-10 22:09:14 -07:00
|
|
|
return _exponential(key, shape, dtype)
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(1, 2))
|
|
|
|
def _exponential(key, shape, dtype):
|
2019-03-28 23:57:00 -04:00
|
|
|
u = uniform(key, shape, dtype)
|
|
|
|
# taking 1 - u to move the domain of log to (0, 1] instead of [0, 1)
|
|
|
|
return lax.neg(lax.log(lax.sub(_constant_like(u, 1), u)))
|
|
|
|
|
|
|
|
|
2019-03-30 18:07:34 -04:00
|
|
|
def _gamma_one(key, alpha):
|
|
|
|
# Ref: A simple method for generating gamma variables, George Marsaglia and Wai Wan Tsang
|
2019-03-31 23:54:31 -04:00
|
|
|
# The algorithm can also be founded in:
|
|
|
|
# https://en.wikipedia.org/wiki/Gamma_distribution#Generating_gamma-distributed_random_variables
|
|
|
|
zero = _constant_like(alpha, 0)
|
|
|
|
one = _constant_like(alpha, 1)
|
|
|
|
one_over_two = _constant_like(alpha, 0.5)
|
|
|
|
one_over_three = _constant_like(alpha, 1. / 3.)
|
|
|
|
squeeze_const = _constant_like(alpha, 0.0331)
|
2019-04-12 16:28:40 -07:00
|
|
|
dtype = lax.dtype(alpha)
|
2019-03-31 23:54:31 -04:00
|
|
|
|
2019-03-30 18:07:34 -04:00
|
|
|
key, subkey = split(key)
|
2019-03-31 23:54:31 -04:00
|
|
|
# for alpha < 1, we boost alpha to alpha + 1 and get a sample according to
|
|
|
|
# Gamma(alpha) ~ Gamma(alpha+1) * Uniform()^(1 / alpha)
|
|
|
|
boost = lax.select(lax.ge(alpha, one),
|
|
|
|
one,
|
2019-04-01 00:32:42 -04:00
|
|
|
lax.pow(uniform(subkey, (), dtype=dtype), lax.div(one, alpha)))
|
2019-03-31 23:54:31 -04:00
|
|
|
alpha = lax.select(lax.ge(alpha, one), alpha, lax.add(alpha, one))
|
2019-03-30 18:07:34 -04:00
|
|
|
|
2019-03-31 23:54:31 -04:00
|
|
|
d = lax.sub(alpha, one_over_three)
|
|
|
|
c = lax.div(one_over_three, lax.pow(d, one_over_two))
|
2019-03-30 18:07:34 -04:00
|
|
|
|
|
|
|
def _cond_fn(kXVU):
|
|
|
|
_, X, V, U = kXVU
|
2019-04-12 16:28:40 -07:00
|
|
|
# TODO: use lax.cond when its batching rule is supported
|
2019-03-31 23:54:31 -04:00
|
|
|
# The reason is to avoid evaluating second condition which involves log+log
|
2019-03-30 18:07:34 -04:00
|
|
|
# if the first condition is satisfied
|
2019-03-31 23:54:31 -04:00
|
|
|
cond = lax.bitwise_and(lax.ge(U, lax.sub(one, lax.mul(squeeze_const, lax.mul(X, X)))),
|
|
|
|
lax.ge(lax.log(U), lax.add(lax.mul(X, one_over_two),
|
|
|
|
lax.mul(d, lax.add(lax.sub(one, V),
|
|
|
|
lax.log(V))))))
|
|
|
|
return lax.bitwise_or(lax.le(V, zero), cond)
|
2019-03-30 18:07:34 -04:00
|
|
|
|
|
|
|
def _body_fn(kXVU):
|
|
|
|
key = kXVU[0]
|
2019-03-31 23:54:31 -04:00
|
|
|
key, x_key, U_key = split(key, 3)
|
2019-04-01 00:32:42 -04:00
|
|
|
x = normal(x_key, (), dtype=dtype)
|
2019-03-31 23:54:31 -04:00
|
|
|
v = lax.add(one, lax.mul(x, c))
|
|
|
|
X = lax.mul(x, x)
|
|
|
|
V = lax.mul(lax.mul(v, v), v)
|
2019-04-01 00:32:42 -04:00
|
|
|
U = uniform(U_key, (), dtype=dtype)
|
2019-03-30 18:07:34 -04:00
|
|
|
return key, X, V, U
|
|
|
|
|
2019-03-31 23:54:31 -04:00
|
|
|
# initial state is chosen such that _cond_fn will return True
|
2019-04-12 16:28:40 -07:00
|
|
|
_, _, V, _ = lax.while_loop(
|
2019-04-10 13:00:31 -07:00
|
|
|
_cond_fn, _body_fn, (key, zero, _constant_like(alpha, -1), zero))
|
2019-03-31 23:54:31 -04:00
|
|
|
z = lax.mul(lax.mul(d, V), boost)
|
|
|
|
return lax.select(lax.eq(z, zero), onp.finfo(z.dtype).tiny, z)
|
2019-03-30 18:07:34 -04:00
|
|
|
|
|
|
|
|
|
|
|
def gamma(key, a, shape=(), dtype=onp.float32):
|
|
|
|
"""Sample Gamma random values with given shape and float dtype.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
|
|
|
a: an array-like broadcastable to `shape` and used as the shape parameter
|
|
|
|
of the random variables.
|
|
|
|
shape: optional, a tuple of nonnegative integers representing the shape
|
|
|
|
(default scalar).
|
|
|
|
dtype: optional, a float dtype for the returned values (default float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
2019-04-12 07:15:41 -07:00
|
|
|
return _gamma(key, a, shape, dtype)
|
2019-04-10 13:00:31 -07:00
|
|
|
|
2019-04-12 07:15:41 -07:00
|
|
|
@partial(jit, static_argnums=(2, 3))
|
|
|
|
def _gamma(key, a, shape=(), dtype=onp.float32):
|
2019-03-30 18:07:34 -04:00
|
|
|
a = lax.convert_element_type(a, dtype)
|
|
|
|
shape = shape or onp.shape(a)
|
|
|
|
if onp.shape(a) != shape:
|
|
|
|
a = np.broadcast_to(a, shape)
|
|
|
|
alphas = np.reshape(a, -1)
|
|
|
|
keys = split(key, onp.size(alphas))
|
|
|
|
samples = vmap(_gamma_one)(keys, alphas)
|
|
|
|
return np.reshape(samples, shape)
|
|
|
|
|
|
|
|
|
2019-04-21 16:25:20 -04:00
|
|
|
def gumbel(key, shape=(), dtype=onp.float32):
|
|
|
|
"""Sample Gumbel random values with given shape and float dtype.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
|
|
|
shape: optional, a tuple of nonnegative integers representing the shape
|
|
|
|
(default scalar).
|
|
|
|
dtype: optional, a float dtype for the returned values (default float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
|
|
|
return _gumbel(key, shape, dtype)
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(1, 2))
|
|
|
|
def _gumbel(key, shape, dtype):
|
|
|
|
return -np.log(-np.log(uniform(key, shape, dtype)))
|
|
|
|
|
|
|
|
|
2019-03-28 23:57:00 -04:00
|
|
|
def laplace(key, shape=(), dtype=onp.float32):
|
|
|
|
"""Sample Laplace random values with given shape and float dtype.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
|
|
|
shape: optional, a tuple of nonnegative integers representing the shape
|
|
|
|
(default scalar).
|
|
|
|
dtype: optional, a float dtype for the returned values (default float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
2019-04-10 22:09:14 -07:00
|
|
|
return _laplace(key, shape, dtype)
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(1, 2))
|
|
|
|
def _laplace(key, shape, dtype):
|
2019-03-28 23:57:00 -04:00
|
|
|
u = uniform(key, shape, dtype, minval=-1., maxval=1.)
|
|
|
|
return lax.mul(lax.sign(u), lax.log1p(lax.neg(lax.abs(u))))
|
2019-03-30 16:34:20 -04:00
|
|
|
|
|
|
|
|
|
|
|
def pareto(key, b, shape=(), dtype=onp.float32):
|
|
|
|
"""Sample Pareto random values with given shape and float dtype.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
|
|
|
b: an array-like broadcastable to `shape` and used as the shape parameter
|
|
|
|
of the random variables.
|
|
|
|
shape: optional, a tuple of nonnegative integers representing the shape
|
|
|
|
(default scalar).
|
|
|
|
dtype: optional, a float dtype for the returned values (default float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
2019-04-10 22:09:14 -07:00
|
|
|
return _pareto(key, b, shape, dtype)
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(2, 3))
|
|
|
|
def _pareto(key, b, shape, dtype):
|
2019-03-30 16:34:20 -04:00
|
|
|
b = lax.convert_element_type(b, dtype)
|
|
|
|
shape = shape or onp.shape(b)
|
|
|
|
if onp.shape(b) != shape:
|
|
|
|
b = np.broadcast_to(b, shape)
|
|
|
|
e = exponential(key, shape, dtype)
|
|
|
|
return lax.exp(lax.div(e, b))
|
2019-04-02 10:55:03 +01:00
|
|
|
|
|
|
|
|
2019-04-21 16:25:20 -04:00
|
|
|
def t(key, df, shape=(), dtype=onp.float32):
|
|
|
|
"""Sample Student's t random values with given shape and float dtype.
|
2019-04-02 10:55:03 +01:00
|
|
|
|
|
|
|
Args:
|
|
|
|
key: a PRNGKey used as the random key.
|
2019-04-21 16:25:20 -04:00
|
|
|
df: an array-like broadcastable to `shape` and used as the shape parameter
|
|
|
|
of the random variables.
|
2019-04-02 10:55:03 +01:00
|
|
|
shape: optional, a tuple of nonnegative integers representing the shape
|
|
|
|
(default scalar).
|
|
|
|
dtype: optional, a float dtype for the returned values (default float32).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random array with the specified shape and dtype.
|
|
|
|
"""
|
2019-04-21 16:25:20 -04:00
|
|
|
return _t(key, df, shape, dtype)
|
2019-04-10 22:09:14 -07:00
|
|
|
|
2019-04-21 16:25:20 -04:00
|
|
|
@partial(jit, static_argnums=(2, 3))
|
|
|
|
def _t(key, df, shape, dtype):
|
|
|
|
df = lax.convert_element_type(df, dtype)
|
|
|
|
shape = shape or onp.shape(df)
|
2019-04-21 16:43:18 -04:00
|
|
|
key_n, key_g = split(key)
|
2019-04-21 16:25:20 -04:00
|
|
|
n = normal(key_n, shape, dtype)
|
|
|
|
two = _constant_like(n, 2)
|
|
|
|
half_df = lax.div(df, two)
|
2019-04-21 16:43:18 -04:00
|
|
|
g = gamma(key_n, half_df, shape, dtype)
|
2019-04-21 16:25:20 -04:00
|
|
|
return n * np.sqrt(half_df / g)
|