rocm_jax/jax/random.py

1159 lines
43 KiB
Python
Raw Normal View History

2018-11-17 18:03:33 -08:00
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2019-02-13 19:42:47 -08:00
"""JAX pseudo-random number generators (PRNGs).
The JAX PRNG system is based on "Parallel random numbers: as easy as 1, 2, 3"
(Salmon et al. 2011). For details on the design and its motivation, see:
https://github.com/google/jax/blob/master/design_notes/prng.md
"""
2018-11-17 18:03:33 -08:00
from functools import partial
from typing import Optional, Sequence, Union
2018-11-17 18:03:33 -08:00
import numpy as onp
from . import lax
from . import numpy as np
from . import dtypes
from .api import jit, vmap
from .numpy.lax_numpy import _constant_like, asarray
from jax.lib import xla_bridge
from jax.lib import cuda_prng
from jax import core
from jax import abstract_arrays
from jax.numpy.linalg import cholesky
from jax.scipy.special import logit
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import xla
2019-12-26 22:43:06 -05:00
from jax.util import prod
2018-11-17 18:03:33 -08:00
def PRNGKey(seed: int) -> np.ndarray:
2019-02-16 23:31:27 +09:00
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
2019-02-13 19:42:47 -08:00
Args:
2019-02-13 20:05:15 -08:00
seed: a 64- or 32-bit integer used as the value of the key.
2019-02-13 19:42:47 -08:00
Returns:
A PRNG key, which is modeled as an array of shape (2,) and dtype uint32. The
key is constructed from a 64-bit seed by effectively bit-casting to a pair
of uint32 values (or from a 32-bit seed by first padding out with zeros).
"""
if onp.shape(seed):
raise TypeError("PRNGKey seed must be a scalar.")
convert = lambda k: lax.reshape(lax.convert_element_type(k, onp.uint32), [1])
if isinstance(seed, (int, onp.ndarray)):
# Special handling of raw integer values, which may have be 64bit even
# when jax_enable_x64=False and we don't want to drop the top 32 bits
k1 = convert(onp.bitwise_and(onp.right_shift(seed, 32), 0xFFFFFFFF))
else:
k1 = convert(lax.shift_right_logical(seed, lax._const(seed, 32)))
k2 = convert(np.bitwise_and(seed, 0xFFFFFFFF))
return lax.concatenate([k1, k2], 0)
def _is_prng_key(key: np.ndarray) -> bool:
try:
return key.shape == (2,) and key.dtype == onp.uint32
except AttributeError:
return False
2018-11-17 18:03:33 -08:00
### utilities
def _make_rotate_left(dtype):
if not np.issubdtype(dtype, onp.integer):
2018-11-17 18:03:33 -08:00
raise TypeError("_rotate_left only accepts integer dtypes.")
nbits = onp.array(np.iinfo(dtype).bits, dtype)
2018-11-17 18:03:33 -08:00
def _rotate_left(x, d):
if lax.dtype(d) != lax.dtype(x):
2018-11-17 18:03:33 -08:00
d = lax.convert_element_type(d, x.dtype)
return lax.shift_left(x, d) | lax.shift_right_logical(x, nbits - d)
2018-11-17 18:03:33 -08:00
return _rotate_left
def _bit_stats(bits):
"""This is a debugging function to compute the statistics of bit fields."""
return onp.array([list(map(int, onp.binary_repr(x, 64))) for x in bits]).mean(0)
2018-11-17 18:03:33 -08:00
### hash function and split
def _threefry2x32_abstract_eval(*args):
if any(a.dtype != np.uint32 for a in args):
raise TypeError("Arguments to threefry2x32 must have uint32 type, got {}"
.format(args))
if all(isinstance(arg, abstract_arrays.ShapedArray) for arg in args):
shape = lax._broadcasting_shape_rule(*args)
aval = abstract_arrays.ShapedArray(shape, np.dtype(np.uint32))
else:
aval = abstract_arrays.UnshapedArray(np.dtype(np.uint32))
return (aval,) * 2
2018-11-17 18:03:33 -08:00
def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True):
2018-11-17 18:03:33 -08:00
"""Apply the Threefry 2x32 hash.
Args:
keypair: a pair of 32bit unsigned integers used for the key.
count: an array of dtype uint32 used for the counts.
Returns:
An array of dtype uint32 with the same shape as `count`.
"""
x = [x1, x2]
rotate_left = _make_rotate_left(onp.uint32)
2018-11-17 18:03:33 -08:00
def apply_round(v, rot):
v = v[:]
v[0] = v[0] + v[1]
v[1] = rotate_left(v[1], rot)
v[1] = v[0] ^ v[1]
return v
rotations = [onp.array([13, 15, 26, 6], dtype=onp.uint32),
onp.array([17, 29, 16, 24], dtype=onp.uint32)]
ks = [key1, key2, key1 ^ key2 ^ onp.uint32(0x1BD11BDA)]
2018-11-17 18:03:33 -08:00
x[0] = x[0] + ks[0]
x[1] = x[1] + ks[1]
if use_rolled_loops:
def rotate_list(xs): return xs[1:] + xs[:1]
def step(i, state):
x, ks, rotations = state
for r in rotations[0]:
x = apply_round(x, r)
new_x = [x[0] + ks[0], x[1] + ks[1] + asarray(i + 1, dtype=onp.uint32)]
return new_x, rotate_list(ks), rotate_list(rotations)
x, _, _ = lax.fori_loop(0, 5, step, (x, rotate_list(ks), rotations))
else:
for r in rotations[0]:
x = apply_round(x, r)
x[0] = x[0] + ks[1]
x[1] = x[1] + ks[2] + onp.uint32(1)
for r in rotations[1]:
x = apply_round(x, r)
x[0] = x[0] + ks[2]
x[1] = x[1] + ks[0] + onp.uint32(2)
for r in rotations[0]:
x = apply_round(x, r)
x[0] = x[0] + ks[0]
x[1] = x[1] + ks[1] + onp.uint32(3)
for r in rotations[1]:
x = apply_round(x, r)
x[0] = x[0] + ks[1]
x[1] = x[1] + ks[2] + onp.uint32(4)
for r in rotations[0]:
x = apply_round(x, r)
x[0] = x[0] + ks[2]
x[1] = x[1] + ks[0] + onp.uint32(5)
return tuple(x)
def _threefry2x32_gpu_translation_rule(c, k1, k2, x1, x2):
shape = lax.broadcast_shapes(
c.GetShape(k1).dimensions(), c.GetShape(k2).dimensions(),
c.GetShape(x1).dimensions(), c.GetShape(x2).dimensions())
rank = len(shape)
def _broadcast(x):
ndims = c.GetShape(x).rank()
return c.BroadcastInDim(x, shape, tuple(range(rank - ndims, rank)))
return cuda_prng.threefry2x32(
c, (_broadcast(k1), _broadcast(k2)), (_broadcast(x1), _broadcast(x2)))
threefry2x32_p = core.Primitive("threefry2x32")
threefry2x32_p.multiple_results = True
threefry2x32_p.def_impl(partial(xla.apply_primitive, threefry2x32_p))
threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval)
batching.defbroadcasting(threefry2x32_p)
xla.translations[threefry2x32_p] = xla.lower_fun(
change the xla representation of JAX's unit (#2416) * change the xla representation of JAX's unit Previously the representation of JAX's unit value (a sentinel / placeholder) was an empty tuple, but by changing the representation to something else we can further reduce our dependence on runtime tuples. This commit makes the representation fairly easy to change. There are three functions in xla.py that define the representation. Here are versions that would keep the old XLA representation as an empty tuple: ``` def _make_unit(c): return c.Tuple() def _make_abstract_unit(_): return xc.Shape.tuple_shape(()) def _device_put_unit(_, device): return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device)) ``` The new representation is as a trivial array. An alternative representation would be nothing at all: we don't need to generate XLA computations that have representations of JAX units. While that alterntaive is probably the best choice, it seemed like it would require a bit more refactoring/bookkeeping (e.g. to allow XLA computations to have a smaller number of outputs than the corresponding JAX function), and would also mean the XLA representation would be a step further removed from the jaxpr representation. So I stuck with a trivial array for now. The mapping from JAX types to XLA types need not be invertible. However, XLA translation rules currently don't take as arguments the corresponding JAX types (abstract values), and there were a few cases where we relied on checking whether an argument's XLA type was that of an empty tuple so as to determine if we were effectively operating on a JAX unit. In particular, the AD-related primitive add_jaxvals_p could in principle add two units, and get lowered to an XLA addition on the unit representation. Previously, the translation rule for add_jaxvals_p checked the XLA type so that adding two empty tuples didn't produce any XLA operation; now it adds its inputs, and so if unit is represented as a trivial array we could be inserting trivial scalar adds where we had none before. However, if that case is ever possible, it doesn't come up in our tests (which I checked by keeping the representation as an empty tuple and then asserting an XLA tuple type is never seen by that translation rule). * add comment about JAX<->XLA array types assumption
2020-03-14 12:33:14 -07:00
partial(_threefry2x32_lowering, use_rolled_loops=False))
xla.backend_specific_translations['cpu'][threefry2x32_p] = xla.lower_fun(
change the xla representation of JAX's unit (#2416) * change the xla representation of JAX's unit Previously the representation of JAX's unit value (a sentinel / placeholder) was an empty tuple, but by changing the representation to something else we can further reduce our dependence on runtime tuples. This commit makes the representation fairly easy to change. There are three functions in xla.py that define the representation. Here are versions that would keep the old XLA representation as an empty tuple: ``` def _make_unit(c): return c.Tuple() def _make_abstract_unit(_): return xc.Shape.tuple_shape(()) def _device_put_unit(_, device): return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device)) ``` The new representation is as a trivial array. An alternative representation would be nothing at all: we don't need to generate XLA computations that have representations of JAX units. While that alterntaive is probably the best choice, it seemed like it would require a bit more refactoring/bookkeeping (e.g. to allow XLA computations to have a smaller number of outputs than the corresponding JAX function), and would also mean the XLA representation would be a step further removed from the jaxpr representation. So I stuck with a trivial array for now. The mapping from JAX types to XLA types need not be invertible. However, XLA translation rules currently don't take as arguments the corresponding JAX types (abstract values), and there were a few cases where we relied on checking whether an argument's XLA type was that of an empty tuple so as to determine if we were effectively operating on a JAX unit. In particular, the AD-related primitive add_jaxvals_p could in principle add two units, and get lowered to an XLA addition on the unit representation. Previously, the translation rule for add_jaxvals_p checked the XLA type so that adding two empty tuples didn't produce any XLA operation; now it adds its inputs, and so if unit is represented as a trivial array we could be inserting trivial scalar adds where we had none before. However, if that case is ever possible, it doesn't come up in our tests (which I checked by keeping the representation as an empty tuple and then asserting an XLA tuple type is never seen by that translation rule). * add comment about JAX<->XLA array types assumption
2020-03-14 12:33:14 -07:00
partial(_threefry2x32_lowering, use_rolled_loops=True))
if cuda_prng:
xla.backend_specific_translations['gpu'][threefry2x32_p] = \
_threefry2x32_gpu_translation_rule
@jit
def threefry_2x32(keypair, count):
"""Apply the Threefry 2x32 hash.
Args:
keypair: a pair of 32bit unsigned integers used for the key.
count: an array of dtype uint32 used for the counts.
Returns:
An array of dtype uint32 with the same shape as `count`.
"""
key1, key2 = keypair
if not lax.dtype(key1) == lax.dtype(key2) == lax.dtype(count) == onp.uint32:
msg = "threefry_2x32 requires uint32 arguments, got {}"
raise TypeError(msg.format([lax.dtype(x) for x in [key1, key2, count]]))
odd_size = count.size % 2
if odd_size:
x = list(np.split(np.concatenate([count.ravel(), onp.uint32([0])]), 2))
else:
x = list(np.split(count.ravel(), 2))
x = threefry2x32_p.bind(key1, key2, x[0], x[1])
out = np.concatenate(x)
2018-11-17 18:03:33 -08:00
assert out.dtype == onp.uint32
return lax.reshape(out[:-1] if odd_size else out, count.shape)
def split(key: np.ndarray, num: int = 2) -> np.ndarray:
"""Splits a PRNG key into `num` new keys by adding a leading axis.
2018-11-17 18:03:33 -08:00
Args:
key: a PRNGKey (an array with shape (2,) and dtype uint32).
2018-11-17 18:03:33 -08:00
num: optional, a positive integer indicating the number of keys to produce
(default 2).
Returns:
An array with shape (num, 2) and dtype uint32 representing `num` new keys.
2018-11-17 18:03:33 -08:00
"""
return _split(key, num)
@partial(jit, static_argnums=(1,))
def _split(key, num):
counts = lax.tie_in(key, lax.iota(onp.uint32, num * 2))
return lax.reshape(threefry_2x32(key, counts), (num, 2))
2018-11-17 18:03:33 -08:00
def fold_in(key, data):
"""Folds in data to a PRNG key to form a new PRNG key.
Args:
key: a PRNGKey (an array with shape (2,) and dtype uint32).
data: a 32bit integer representing data to be folded in to the key.
Returns:
A new PRNGKey that is a deterministic function of the inputs and is
statistically safe for producing a stream of new pseudo-random values.
"""
return _fold_in(key, data)
@jit
def _fold_in(key, data):
key2 = lax.tie_in(key, PRNGKey(data))
return threefry_2x32(key, key2)
2018-11-17 18:03:33 -08:00
def _random_bits(key, bit_width, shape):
"""Sample uniform random bits of given width and shape using PRNG key."""
2019-02-13 19:42:47 -08:00
if not _is_prng_key(key):
raise TypeError("_random_bits got invalid prng key.")
2018-11-17 18:03:33 -08:00
if bit_width not in (32, 64):
raise TypeError("requires 32- or 64-bit field width.")
max_count = (bit_width // 32) * onp.prod(shape)
if max_count >= np.iinfo(onp.uint32).max:
2018-11-17 18:03:33 -08:00
# TODO(mattjj): just split the key here
raise TypeError("requesting more random bits than a single call provides.")
counts = lax.tie_in(key, lax.iota(onp.uint32, max_count))
bits = threefry_2x32(key, counts)
2018-11-17 18:03:33 -08:00
if bit_width == 64:
bits = [lax.convert_element_type(x, onp.uint64) for x in np.split(bits, 2)]
bits = lax.shift_left(bits[0], onp.uint64(32)) | bits[1]
2018-11-17 18:03:33 -08:00
return lax.reshape(bits, shape)
### random samplers
def _check_shape(name, shape, *param_shapes):
try:
shape = tuple(map(int, shape))
except TypeError as err:
msg = "{} requires a concrete tuple of integers as shape argument, got {}."
raise ValueError(msg.format(name, shape)) from err
if param_shapes:
shape_ = lax.broadcast_shapes(shape, *param_shapes)
if shape != shape_:
msg = ("{} parameter shapes must be broadcast-compatible with shape "
"argument, and the result of broadcasting the shapes must equal "
"the shape argument, but got result {} for shape argument {}.")
raise ValueError(msg.format(name, shape_, shape))
def uniform(key: np.ndarray,
shape: Sequence[int] = (),
dtype: onp.dtype = onp.float64,
2020-04-13 13:24:08 +01:00
minval: Union[float, np.ndarray] = 0.,
maxval: Union[float, np.ndarray] = 1.) -> np.ndarray:
2018-11-17 18:03:33 -08:00
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
Args:
key: a PRNGKey used as the random key.
shape: optional, a tuple of nonnegative integers representing the result
shape. Default ().
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
2018-11-17 18:03:33 -08:00
minval: optional, a minimum (inclusive) value for the range (default 0).
maxval: optional, a maximum (exclusive) value for the range (default 1).
Returns:
A random array with the specified shape and dtype.
"""
dtype = dtypes.canonicalize_dtype(dtype)
shape = abstract_arrays.canonicalize_shape(shape)
return _uniform(key, shape, dtype, minval, maxval)
@partial(jit, static_argnums=(1, 2))
def _uniform(key, shape, dtype, minval, maxval):
_check_shape("uniform", shape)
if not np.issubdtype(dtype, onp.floating):
2018-11-17 18:03:33 -08:00
raise TypeError("uniform only accepts floating point dtypes.")
minval = lax.convert_element_type(minval, dtype)
maxval = lax.convert_element_type(maxval, dtype)
finfo = np.finfo(dtype)
2018-11-17 18:03:33 -08:00
nbits, nmant = finfo.bits, finfo.nmant
if nbits not in (32, 64):
raise TypeError("uniform only accepts 32- or 64-bit dtypes.")
bits = _random_bits(key, nbits, shape)
# The strategy here is to randomize only the mantissa bits with an exponent of
# 1 (after applying the bias), then shift and scale to the desired range. The
# bit-level transformation we use relies on Numpy and XLA having bit-for-bit
# equivalent float representations, which might not be true on all platforms.
float_bits = lax.bitwise_or(
lax.shift_right_logical(bits, onp.array(nbits - nmant, lax.dtype(bits))),
2018-11-17 18:03:33 -08:00
onp.array(1., dtype).view(onp.uint32 if nbits == 32 else onp.uint64))
floats = lax.bitcast_convert_type(float_bits, dtype) - onp.array(1., dtype)
return lax.max(
minval,
lax.reshape(floats * (maxval - minval) + minval, shape))
2018-11-17 18:03:33 -08:00
def randint(key: np.ndarray,
shape: Sequence[int],
minval: Union[int, np.ndarray],
maxval: Union[int, np.ndarray],
dtype: onp.dtype = onp.int64):
2018-11-17 18:03:33 -08:00
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
Args:
key: a PRNGKey used as the random key.
shape: a tuple of nonnegative integers representing the shape.
2019-05-16 10:32:28 -07:00
minval: int or array of ints broadcast-compatible with ``shape``, a minimum
(inclusive) value for the range.
maxval: int or array of ints broadcast-compatible with ``shape``, a maximum
2019-05-16 10:32:28 -07:00
(exclusive) value for the range.
dtype: optional, an int dtype for the returned values (default int64 if
jax_enable_x64 is true, otherwise int32).
2018-11-17 18:03:33 -08:00
Returns:
A random array with the specified shape and dtype.
"""
dtype = dtypes.canonicalize_dtype(dtype)
shape = abstract_arrays.canonicalize_shape(shape)
return _randint(key, shape, minval, maxval, dtype)
@partial(jit, static_argnums=(1, 4))
def _randint(key, shape, minval, maxval, dtype):
_check_shape("randint", shape, onp.shape(minval), onp.shape(maxval))
if not np.issubdtype(dtype, onp.integer):
2018-11-17 18:03:33 -08:00
raise TypeError("randint only accepts integer dtypes.")
minval = lax.convert_element_type(minval, dtype)
maxval = lax.convert_element_type(maxval, dtype)
nbits = np.iinfo(dtype).bits
2018-11-17 18:03:33 -08:00
if nbits not in (32, 64):
raise TypeError("randint only accepts 32- or 64-bit dtypes.")
# if we don't have minval < maxval, just always return minval
# https://github.com/google/jax/issues/222
maxval = lax.max(lax.add(minval, onp.array(1, dtype)), maxval)
2018-11-17 18:03:33 -08:00
# This algorithm is biased whenever (maxval - minval) is not a power of 2.
# We generate double the number of random bits required by the dtype so as to
# reduce that bias.
k1, k2 = split(key)
rbits = lambda key: _random_bits(key, nbits, shape)
higher_bits, lower_bits = rbits(k1), rbits(k2)
unsigned_dtype = onp.uint32 if nbits == 32 else onp.uint64
span = lax.convert_element_type(maxval - minval, unsigned_dtype)
# To compute a remainder operation on an integer that might have twice as many
# bits as we can represent in the native unsigned dtype, we compute a
# multiplier equal to 2**nbits % span (using that nbits is 32 or 64).
multiplier = lax.rem(onp.array(2**16, unsigned_dtype), span)
multiplier = lax.rem(lax.mul(multiplier, multiplier), span)
if nbits == 64:
multiplier = lax.rem(lax.mul(multiplier, multiplier), span)
random_offset = lax.add(lax.mul(lax.rem(higher_bits, span), multiplier),
lax.rem(lower_bits, span))
random_offset = lax.rem(random_offset, span)
return lax.add(minval, lax.convert_element_type(random_offset, dtype))
def shuffle(key: np.ndarray, x: np.ndarray, axis: int = 0) -> np.ndarray:
2018-11-17 18:03:33 -08:00
"""Shuffle the elements of an array uniformly at random along an axis.
Args:
key: a PRNGKey used as the random key.
x: the array to be shuffled.
axis: optional, an int axis along which to shuffle (default 0).
Returns:
A shuffled version of x.
"""
return _shuffle(key, x, axis)
@partial(jit, static_argnums=(2,))
def _shuffle(key, x, axis):
2018-11-17 18:03:33 -08:00
# On parallel architectures, Fisher-Yates is more expensive than doing
# multiple sorts. This algorithm is based on one developed and analyzed by
# tjablin@. We sort according to randomly-generated 32bit keys, but those keys
# may have collisions. If we repeat the process, using fresh 32bit keys for
# each sort, then whenever all pairs of elements have been assigned distinct
# keys at some iteration (or equivalently when the strings formed by
# concatenating the successive keys for each element are all distinct) then we
# are guaranteed to have a perfect sample (assuming that either the sort is
# stable or that any bias is not value-dependent). Since checking uniqueness
# at runtime may be expensive, we use a heuristic static stop criterion
# developed by tjablin@. See tensorflow/compiler/tf2xla/random_ops.cc for more
# info, and for the original implementation of this algorithm. See also
# Section 2 of http://people.csail.mit.edu/costis/6896sp11/lec5s.pdf for
# another analysis (where the keys are generated one bit at a time).
exponent = 3 # see tjablin@'s analysis for explanation of this parameter
uint32max = np.iinfo(onp.uint32).max
num_rounds = int(onp.ceil(exponent * onp.log(x.size) / onp.log(uint32max)))
2018-11-17 18:03:33 -08:00
for _ in range(num_rounds):
key, subkey = split(key)
sort_keys = _random_bits(subkey, 32, x.shape)
2018-11-19 07:43:23 -08:00
_, x = lax.sort_key_val(sort_keys, x, axis)
2018-11-17 18:03:33 -08:00
return x
def normal(key: np.ndarray,
shape: Sequence[int] = (),
dtype: onp.dtype = onp.float64) -> np.ndarray:
2018-11-17 18:03:33 -08:00
"""Sample standard normal random values with given shape and float dtype.
Args:
key: a PRNGKey used as the random key.
shape: optional, a tuple of nonnegative integers representing the result
shape. Default ().
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
2018-11-17 18:03:33 -08:00
Returns:
A random array with the specified shape and dtype.
"""
dtype = dtypes.canonicalize_dtype(dtype)
shape = abstract_arrays.canonicalize_shape(shape)
return _normal(key, shape, dtype)
@partial(jit, static_argnums=(1, 2))
def _normal(key, shape, dtype):
_check_shape("normal", shape)
lo = onp.nextafter(onp.array(-1., dtype), 0., dtype=dtype)
2018-11-17 18:03:33 -08:00
hi = onp.array(1., dtype)
u = uniform(key, shape, dtype, lo, hi)
return onp.array(onp.sqrt(2), dtype) * lax.erf_inv(u)
def multivariate_normal(key: np.ndarray,
mean: np.ndarray,
cov: np.ndarray,
shape: Optional[Sequence[int]] = None,
dtype: onp.dtype = onp.float64) -> np.ndarray:
"""Sample multivariate normal random values with given mean and covariance.
Args:
key: a PRNGKey used as the random key.
mean: a mean vector of shape ``(..., n)``.
cov: a positive definite covariance matrix of shape ``(..., n, n)``. The
2019-10-20 21:14:48 +00:00
batch shape ``...`` must be broadcast-compatible with that of ``mean``.
shape: optional, a tuple of nonnegative integers specifying the result
batch shape; that is, the prefix of the result shape excluding the last
2019-10-20 21:14:48 +00:00
axis. Must be broadcast-compatible with ``mean.shape[:-1]`` and
``cov.shape[:-2]``. The default (None) produces a result batch shape by
broadcasting together the batch shapes of ``mean`` and ``cov``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified dtype and shape given by
2019-10-20 21:14:48 +00:00
``shape + mean.shape[-1:]`` if ``shape`` is not None, or else
``broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]``.
"""
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = abstract_arrays.canonicalize_shape(shape)
return _multivariate_normal(key, mean, cov, shape, dtype)
@partial(jit, static_argnums=(3, 4))
def _multivariate_normal(key, mean, cov, shape, dtype):
if not onp.ndim(mean) >= 1:
msg = "multivariate_normal requires mean.ndim >= 1, got mean.ndim == {}"
raise ValueError(msg.format(onp.ndim(mean)))
if not onp.ndim(cov) >= 2:
msg = "multivariate_normal requires cov.ndim >= 2, got cov.ndim == {}"
raise ValueError(msg.format(onp.ndim(cov)))
n = mean.shape[-1]
if onp.shape(cov)[-2:] != (n, n):
msg = ("multivariate_normal requires cov.shape == (..., n, n) for n={n}, "
"but got cov.shape == {shape}.")
raise ValueError(msg.format(n=n, shape=onp.shape(cov)))
2019-10-20 21:14:48 +00:00
if shape is None:
shape = lax.broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
else:
_check_shape("normal", shape, mean.shape[:-1], mean.shape[:-2])
chol_factor = cholesky(cov)
normal_samples = normal(key, shape + mean.shape[-1:], dtype)
return mean + np.tensordot(normal_samples, chol_factor, [-1, 1])
2019-09-23 16:15:41 -04:00
def truncated_normal(key: np.ndarray,
lower: Union[float, np.ndarray],
upper: Union[float, np.ndarray],
shape: Optional[Sequence[int]] = None,
dtype: onp.dtype = onp.float64) -> np.ndarray:
2019-09-03 17:51:29 -07:00
"""Sample truncated standard normal random values with given shape and dtype.
2019-08-16 17:02:20 -07:00
Args:
key: a PRNGKey used as the random key.
2019-10-20 21:14:48 +00:00
lower: a float or array of floats representing the lower bound for
truncation. Must be broadcast-compatible with ``upper``.
upper: a float or array of floats representing the upper bound for
truncation. Must be broadcast-compatible with ``lower``.
shape: optional, a tuple of nonnegative integers specifying the result
2019-10-20 21:14:48 +00:00
shape. Must be broadcast-compatible with ``lower`` and ``upper``. The
default (None) produces a result shape by broadcasting ``lower`` and
``upper``.
2019-08-16 17:02:20 -07:00
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
2019-10-20 21:14:48 +00:00
A random array with the specified dtype and shape given by ``shape`` if
``shape`` is not None, or else by broadcasting ``lower`` and ``upper``.
2019-08-16 17:02:20 -07:00
"""
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = abstract_arrays.canonicalize_shape(shape)
2019-08-16 17:02:20 -07:00
return _truncated_normal(key, lower, upper, shape, dtype)
@partial(jit, static_argnums=(3, 4))
def _truncated_normal(key, lower, upper, shape, dtype):
2019-10-20 21:14:48 +00:00
if shape is None:
shape = lax.broadcast_shapes(onp.shape(lower), onp.shape(upper))
2019-10-20 21:14:48 +00:00
else:
_check_shape("truncated_normal", shape, onp.shape(lower), onp.shape(upper))
2019-10-20 21:14:48 +00:00
2019-09-03 17:51:29 -07:00
sqrt2 = onp.array(onp.sqrt(2), dtype)
a = lax.erf(lax.convert_element_type(lower, dtype) / sqrt2)
b = lax.erf(lax.convert_element_type(upper, dtype) / sqrt2)
if not np.issubdtype(dtype, onp.floating):
raise TypeError("truncated_normal only accepts floating point dtypes.")
u = uniform(key, shape, dtype, minval=np.finfo(dtype).tiny)
2019-08-16 17:02:20 -07:00
return sqrt2 * lax.erf_inv(a + u * (b - a))
def bernoulli(key: np.ndarray,
p: np.ndarray = onp.float32(0.5),
shape: Optional[Sequence[int]] = None) -> np.ndarray:
2018-11-17 18:03:33 -08:00
"""Sample Bernoulli random values with given shape and mean.
Args:
key: a PRNGKey used as the random key.
p: optional, a float or array of floats for the mean of the random
variables. Must be broadcast-compatible with ``shape``. Default 0.5.
2019-10-20 21:14:48 +00:00
shape: optional, a tuple of nonnegative integers representing the result
shape. Must be broadcast-compatible with ``p.shape``. The default (None)
produces a result shape equal to ``p.shape``.
2018-11-17 18:03:33 -08:00
Returns:
2019-10-20 21:14:48 +00:00
A random array with boolean dtype and shape given by ``shape`` if ``shape``
is not None, or else ``p.shape``.
2018-11-17 18:03:33 -08:00
"""
dtype = dtypes.canonicalize_dtype(lax.dtype(p))
if shape is not None:
shape = abstract_arrays.canonicalize_shape(shape)
if not np.issubdtype(dtype, onp.floating):
2019-05-22 20:06:12 -07:00
msg = "bernoulli probability `p` must have a floating dtype, got {}."
raise TypeError(msg.format(dtype))
p = lax.convert_element_type(p, dtype)
2019-04-21 21:22:50 -04:00
return _bernoulli(key, p, shape)
@partial(jit, static_argnums=(2,))
2019-04-21 21:22:50 -04:00
def _bernoulli(key, p, shape):
2019-10-20 21:14:48 +00:00
if shape is None:
shape = onp.shape(p)
2019-10-20 21:14:48 +00:00
else:
_check_shape("bernoulli", shape, onp.shape(p))
2019-10-20 21:14:48 +00:00
return uniform(key, shape, lax.dtype(p)) < p
def beta(key: np.ndarray,
a: Union[float, np.ndarray],
b: Union[float, np.ndarray],
shape: Optional[Sequence[int]] = None,
dtype: onp.dtype = onp.float64) -> np.ndarray:
2019-04-21 16:25:20 -04:00
"""Sample Bernoulli random values with given shape and mean.
Args:
key: a PRNGKey used as the random key.
a: a float or array of floats broadcast-compatible with ``shape``
representing the first parameter "alpha".
b: a float or array of floats broadcast-compatible with ``shape``
representing the second parameter "beta".
shape: optional, a tuple of nonnegative integers specifying the result
2019-10-20 21:14:48 +00:00
shape. Must be broadcast-compatible with ``a`` and ``b``. The default
(None) produces a result shape by broadcasting ``a`` and ``b``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
2019-04-21 16:25:20 -04:00
Returns:
2019-10-20 21:14:48 +00:00
A random array with the specified dtype and shape given by ``shape`` if
``shape`` is not None, or else by broadcasting ``a`` and ``b``.
2019-04-21 16:25:20 -04:00
"""
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = abstract_arrays.canonicalize_shape(shape)
2019-04-21 16:25:20 -04:00
return _beta(key, a, b, shape, dtype)
def _beta(key, a, b, shape, dtype):
2019-10-20 21:14:48 +00:00
if shape is None:
shape = lax.broadcast_shapes(onp.shape(a), onp.shape(b))
2019-10-20 21:14:48 +00:00
else:
_check_shape("beta", shape, onp.shape(a), onp.shape(b))
2019-10-20 21:14:48 +00:00
2019-04-21 16:25:20 -04:00
a = lax.convert_element_type(a, dtype)
b = lax.convert_element_type(b, dtype)
2019-04-21 16:43:18 -04:00
key_a, key_b = split(key)
2019-12-23 23:02:08 -05:00
a = np.broadcast_to(a, shape)
b = np.broadcast_to(b, shape)
2019-04-21 16:25:20 -04:00
gamma_a = gamma(key_a, a, shape, dtype)
gamma_b = gamma(key_b, b, shape, dtype)
return gamma_a / (gamma_a + gamma_b)
def cauchy(key, shape=(), dtype=onp.float64):
"""Sample Cauchy random values with given shape and float dtype.
Args:
key: a PRNGKey used as the random key.
shape: optional, a tuple of nonnegative integers representing the result
shape. Default ().
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified shape and dtype.
"""
dtype = dtypes.canonicalize_dtype(dtype)
shape = abstract_arrays.canonicalize_shape(shape)
return _cauchy(key, shape, dtype)
@partial(jit, static_argnums=(1, 2))
def _cauchy(key, shape, dtype):
_check_shape("cauchy", shape)
u = uniform(key, shape, dtype, minval=np.finfo(dtype).eps, maxval=1.)
pi = _constant_like(u, onp.pi)
2019-03-28 23:57:00 -04:00
return lax.tan(lax.mul(pi, lax.sub(u, _constant_like(u, 0.5))))
2019-10-20 21:14:48 +00:00
def dirichlet(key, alpha, shape=None, dtype=onp.float64):
2020-02-14 11:04:20 -05:00
"""Sample Dirichlet random values with given shape and float dtype.
2019-04-22 11:55:02 -04:00
Args:
key: a PRNGKey used as the random key.
alpha: an array of shape ``(..., n)`` used as the concentration
parameter of the random variables.
shape: optional, a tuple of nonnegative integers specifying the result
batch shape; that is, the prefix of the result shape excluding the last
element of value ``n``. Must be broadcast-compatible with
2019-10-20 21:14:48 +00:00
``alpha.shape[:-1]``. The default (None) produces a result shape equal to
``alpha.shape``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
2019-04-22 11:55:02 -04:00
Returns:
A random array with the specified dtype and shape given by
2019-10-20 21:14:48 +00:00
``shape + (alpha.shape[-1],)`` if ``shape`` is not None, or else
``alpha.shape``.
2019-04-22 11:55:02 -04:00
"""
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = abstract_arrays.canonicalize_shape(shape)
2019-04-22 11:55:02 -04:00
return _dirichlet(key, alpha, shape, dtype)
@partial(jit, static_argnums=(2, 3))
def _dirichlet(key, alpha, shape, dtype):
if not onp.ndim(alpha) >= 1:
msg = "dirichlet requires alpha.ndim >= 1, got alpha.ndim == {}"
raise ValueError(msg.format(onp.ndim(alpha)))
2019-10-20 21:14:48 +00:00
if shape is None:
shape = onp.shape(alpha)[:-1]
2019-10-20 21:14:48 +00:00
else:
_check_shape("dirichlet", shape, onp.shape(alpha)[:-1])
2019-10-20 21:14:48 +00:00
alpha = lax.convert_element_type(alpha, dtype)
gamma_samples = gamma(key, alpha, shape + onp.shape(alpha)[-1:], dtype)
2019-04-22 11:55:02 -04:00
return gamma_samples / np.sum(gamma_samples, axis=-1, keepdims=True)
def exponential(key, shape=(), dtype=onp.float64):
2019-03-28 23:57:00 -04:00
"""Sample Exponential random values with given shape and float dtype.
Args:
key: a PRNGKey used as the random key.
shape: optional, a tuple of nonnegative integers representing the result
shape. Default ().
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
2019-03-28 23:57:00 -04:00
Returns:
A random array with the specified shape and dtype.
"""
dtype = dtypes.canonicalize_dtype(dtype)
shape = abstract_arrays.canonicalize_shape(shape)
return _exponential(key, shape, dtype)
@partial(jit, static_argnums=(1, 2))
def _exponential(key, shape, dtype):
_check_shape("exponential", shape)
2019-03-28 23:57:00 -04:00
u = uniform(key, shape, dtype)
# taking 1 - u to move the domain of log to (0, 1] instead of [0, 1)
return lax.neg(lax.log1p(lax.neg(u)))
2019-03-28 23:57:00 -04:00
2019-03-30 18:07:34 -04:00
def _gamma_one(key, alpha):
# Ref: A simple method for generating gamma variables, George Marsaglia and Wai Wan Tsang
# The algorithm can also be founded in:
# https://en.wikipedia.org/wiki/Gamma_distribution#Generating_gamma-distributed_random_variables
zero = _constant_like(alpha, 0)
one = _constant_like(alpha, 1)
2019-06-20 18:41:52 -04:00
minus_one = _constant_like(alpha, -1)
one_over_two = _constant_like(alpha, 0.5)
one_over_three = _constant_like(alpha, 1. / 3.)
squeeze_const = _constant_like(alpha, 0.0331)
dtype = lax.dtype(alpha)
2019-06-27 17:28:36 -04:00
key, subkey = split(key)
# for alpha < 1, we boost alpha to alpha + 1 and get a sample according to
# Gamma(alpha) ~ Gamma(alpha+1) * Uniform()^(1 / alpha)
boost = lax.select(lax.ge(alpha, one),
one,
2019-06-27 17:28:36 -04:00
lax.pow(uniform(subkey, (), dtype=dtype), lax.div(one, alpha)))
alpha = lax.select(lax.ge(alpha, one), alpha, lax.add(alpha, one))
2019-03-30 18:07:34 -04:00
d = lax.sub(alpha, one_over_three)
c = lax.div(one_over_three, lax.pow(d, one_over_two))
2019-03-30 18:07:34 -04:00
def _cond_fn(kXVU):
_, X, V, U = kXVU
# TODO: use lax.cond when its batching rule is supported
# The reason is to avoid evaluating second condition which involves log+log
2019-03-30 18:07:34 -04:00
# if the first condition is satisfied
cond = lax.bitwise_and(lax.ge(U, lax.sub(one, lax.mul(squeeze_const, lax.mul(X, X)))),
lax.ge(lax.log(U), lax.add(lax.mul(X, one_over_two),
lax.mul(d, lax.add(lax.sub(one, V),
lax.log(V))))))
2019-06-20 18:41:52 -04:00
return cond
2019-03-30 18:07:34 -04:00
def _body_fn(kXVU):
2019-06-20 18:41:52 -04:00
def _next_kxv(kxv):
2019-06-27 17:28:36 -04:00
key = kxv[0]
key, subkey = split(key)
x = normal(subkey, (), dtype=dtype)
2019-06-20 18:41:52 -04:00
v = lax.add(one, lax.mul(x, c))
2019-06-27 17:28:36 -04:00
return key, x, v
2019-06-20 18:41:52 -04:00
2019-06-27 17:28:36 -04:00
key = kXVU[0]
key, x_key, U_key = split(key, 3)
_, x, v = lax.while_loop(lambda kxv: lax.le(kxv[2], zero), _next_kxv, (x_key, zero, minus_one))
X = lax.mul(x, x)
V = lax.mul(lax.mul(v, v), v)
2019-06-27 17:28:36 -04:00
U = uniform(U_key, (), dtype=dtype)
2019-03-30 18:07:34 -04:00
return key, X, V, U
# initial state is chosen such that _cond_fn will return True
2019-06-20 18:41:52 -04:00
_, _, V, _ = lax.while_loop(_cond_fn, _body_fn, (key, zero, one, _constant_like(alpha, 2)))
z = lax.mul(lax.mul(d, V), boost)
return lax.select(lax.eq(z, zero), np.finfo(z.dtype).tiny, z)
2019-03-30 18:07:34 -04:00
2019-06-20 18:41:52 -04:00
_bivariate_coef = [[0.16009398, -0.094634816, 0.025146379, -0.0030648348,
1, 0.3266811, 0.10406087, 0.0014179033],
[0.53487893, 0.12980707, 0.06573594, -0.0015649787,
0.16639465, 0.020070098, -0.0035938937, -0.00058392601],
[0.040121005, -0.0065914079, -0.002628604, -0.0013441777,
0.017050642, -0.0021309345, 0.00085092385, -1.5248239e-07]]
def _gamma_grad_one(z, alpha):
# Ref 1: Pathwise Derivatives Beyond the Reparameterization Trick, Martin & Fritz
# Ref 2: Case 4 follows https://github.com/fritzo/notebooks/blob/master/gamma-reparameterized.ipynb
# TODO: use lax.cond instead of lax.while_loop when its batching rule is available
# See https://github.com/google/jax/issues/490
def _case1(zagf):
z, alpha, _, flag = zagf
# dz = - dCDF(z; a) / pdf(z; a)
# pdf = z^(a-1) * e^(-z) / Gamma(a)
# CDF(z; a) = IncompleteGamma(a, z) / Gamma(a)
# dCDF(z; a) = (dIncompleteGamma - IncompleteGamma * Digamma(a)) / Gamma(a)
# =: unnormalized_dCDF / Gamma(a)
# IncompleteGamma ~ z^a [ 1/a - z/(a+1) + z^2/2!(a+2) - z^3/3!(a+3) + z^4/4!(a+4) - z^5/5!(a+5) ]
# =: z^a * term1
# dIncompleteGamma ~ z^a * log(z) * term1 - z^a [1/a^2 - z/(a+1)^2 + z^2/2!(a+2)^2
# - z^3/3!(a+3)^2 + z^4/4!(a+4)^2 - z^5/5!(a+5)^2 ]
# =: z^a * log(z) * term1 - z^a * term2
# unnormalized_dCDF = z^a { [log(z) - Digamma(a)] * term1 - term2 }
zi = 1.0
update = zi / alpha
term1 = update
term2 = update / alpha
for i in range(1, 6):
zi = -zi * z / i
update = zi / (alpha + i)
term1 = term1 + update
term2 = term2 + update / (alpha + i)
unnormalized_cdf_dot = np.power(z, alpha) * ((np.log(z) - lax.digamma(alpha)) * term1 - term2)
unnormalized_pdf = np.power(z, alpha - 1) * np.exp(-z)
grad = -unnormalized_cdf_dot / unnormalized_pdf
return z, alpha, grad, ~flag
def _cond2(zagf):
z, alpha, _, flag = zagf
return (~flag) & (alpha > 8.0) & ((z < 0.9 * alpha) | (z > 1.1 * alpha))
def _case2(zagf):
z, alpha, _, flag = zagf
# Formula 58 of [1]
sqrt_8a = np.sqrt(8 * alpha)
z_minus_a = z - alpha
log_z_div_a = np.log(z / alpha)
sign = np.where(z < alpha, lax._const(z, 1.0), lax._const(z, -1.0))
term1 = 4 * (z + alpha) / (sqrt_8a * z_minus_a * z_minus_a)
term2 = log_z_div_a * (sqrt_8a / z_minus_a + sign * np.power(z_minus_a - alpha * log_z_div_a, -1.5))
term3 = z * (1.0 + 1.0 / (12 * alpha) + 1.0 / (288 * alpha * alpha)) / sqrt_8a
grad = (term1 + term2) * term3
return z, alpha, grad, ~flag
def _cond3(zagf):
z, alpha, _, flag = zagf
return (~flag) & (alpha > 8.0) & (z >= 0.9 * alpha) & (z <= 1.1 * alpha)
def _case3(zagf):
z, alpha, _, flag = zagf
# Formula 59 of [1]
z_div_a = np.divide(z, alpha)
aa = alpha * alpha
term1 = 1440 * alpha + 6 * z_div_a * (53 - 120 * z) - 65 * z_div_a * z_div_a + 3600 * z + 107
term2 = 1244160 * alpha * aa
term3 = 1 + 24 * alpha + 288 * aa
grad = term1 * term3 / term2
return z, alpha, grad, ~flag
def _case4(zagf):
z, alpha, _, flag = zagf
# Ref [2]
u = np.log(z / alpha)
v = np.log(alpha)
c = []
for i in range(8):
c.append(_bivariate_coef[0][i] + u * (_bivariate_coef[1][i] + u * _bivariate_coef[2][i]))
p = c[0] + v * (c[1] + v * (c[2] + v * c[3]))
q = c[4] + v * (c[5] + v * (c[6] + v * c[7]))
grad = np.exp(p / np.maximum(q, 0.01))
return z, alpha, grad, ~flag
_, _, grad, flag = lax.while_loop(lambda zagf: (~zagf[3]) & (zagf[0] < 0.8),
_case1,
(z, alpha, lax._const(alpha, 0.0), False))
_, _, grad, flag = lax.while_loop(_cond2, _case2, (z, alpha, grad, flag))
_, _, grad, flag = lax.while_loop(_cond3, _case3, (z, alpha, grad, flag))
_, _, grad, flag = lax.while_loop(lambda zagf: ~zagf[3], _case4, (z, alpha, grad, flag))
return grad
2019-06-20 18:41:52 -04:00
def _gamma_grad(sample, a):
2019-12-23 22:52:15 -05:00
samples = np.reshape(sample, -1)
alphas = np.reshape(a, -1)
if xla_bridge.get_backend().platform == 'cpu':
grads = lax.map(lambda args: _gamma_grad_one(*args), (samples, alphas))
else:
2019-06-21 10:15:25 -04:00
grads = vmap(_gamma_grad_one)(samples, alphas)
2019-12-23 22:52:15 -05:00
return grads.reshape(onp.shape(a))
2019-06-20 18:41:52 -04:00
def _gamma_impl(key, a):
2019-12-26 22:43:06 -05:00
a_shape = np.shape(a)
# split key to match the shape of a
key_ndim = np.ndim(key) - 1
key = np.reshape(key, (-1, 2))
key = vmap(split, in_axes=(0, None))(key, prod(a_shape[key_ndim:]))
keys = np.reshape(key, (-1, 2))
2019-12-26 22:43:06 -05:00
alphas = np.reshape(a, -1)
2019-12-23 22:52:15 -05:00
if xla_bridge.get_backend().platform == 'cpu':
samples = lax.map(lambda args: _gamma_one(*args), (keys, alphas))
2019-12-23 22:52:15 -05:00
else:
2019-06-20 18:41:52 -04:00
samples = vmap(_gamma_one)(keys, alphas)
2019-12-26 22:43:06 -05:00
return np.reshape(samples, a_shape),
def _gamma_batching_rule(batched_args, batch_dims):
k, a = batched_args
bk, ba = batch_dims
size = next(t.shape[i] for t, i in zip(batched_args, batch_dims) if i is not None)
k = batching.bdim_at_front(k, bk, size)
a = batching.bdim_at_front(a, ba, size)
2019-12-23 22:52:15 -05:00
return random_gamma_p.bind(k, a), (0,)
random_gamma_p = core.Primitive('random_gamma')
2019-12-23 22:52:15 -05:00
random_gamma_p.multiple_results = True
random_gamma_p.def_impl(_gamma_impl)
random_gamma_p.def_abstract_eval(lambda key, a: (abstract_arrays.raise_to_shaped(a),))
ad.defjvp2(random_gamma_p, None, lambda tangent, ans, key, a: (tangent * _gamma_grad(ans[0], a),))
change the xla representation of JAX's unit (#2416) * change the xla representation of JAX's unit Previously the representation of JAX's unit value (a sentinel / placeholder) was an empty tuple, but by changing the representation to something else we can further reduce our dependence on runtime tuples. This commit makes the representation fairly easy to change. There are three functions in xla.py that define the representation. Here are versions that would keep the old XLA representation as an empty tuple: ``` def _make_unit(c): return c.Tuple() def _make_abstract_unit(_): return xc.Shape.tuple_shape(()) def _device_put_unit(_, device): return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device)) ``` The new representation is as a trivial array. An alternative representation would be nothing at all: we don't need to generate XLA computations that have representations of JAX units. While that alterntaive is probably the best choice, it seemed like it would require a bit more refactoring/bookkeeping (e.g. to allow XLA computations to have a smaller number of outputs than the corresponding JAX function), and would also mean the XLA representation would be a step further removed from the jaxpr representation. So I stuck with a trivial array for now. The mapping from JAX types to XLA types need not be invertible. However, XLA translation rules currently don't take as arguments the corresponding JAX types (abstract values), and there were a few cases where we relied on checking whether an argument's XLA type was that of an empty tuple so as to determine if we were effectively operating on a JAX unit. In particular, the AD-related primitive add_jaxvals_p could in principle add two units, and get lowered to an XLA addition on the unit representation. Previously, the translation rule for add_jaxvals_p checked the XLA type so that adding two empty tuples didn't produce any XLA operation; now it adds its inputs, and so if unit is represented as a trivial array we could be inserting trivial scalar adds where we had none before. However, if that case is ever possible, it doesn't come up in our tests (which I checked by keeping the representation as an empty tuple and then asserting an XLA tuple type is never seen by that translation rule). * add comment about JAX<->XLA array types assumption
2020-03-14 12:33:14 -07:00
xla.translations[random_gamma_p] = xla.lower_fun(_gamma_impl)
batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule
2019-06-20 18:41:52 -04:00
2019-10-20 21:14:48 +00:00
def gamma(key, a, shape=None, dtype=onp.float64):
2019-03-30 18:07:34 -04:00
"""Sample Gamma random values with given shape and float dtype.
Args:
key: a PRNGKey used as the random key.
a: a float or array of floats broadcast-compatible with ``shape``
representing the parameter of the distribution.
shape: optional, a tuple of nonnegative integers specifying the result
2019-10-20 21:14:48 +00:00
shape. Must be broadcast-compatible with ``a``. The default (None)
produces a result shape equal to ``a.shape``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
2019-03-30 18:07:34 -04:00
Returns:
2019-10-20 21:14:48 +00:00
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``a.shape``.
2019-03-30 18:07:34 -04:00
"""
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = abstract_arrays.canonicalize_shape(shape)
2019-04-12 07:15:41 -07:00
return _gamma(key, a, shape, dtype)
2019-04-12 07:15:41 -07:00
@partial(jit, static_argnums=(2, 3))
def _gamma(key, a, shape, dtype):
2019-10-20 21:14:48 +00:00
if shape is None:
shape = onp.shape(a)
2019-10-20 21:14:48 +00:00
else:
_check_shape("gamma", shape, onp.shape(a))
2019-10-20 21:14:48 +00:00
2019-03-30 18:07:34 -04:00
a = lax.convert_element_type(a, dtype)
if onp.shape(a) != shape:
a = np.broadcast_to(a, shape)
2019-12-23 22:52:15 -05:00
return random_gamma_p.bind(key, a)[0]
2019-03-30 18:07:34 -04:00
def gumbel(key, shape=(), dtype=onp.float64):
2019-04-21 16:25:20 -04:00
"""Sample Gumbel random values with given shape and float dtype.
Args:
key: a PRNGKey used as the random key.
shape: optional, a tuple of nonnegative integers representing the result
shape. Default ().
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
2019-04-21 16:25:20 -04:00
Returns:
A random array with the specified shape and dtype.
"""
dtype = dtypes.canonicalize_dtype(dtype)
shape = abstract_arrays.canonicalize_shape(shape)
2019-04-21 16:25:20 -04:00
return _gumbel(key, shape, dtype)
@partial(jit, static_argnums=(1, 2))
def _gumbel(key, shape, dtype):
_check_shape("gumbel", shape)
return -np.log(-np.log(
uniform(key, shape, dtype, minval=np.finfo(dtype).eps, maxval=1.)))
2019-04-21 16:25:20 -04:00
def categorical(key, logits, axis=-1, shape=None):
2019-12-13 11:46:08 +00:00
"""Sample random values from categorical distributions.
Args:
key: a PRNGKey used as the random key.
logits: Unnormalized log probabilities of the categorical distribution(s) to sample from,
so that `softmax(logits, axis)` gives the corresponding probabilities.
axis: Axis along which logits belong to the same categorical distribution.
shape: Optional, a tuple of nonnegative integers representing the result shape.
Must be broadcast-compatible with ``onp.delete(logits.shape, axis)``.
The default (None) produces a result shape equal to ``onp.delete(logits.shape, axis)``.
2019-12-13 11:46:08 +00:00
Returns:
A random array with int dtype and shape given by ``shape`` if ``shape``
is not None, or else ``onp.delete(logits.shape, axis)``.
"""
2019-12-13 11:46:08 +00:00
if axis >= 0:
axis -= len(logits.shape)
batch_shape = tuple(onp.delete(logits.shape, axis))
if shape is None:
shape = batch_shape
else:
_check_shape("categorical", shape, batch_shape)
2019-12-13 11:46:08 +00:00
sample_shape = shape[:len(shape)-len(batch_shape)]
return np.argmax(gumbel(key, sample_shape + logits.shape, logits.dtype) + logits, axis=axis)
2019-04-21 16:25:20 -04:00
def laplace(key, shape=(), dtype=onp.float64):
2019-03-28 23:57:00 -04:00
"""Sample Laplace random values with given shape and float dtype.
Args:
key: a PRNGKey used as the random key.
shape: optional, a tuple of nonnegative integers representing the result
shape. Default ().
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
2019-03-28 23:57:00 -04:00
Returns:
A random array with the specified shape and dtype.
"""
dtype = dtypes.canonicalize_dtype(dtype)
shape = abstract_arrays.canonicalize_shape(shape)
return _laplace(key, shape, dtype)
@partial(jit, static_argnums=(1, 2))
def _laplace(key, shape, dtype):
_check_shape("laplace", shape)
u = uniform(
key, shape, dtype, minval=-1. + np.finfo(dtype).epsneg, maxval=1.)
2019-03-28 23:57:00 -04:00
return lax.mul(lax.sign(u), lax.log1p(lax.neg(lax.abs(u))))
2019-03-30 16:34:20 -04:00
def logistic(key, shape=(), dtype=onp.float64):
"""Sample logistic random values with given shape and float dtype.
Args:
key: a PRNGKey used as the random key.
shape: optional, a tuple of nonnegative integers representing the result
shape. Default ().
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified shape and dtype.
"""
dtype = dtypes.canonicalize_dtype(dtype)
shape = abstract_arrays.canonicalize_shape(shape)
return _logistic(key, shape, dtype)
@partial(jit, static_argnums=(1, 2))
def _logistic(key, shape, dtype):
_check_shape("logistic", shape)
return logit(uniform(key, shape, dtype))
2019-10-20 21:14:48 +00:00
def pareto(key, b, shape=None, dtype=onp.float64):
2019-03-30 16:34:20 -04:00
"""Sample Pareto random values with given shape and float dtype.
Args:
key: a PRNGKey used as the random key.
a: a float or array of floats broadcast-compatible with ``shape``
representing the parameter of the distribution.
shape: optional, a tuple of nonnegative integers specifying the result
2019-10-20 21:14:48 +00:00
shape. Must be broadcast-compatible with ``b``. The default (None)
produces a result shape equal to ``b.shape``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
2019-03-30 16:34:20 -04:00
Returns:
2019-10-20 21:14:48 +00:00
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``b.shape``.
2019-03-30 16:34:20 -04:00
"""
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = abstract_arrays.canonicalize_shape(shape)
return _pareto(key, b, shape, dtype)
@partial(jit, static_argnums=(2, 3))
def _pareto(key, b, shape, dtype):
2019-10-20 21:14:48 +00:00
if shape is None:
shape = onp.shape(b)
2019-10-20 21:14:48 +00:00
else:
_check_shape("pareto", shape)
2019-03-30 16:34:20 -04:00
b = lax.convert_element_type(b, dtype)
e = exponential(key, shape, dtype)
return lax.exp(e / b)
2019-04-02 10:55:03 +01:00
def t(key, df, shape=(), dtype=onp.float64):
2019-04-21 16:25:20 -04:00
"""Sample Student's t random values with given shape and float dtype.
2019-04-02 10:55:03 +01:00
Args:
key: a PRNGKey used as the random key.
df: a float or array of floats broadcast-compatible with ``shape``
representing the parameter of the distribution.
shape: optional, a tuple of nonnegative integers specifying the result
2019-10-20 21:14:48 +00:00
shape. Must be broadcast-compatible with ``df``. The default (None)
produces a result shape equal to ``df.shape``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
2019-04-02 10:55:03 +01:00
Returns:
2019-10-20 21:14:48 +00:00
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``df.shape``.
2019-04-02 10:55:03 +01:00
"""
dtype = dtypes.canonicalize_dtype(dtype)
shape = abstract_arrays.canonicalize_shape(shape)
2019-04-21 16:25:20 -04:00
return _t(key, df, shape, dtype)
2019-04-21 16:25:20 -04:00
@partial(jit, static_argnums=(2, 3))
def _t(key, df, shape, dtype):
2019-10-20 21:14:48 +00:00
if shape is None:
shape = onp.shape(df)
2019-10-20 21:14:48 +00:00
else:
_check_shape("t", shape, onp.shape(df))
2019-10-20 21:14:48 +00:00
2019-04-21 16:25:20 -04:00
df = lax.convert_element_type(df, dtype)
2019-04-21 16:43:18 -04:00
key_n, key_g = split(key)
2019-04-21 16:25:20 -04:00
n = normal(key_n, shape, dtype)
two = _constant_like(n, 2)
half_df = lax.div(df, two)
2019-04-21 16:43:18 -04:00
g = gamma(key_n, half_df, shape, dtype)
2019-04-21 16:25:20 -04:00
return n * np.sqrt(half_df / g)