rocm_jax/tests/random_lax_test.py

1672 lines
61 KiB
Python

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import math
from unittest import SkipTest
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import scipy.linalg
import scipy.special
import scipy.stats
import jax
from jax import grad
from jax import lax
from jax import numpy as jnp
from jax import random
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src import test_util as jtu
from jax import vmap
from jax._src import prng as prng_internal
config.parse_flags_with_absl()
float_dtypes = jtu.dtypes.all_floating
complex_dtypes = jtu.dtypes.complex
int_dtypes = jtu.dtypes.all_integer
uint_dtypes = jtu.dtypes.all_unsigned
@jtu.with_config(jax_legacy_prng_key='allow')
class LaxRandomTest(jtu.JaxTestCase):
def _CheckCollisions(self, samples, nbits):
fail_prob = 0.01 # conservative bound on statistical fail prob by Chebyshev
nitems = len(samples)
nbins = 2 ** nbits
nexpected = nbins * (1 - ((nbins - 1) / nbins) ** nitems)
ncollisions = len(np.unique(samples))
sq_percent_deviation = ((ncollisions - nexpected) / nexpected) ** 2
self.assertLess(sq_percent_deviation, 1 / np.sqrt(nexpected * fail_prob))
def _CheckKolmogorovSmirnovCDF(self, samples, cdf, pval=None):
# conservative bound on statistical fail prob by Kolmo CDF
# bfloat16 quantization creates much lower p-values in large distributions
fail_prob = pval or (0.003 if samples.dtype == jnp.bfloat16 else 0.01)
# TODO(frostig): This reads enable_custom_prng as a proxy for
# whether RBG keys may be involved, but that's no longer exact.
if config.enable_custom_prng.value and samples.dtype == jnp.bfloat16:
return
# kstest does not understand bfloat16 input, so cast to float32.
if samples.dtype == jnp.bfloat16:
samples = samples.astype('float32')
# kstest fails for infinities starting in scipy 1.12
# (https://github.com/scipy/scipy/issues/20386)
# TODO(jakevdp): remove this logic if/when fixed upstream.
scipy_version = jtu.parse_version(scipy.__version__)
if scipy_version >= (1, 12) and np.issubdtype(samples.dtype, np.floating):
samples = np.array(samples, copy=True)
samples[np.isposinf(samples)] = 0.01 * np.finfo(samples.dtype).max
samples[np.isneginf(samples)] = 0.01 * np.finfo(samples.dtype).min
self.assertGreater(scipy.stats.kstest(samples, cdf).pvalue, fail_prob)
def _CheckChiSquared(self, samples, pmf, *, pval=None):
if samples.dtype == bool:
samples = samples.astype(int)
alpha = pval or 0.01 # significance level, threshold for p-value
# scipy.stats.chisquare requires the sum of expected and actual to
# match; this is only the case if we compute the expected frequency
# at *all* nonzero values of the pmf. We don't know this a priori,
# so we add extra values past the largest observed value. The number
# below is empirically enough to get full coverage for the current set
# of tests. If a new test is added where this is not enough, chisquare()
# below will error due to the sums of the inputs not matching.
extra_values = 100
actual_freq = np.bincount(samples, minlength=samples.max() + extra_values)
values = np.arange(len(actual_freq))
expected_freq = pmf(values) * samples.size
valid = expected_freq > 0
actual_freq = actual_freq[valid]
expected_freq = expected_freq[valid]
_, p_value = scipy.stats.chisquare(actual_freq, expected_freq)
self.assertGreater(
p_value, alpha,
msg=f'Failed chi-squared test with p={p_value}.\n'
'Expected vs. actual frequencies:\n'
f'{expected_freq}\n{actual_freq}')
def make_key(self, seed):
return random.PRNGKey(seed, impl='threefry2x32')
@jtu.sample_product(
num=(None, 6, (6,), (2, 3), (2, 3, 4)),
)
def test_split_size_shape(self, num):
key = self.make_key(0)
if num is None:
key_split = jax.random.split(key)
else:
key_split = jax.random.split(key, num)
if num is None:
self.assertEqual(key_split.shape, (2, *key.shape))
elif type(num) is tuple:
self.assertEqual(key_split.shape, (*num, *key.shape))
else:
self.assertEqual(key_split.shape, (num, *key.shape))
@jtu.sample_product(dtype=jtu.dtypes.floating)
def testNumpyAndXLAAgreeOnFloatEndianness(self, dtype):
bits_dtype = np.uint32 if jnp.finfo(dtype).bits == 32 else np.uint64
numpy_bits = np.array(1., dtype).view(bits_dtype)
xla_bits = jax.jit(
lambda: lax.bitcast_convert_type(np.array(1., dtype), bits_dtype))()
self.assertEqual(numpy_bits, xla_bits)
@jtu.sample_product(dtype=float_dtypes)
def testRngUniform(self, dtype):
key = lambda: self.make_key(0)
rand = lambda key: random.uniform(key, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckCollisions(samples, jnp.finfo(dtype).nmant)
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf)
@jtu.sample_product(dtype=int_dtypes + uint_dtypes)
def testRngRandint(self, dtype):
lo = 5
hi = 10
key = lambda: self.make_key(0)
rand = lambda key: random.randint(key, (10000,), lo, hi, dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self.assertTrue(np.all(lo <= samples))
self.assertTrue(np.all(samples < hi))
@jtu.sample_product(dtype=float_dtypes)
def testNormal(self, dtype):
key = lambda: self.make_key(0)
rand = lambda key: random.normal(key, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.norm().cdf)
def testNormalBfloat16(self):
# Passing bfloat16 as dtype string.
# https://github.com/jax-ml/jax/issues/6813
res_bfloat16_str = random.normal(self.make_key(0), dtype='bfloat16')
res_bfloat16 = random.normal(self.make_key(0), dtype=jnp.bfloat16)
self.assertAllClose(res_bfloat16, res_bfloat16_str)
@jtu.sample_product(dtype=complex_dtypes)
def testNormalComplex(self, dtype):
key = lambda: self.make_key(0)
rand = lambda key: random.normal(key, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(jnp.real(samples), scipy.stats.norm(scale=1/np.sqrt(2)).cdf)
self._CheckKolmogorovSmirnovCDF(jnp.imag(samples), scipy.stats.norm(scale=1/np.sqrt(2)).cdf)
self.assertEqual(dtype, samples.dtype)
@jtu.sample_product(dtype=float_dtypes)
def testTruncatedNormal(self, dtype):
key = lambda: self.make_key(0)
rand = lambda key: random.truncated_normal(key, -0.3, 0.3, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
min_val = np.min(uncompiled_samples)
max_val = np.max(uncompiled_samples)
self.assertTrue(min_val > -0.3)
self.assertTrue(max_val < 0.3)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.truncnorm(-0.3, 0.3).cdf)
@jtu.sample_product(
[dict(shape=shape, replace=replace, axis=axis,
input_range_or_shape=input_range_or_shape)
for shape in [(), (5,), (4, 5)]
for replace in [True, False]
for input_range_or_shape in [100, (10, 10), (10, 5, 2), 1, (1, 5)]
for is_range in [type(input_range_or_shape) is int]
for ndim in [1 if is_range else len(input_range_or_shape)]
for axis in range(-ndim, ndim or 1)
for ninputs in [input_range_or_shape if is_range else input_range_or_shape[axis]]
if replace or math.prod(shape) <= ninputs
],
dtype=jtu.dtypes.floating + jtu.dtypes.integer,
weighted=[True, False],
)
def testChoice(self, dtype, input_range_or_shape, shape, replace, weighted, axis):
# This is the function API that we test against (note that self.rng().choice differs)
np_choice = np.random.default_rng(0).choice
p_dtype = dtypes.to_inexact_dtype(dtype)
key = lambda: self.make_key(0)
is_range = type(input_range_or_shape) is int
x = (input_range_or_shape if is_range else
self.rng().permutation(np.arange(math.prod(
input_range_or_shape), dtype=dtype)).reshape(input_range_or_shape))
N = x if is_range else x.shape[axis]
if weighted:
p = np.arange(N, dtype=p_dtype) + 1
p /= p.sum()
else:
p = None
rand = lambda key, x: random.choice(key, x, shape, replace, p, axis)
sample = rand(key(), x)
if not is_range:
self.assertEqual(dtype, sample.dtype)
expected_shape = np.shape(np_choice(x, shape or None, replace, p, axis))
self.assertEqual(expected_shape, sample.shape)
expected_dtype = dtypes.result_type(int if is_range else x)
self.assertEqual(expected_dtype, sample.dtype)
if not replace and shape:
def lsort(x):
if not math.prod(x.shape): return x
ind = np.lexsort(np.swapaxes(x, axis, -1).reshape((-1, x.shape[axis])))
return jnp.take(x, ind, axis)
self.assertArraysEqual(lsort(sample), lsort(np.unique(sample, axis=axis)))
self.assertArraysEqual(sample, rand(key(), np.array(x)))
self.assertArraysEqual(sample, jax.jit(rand, static_argnames=
'x' if is_range else None)(key(), x))
@jtu.sample_product(
[dict(range_or_shape=range_or_shape, axis=axis)
for range_or_shape in [0, 1, 100, (0,), (1,), (100,),
(10, 10), (10, 5, 2), (0, 5), (1, 5)]
for ndim in [1 if type(range_or_shape) is int else len(range_or_shape)]
for axis in range(-ndim, ndim or 1)
],
dtype=jtu.dtypes.floating + jtu.dtypes.integer,
independent=[True, False],
)
def testPermutation(self, dtype, range_or_shape, axis, independent):
key = lambda: self.make_key(0)
is_range = type(range_or_shape) is int
x = (range_or_shape if is_range else
self.rng().permutation(np.arange(
math.prod(range_or_shape), dtype=dtype)).reshape(range_or_shape))
shape = ((range_or_shape,) if is_range else range_or_shape)
x_ = np.copy(x)
rand = lambda key, x: random.permutation(key, x, axis, independent=independent)
perm = rand(key(), x)
if shape[axis] >= 10:
self.assertFalse(np.all(perm == x)) # seems unlikely!
arr = np.arange(x) if is_range else x
def lsort(x):
if not math.prod(x.shape): return x
ind = np.lexsort(np.swapaxes(x, axis, -1).reshape((-1, x.shape[axis])))
return jnp.take(x, ind, axis)
if not independent:
self.assertArraysEqual(lsort(arr), lsort(perm), check_dtypes=not is_range)
if independent and (arr.shape[axis] > 4) and (arr.size // arr.shape[axis] > 4):
# Check for independent shuffling if there are >4 vectors of size >4.
# Chance of false positive is 1 in (5!)^4
with self.assertRaises(AssertionError):
self.assertArraysEqual(lsort(arr), lsort(perm), check_dtypes=not is_range)
self.assertArraysEqual(x_, x)
self.assertArraysEqual(perm, rand(key(), np.array(x)))
self.assertArraysEqual(perm, jax.jit(rand, static_argnames=
'x' if is_range else None)(key(), x))
def testPermutationErrors(self):
key = self.make_key(0)
with self.assertRaises(ValueError):
random.permutation(key, 10, axis=3)
with self.assertRaises(TypeError):
random.permutation(key, 10.)
with self.assertRaises(core.ConcretizationTypeError):
jax.jit(random.permutation)(key, 10)
@jtu.sample_product(
p=[0.1, 0.5, 0.9],
dtype=jtu.dtypes.floating,
)
def testBernoulli(self, p, dtype):
key = lambda: self.make_key(0)
p = np.array(p, dtype=dtype)
rand = lambda key, p: random.bernoulli(key, p, (10000,))
crand = jax.jit(rand)
uncompiled_samples = rand(key(), p)
compiled_samples = crand(key(), p)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckChiSquared(samples, scipy.stats.bernoulli(p).pmf)
@jtu.sample_product(
[dict(p=p, axis=axis)
for (p, axis) in [
([.25] * 4, -1),
([.1, .2, .3, .4], -1),
([[.5, .5], [.1, .9]], 1),
([[.5, .1], [.5, .9]], 0),
]
],
sample_shape=[(10000,), (5000, 2)],
dtype=jtu.dtypes.floating,
)
def testCategorical(self, p, axis, dtype, sample_shape):
key = lambda: self.make_key(0)
p = np.array(p, dtype=dtype)
logits = np.log(p) - 42 # test unnormalized
out_shape = tuple(np.delete(logits.shape, axis))
shape = sample_shape + out_shape
rand = partial(random.categorical, shape=shape, axis=axis)
crand = jax.jit(rand)
uncompiled_samples = rand(key(), logits)
compiled_samples = crand(key(), logits)
if axis < 0:
axis += len(logits.shape)
for samples in [uncompiled_samples, compiled_samples]:
assert samples.shape == shape
samples = jnp.reshape(samples, (10000,) + out_shape)
if len(p.shape[:-1]) > 0:
ps = np.transpose(p, (1, 0)) if axis == 0 else p
for cat_samples, cat_p in zip(samples.transpose(), ps):
pmf = lambda x: np.where(x < len(cat_p), cat_p[np.minimum(len(cat_p) - 1, x)], 0.0)
self._CheckChiSquared(cat_samples, pmf=pmf)
else:
pmf = lambda x: np.where(x < len(p), p[np.minimum(len(p) - 1, x)], 0.0)
self._CheckChiSquared(samples, pmf=pmf)
@jtu.sample_product(
logits_shape=[(7,), (8, 9), (10, 11, 12)],
prefix_shape=[(2,), (3, 4), (5, 6)],
)
def testCategoricalWithoutReplacement(self, logits_shape, prefix_shape):
key = random.key(0)
key, subkey = random.split(key)
logits = random.normal(subkey, logits_shape)
key, subkey = random.split(key)
axis = random.randint(subkey, (), -len(logits_shape), len(logits_shape))
dists_shape = tuple(np.delete(logits_shape, axis))
n_categories = logits_shape[axis]
shape = prefix_shape + dists_shape
prefix_size = math.prod(prefix_shape)
if n_categories < prefix_size:
with self.assertRaisesRegex(ValueError, "Number of samples without replacement"):
random.categorical(key, logits, axis=axis, shape=shape, replace=False)
else:
output = random.categorical(key, logits, axis=axis, shape=shape, replace=False)
self.assertEqual(output.shape, shape)
assert (0 <= output).all()
assert (output < n_categories).all()
flat = output.reshape((prefix_size, math.prod(dists_shape)))
counts = jax.vmap(partial(jnp.bincount, length=n_categories), 1)(flat)
assert (counts <= 1).all()
def testBernoulliShape(self):
key = self.make_key(0)
with jax.numpy_rank_promotion('allow'):
x = random.bernoulli(key, np.array([0.2, 0.3]), shape=(3, 2))
assert x.shape == (3, 2)
@jtu.sample_product(
a=[0.2, 5.],
b=[0.2, 5.],
dtype=[np.float64], # NOTE: KS test fails with float32
)
def testBeta(self, a, b, dtype):
if not config.enable_x64.value:
raise SkipTest("skip test except on X64")
key = lambda: self.make_key(0)
rand = lambda key, a, b: random.beta(key, a, b, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key(), a, b)
compiled_samples = crand(key(), a, b)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.beta(a, b).cdf)
@jtu.skip_on_devices("tpu") # TPU precision causes issues.
def testBetaSmallParameters(self, dtype=np.float32):
# Regression test for beta version of https://github.com/jax-ml/jax/issues/9896
key = self.make_key(0)
a, b = 0.0001, 0.0002
samples = random.beta(key, a, b, shape=(100,), dtype=dtype)
# With such small parameters, all samples should be exactly zero or one.
tol = 5E-2 if jtu.test_device_matches(["tpu"]) else 1E-3
zeros = samples[samples < 0.5]
self.assertAllClose(zeros, jnp.zeros_like(zeros), atol=tol)
ones = samples[samples >= 0.5]
self.assertAllClose(ones, jnp.ones_like(ones), atol=tol)
@jtu.sample_product(dtype=float_dtypes)
def testCauchy(self, dtype):
key = lambda: self.make_key(0)
rand = lambda key: random.cauchy(key, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.cauchy().cdf)
@jtu.sample_product(
alpha=[np.array([0.2, 1., 5.]),],
dtype=jtu.dtypes.floating,
)
@jtu.skip_on_devices("tpu") # TODO(mattjj): slow compilation times
def testDirichlet(self, alpha, dtype):
key = lambda: self.make_key(0)
num_samples = 10000
rand = lambda key, alpha: random.dirichlet(key, alpha, (num_samples,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key(), alpha)
compiled_samples = crand(key(), alpha)
for samples in [uncompiled_samples, compiled_samples]:
self.assertAllClose(samples.sum(-1), np.ones(num_samples, dtype=dtype))
alpha_sum = sum(alpha)
for i, a in enumerate(alpha):
self._CheckKolmogorovSmirnovCDF(samples[..., i],
scipy.stats.beta(a, alpha_sum - a).cdf,
pval=0.003)
@jtu.skip_on_devices("tpu") # lower accuracy leads to failures.
def testDirichletSmallAlpha(self, dtype=np.float32):
# Regression test for https://github.com/jax-ml/jax/issues/9896
key = self.make_key(0)
alpha = 0.00001 * jnp.ones(3)
samples = random.dirichlet(key, alpha, shape=(100,), dtype=dtype)
# Check that results lie on the simplex.
self.assertAllClose(samples.sum(1), jnp.ones(samples.shape[0]),
check_dtypes=False, rtol=1E-5)
# Check that results contain 1 in one of the dimensions:
# this is highly likely to be true when alpha is small.
self.assertAllClose(samples.max(1), jnp.ones(samples.shape[0]),
check_dtypes=False, rtol=1E-4)
@jtu.sample_product(dtype=float_dtypes)
def testExponential(self, dtype):
key = lambda: self.make_key(0)
rand = lambda key: random.exponential(key, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.expon().cdf)
@jtu.sample_product(
a=[0.1, 1., 10.],
dtype=jtu.dtypes.floating,
)
@jtu.skip_on_devices("tpu") # low accuracy leads to failures.
def testGammaVsLogGamma(self, a, dtype):
# Test that gamma() and loggamma() produce equivalent samples.
rand_gamma = lambda key, a: random.gamma(key, a, (100,), dtype)
rand_loggamma = lambda key, a: random.loggamma(key, a, (100,), dtype)
crand_loggamma = jax.jit(rand_loggamma)
tol = {np.float32: 1E-6, np.float64: 1E-12}
key = lambda: self.make_key(0)
self.assertAllClose(rand_gamma(key(), a), jnp.exp(rand_loggamma(key(), a)),
atol=tol, rtol=tol)
self.assertAllClose(rand_gamma(key(), a), jnp.exp(crand_loggamma(key(), a)),
atol=tol, rtol=tol)
@jtu.sample_product(
a=[0.1, 1., 10.],
dtype=jtu.dtypes.floating,
)
def testGamma(self, a, dtype):
key = lambda: self.make_key(1)
rand = lambda key, a: random.gamma(key, a, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key(), a)
compiled_samples = crand(key(), a)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gamma(a).cdf)
def testGammaShape(self):
key = self.make_key(0)
x = random.gamma(key, np.array([0.2, 0.3]), shape=(3, 2))
assert x.shape == (3, 2)
@jtu.sample_product(
log_space=[True, False],
alpha=[1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4],
)
def testGammaGrad(self, log_space, alpha):
rng = lambda: self.make_key(0)
alphas = np.full((100,), alpha)
z = random.gamma(rng(), alphas)
if log_space:
actual_grad = jax.grad(lambda x: lax.exp(random.loggamma(rng(), x)).sum())(alphas)
else:
actual_grad = jax.grad(lambda x: random.gamma(rng(), x).sum())(alphas)
eps = 0.01 * alpha / (1.0 + np.sqrt(alpha))
cdf_dot = (scipy.stats.gamma.cdf(z, alpha + eps)
- scipy.stats.gamma.cdf(z, alpha - eps)) / (2 * eps)
with np.errstate(over='ignore'):
pdf = scipy.stats.gamma.pdf(z, alpha)
expected_grad = -cdf_dot / pdf
rtol = 2e-2 if jtu.test_device_matches(["tpu"]) else 7e-4
self.assertAllClose(actual_grad, expected_grad, check_dtypes=True,
rtol=rtol)
def testGammaGradType(self):
# Regression test for https://github.com/jax-ml/jax/issues/2130
key = self.make_key(0)
a = jnp.array(1., dtype=jnp.float32)
b = jnp.array(3., dtype=jnp.float32)
f = lambda x, y: random.gamma(key=key, a=x, dtype=jnp.float32) / y
# Should not crash with a type error.
jax.vjp(f, a, b)
@jtu.sample_product(
lam=[0.5, 3, 9, 11, 50, 500],
dtype=jtu.dtypes.supported([np.int16, np.int32, np.int64]),
)
def testPoisson(self, lam, dtype):
key = lambda: self.make_key(0)
rand = lambda key, lam: random.poisson(key, lam, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key(), lam)
compiled_samples = crand(key(), lam)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckChiSquared(samples, scipy.stats.poisson(lam).pmf)
# TODO(shoyer): determine error bounds for moments more rigorously (e.g.,
# based on the central limit theorem).
self.assertAllClose(samples.mean(), lam, rtol=0.02, check_dtypes=False)
self.assertAllClose(samples.var(), lam, rtol=0.03, check_dtypes=False)
def testPoissonBatched(self):
key = self.make_key(1)
lam = jnp.concatenate([2 * jnp.ones(10000), 20 * jnp.ones(10000)])
samples = random.poisson(key, lam, shape=(20000,))
self._CheckChiSquared(samples[:10000], scipy.stats.poisson(2.0).pmf)
self._CheckChiSquared(samples[10000:], scipy.stats.poisson(20.0).pmf)
def testPoissonWithoutShape(self):
key = self.make_key(1)
lam = 2 * jnp.ones(10000)
samples = random.poisson(key, lam)
self._CheckChiSquared(samples, scipy.stats.poisson(2.0).pmf)
def testPoissonShape(self):
key = self.make_key(0)
x = random.poisson(key, np.array([2.0, 20.0]), shape=(3, 2))
assert x.shape == (3, 2)
def testPoissonZeros(self):
key = self.make_key(0)
lam = jnp.concatenate([jnp.zeros(10), 20 * jnp.ones(10)])
samples = random.poisson(key, lam, shape=(2, 20))
self.assertArraysEqual(samples[:, :10], jnp.zeros_like(samples[:, :10]))
def testPoissonCornerCases(self):
key = self.make_key(0)
lam = jnp.array([-1, 0, jnp.nan])
samples = random.poisson(key, lam, shape=(3,))
self.assertArraysEqual(samples, jnp.array([-1, 0, -1]), check_dtypes=False)
@jtu.sample_product(dtype=jtu.dtypes.floating)
def testGumbel(self, dtype):
key = lambda: self.make_key(0)
rand = lambda key: random.gumbel(key, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gumbel_r().cdf)
def testLowProbabilityGumbel(self):
dtype = jnp.bfloat16
nmant = jnp.finfo(dtype).nmant
probs = [x * 2 ** -nmant for x in [0.125, 0.75, 1.25, 2.125]]
num_samples = 1024 * 128
num_groups = 128
key = jax.random.key(0)
def compute_counts(key):
v = jax.random.gumbel(key, (num_samples, 1), dtype=dtype, mode="high")
thresholds = np.array([[-np.log(-np.log(1 - x)) for x in probs]],
dtype=dtype)
return (v > thresholds).sum(axis=0)
pts = [float(x) for x in jax.lax.map(
compute_counts, jax.random.split(key, num_groups)).sum(axis=0)]
cdf_probs = [x / (num_samples * num_groups) for x in pts]
np.testing.assert_allclose(cdf_probs, probs, rtol=0.25, atol=0)
@jtu.sample_product(dtype=float_dtypes)
def testLaplace(self, dtype):
key = lambda: self.make_key(0)
rand = lambda key: random.laplace(key, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.laplace().cdf)
@jtu.sample_product(dtype=float_dtypes)
def testLogistic(self, dtype):
key = lambda: self.make_key(0)
rand = lambda key: random.logistic(key, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.logistic().cdf)
@jtu.sample_product(
n=range(5),
shape=[(), (5,), (10, 5)],
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
m=list(range(5)) + [None],
)
@jax.default_matmul_precision("float32")
def testOrthogonal(self, n, shape, dtype, m):
if m is None:
m = n
key = self.make_key(0)
q = random.orthogonal(key, n, shape, dtype, m)
self.assertEqual(q.shape, (*shape, n, m))
self.assertEqual(q.dtype, dtype)
qT = jnp.conj(q).mT
if n <= m:
I_n = jnp.broadcast_to(jnp.eye(n, dtype=dtype), (*shape, n, n))
self.assertAllClose(jnp.linalg.matmul(q, qT), I_n, atol={jnp.complex128: 1e-14})
if n >= m:
I_m = jnp.broadcast_to(jnp.eye(m, dtype=dtype), (*shape, m, m))
self.assertAllClose(jnp.linalg.matmul(qT, q), I_m, atol={jnp.complex128: 1e-14})
@jtu.sample_product(
p=[.5, 1., 1.5, 2., 2.5],
shape=[(), (5,), (10, 5)],
dtype=jtu.dtypes.floating,
)
def testGeneralizedNormal(self, p, shape, dtype):
key = lambda: self.make_key(2)
rand = lambda key, p: random.generalized_normal(key, p, shape, dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key(), p)
compiled_samples = crand(key(), p)
for samples in [uncompiled_samples, compiled_samples]:
self.assertEqual(samples.shape, shape)
self.assertEqual(samples.dtype, dtype)
@jtu.sample_product(
p=[.5, 1., 1.5, 2., 2.5],
shape=[(), (5,), (10, 5)],
dtype=jtu.dtypes.floating,
)
def testGeneralizedNormalKS(self, p, shape, dtype):
self.skipTest( # test is also sometimes slow, with (300, ...)-shape draws
"sensitive to random key - https://github.com/jax-ml/jax/issues/18941")
key = lambda: self.make_key(2)
rand = lambda key, p: random.generalized_normal(key, p, (300, *shape), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key(), p)
compiled_samples = crand(key(), p)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples.ravel(), scipy.stats.gennorm(p).cdf)
@jtu.sample_product(
d=range(1, 5),
p=[.5, 1., 1.5, 2., 2.5],
shape=[(), (5,), (10, 5)],
dtype=jtu.dtypes.floating,
)
@jtu.skip_on_devices("tpu") # TPU precision causes issues.
def testBall(self, d, p, shape, dtype):
key = lambda: self.make_key(123)
rand = lambda key, p: random.ball(key, d, p, shape, dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key(), p)
compiled_samples = crand(key(), p)
for samples in [uncompiled_samples, compiled_samples]:
self.assertEqual(samples.shape, (*shape, d))
self.assertEqual(samples.dtype, dtype)
self.assertTrue(((jnp.abs(samples) ** p).sum(-1) <= 1).all())
@jtu.sample_product(
d=range(1, 5),
p=[.5, 1., 1.5, 2., 2.5],
shape=[(), (5,), (10, 5)],
dtype=jtu.dtypes.floating,
)
@jtu.skip_on_devices("tpu") # TPU precision causes issues.
def testBallKS(self, d, p, shape, dtype):
self.skipTest(
"sensitive to random key - https://github.com/jax-ml/jax/issues/18932")
key = lambda: self.make_key(123)
rand = lambda key, p: random.ball(key, d, p, (100, *shape), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key(), p)
compiled_samples = crand(key(), p)
for samples in [uncompiled_samples, compiled_samples]:
norms = (jnp.abs(samples) ** p).sum(-1) ** (d / p)
self._CheckKolmogorovSmirnovCDF(norms.ravel(), scipy.stats.uniform().cdf)
@jtu.sample_product(
b=[0.1, 1., 10.],
dtype=jtu.dtypes.floating,
)
def testPareto(self, b, dtype):
key = lambda: self.make_key(0)
rand = lambda key, b: random.pareto(key, b, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key(), b)
compiled_samples = crand(key(), b)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.pareto(b).cdf)
def testParetoShape(self):
key = self.make_key(0)
with jax.numpy_rank_promotion('allow'):
x = random.pareto(key, np.array([0.2, 0.3]), shape=(3, 2))
assert x.shape == (3, 2)
@jtu.sample_product(
df=[0.1, 1., 10.],
dtype=jtu.dtypes.floating,
)
@jtu.skip_on_devices("cpu", "tpu") # TODO(phawkins): slow compilation times
def testT(self, df, dtype):
key = lambda: self.make_key(1)
rand = lambda key, df: random.t(key, df, (10000,), dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key(), df)
compiled_samples = crand(key(), df)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.t(df).cdf)
@jtu.sample_product(
dim=[1, 3, 5],
dtype=float_dtypes,
method=['svd', 'eigh', 'cholesky'],
)
def testMultivariateNormal(self, dim, dtype, method):
r = self.rng()
mean = r.randn(dim)
cov_factor = r.randn(dim, dim)
cov = np.dot(cov_factor, cov_factor.T) + dim * np.eye(dim)
key = lambda: self.make_key(0)
rand = partial(random.multivariate_normal, mean=mean, cov=cov,
shape=(10000,), method=method)
crand = jax.jit(rand)
with jax.numpy_rank_promotion('allow'):
uncompiled_samples = np.asarray(rand(key()), np.float64)
compiled_samples = np.asarray(crand(key()), np.float64)
inv_scale = scipy.linalg.lapack.dtrtri(np.linalg.cholesky(cov), lower=True)[0]
for samples in [uncompiled_samples, compiled_samples]:
centered = samples - mean
whitened = np.einsum('nj,ij->ni', centered, inv_scale)
# This is a quick-and-dirty multivariate normality check that tests that a
# uniform mixture of the marginals along the covariance matrix's
# eigenvectors follow a standard normal distribution.
self._CheckKolmogorovSmirnovCDF(whitened.ravel(), scipy.stats.norm().cdf)
@jtu.sample_product(
dim=[1, 2, 4],
mean_batch_size=[(), (3,), (2, 3)],
cov_batch_size=[(), (3,), (2, 3)],
shape=[(), (1,), (5,)],
method=['cholesky', 'svd', 'eigh'],
)
def testMultivariateNormalShapes(self, dim, mean_batch_size, cov_batch_size,
shape, method):
r = self.rng()
key = self.make_key(0)
eff_batch_size = mean_batch_size \
if len(mean_batch_size) > len(cov_batch_size) else cov_batch_size
mean = r.randn(*(mean_batch_size + (dim,)))
cov_factor = r.randn(*(cov_batch_size + (dim, dim)))
cov = np.einsum('...ij,...kj->...ik', cov_factor, cov_factor)
cov += 1e-3 * np.eye(dim)
shape = shape + eff_batch_size
with jax.numpy_rank_promotion('allow'):
samples = random.multivariate_normal(key, mean, cov, shape=shape, method=method)
assert samples.shape == shape + (dim,)
def testMultivariateNormalCovariance(self):
# test code based on https://github.com/jax-ml/jax/issues/1869
N = 100000
mean = jnp.zeros(4)
cov = jnp.array([[ 0.19, 0.00, -0.13, 0.00],
[ 0.00, 0.29, 0.00, -0.23],
[ -0.13, 0.00, 0.39, 0.00],
[ 0.00, -0.23, 0.00, 0.49]], dtype=mean.dtype)
out_np = self.rng().multivariate_normal(mean, cov, N)
key = self.make_key(0)
with jax.numpy_rank_promotion('allow'):
out_jnp = random.multivariate_normal(key, mean=mean, cov=cov, shape=(N,))
var_np = out_np.var(axis=0)
var_jnp = out_jnp.var(axis=0)
self.assertAllClose(var_np, var_jnp, rtol=1e-2, atol=1e-2,
check_dtypes=False)
var_np = np.cov(out_np, rowvar=False)
var_jnp = np.cov(out_jnp, rowvar=False)
self.assertAllClose(var_np, var_jnp, rtol=1e-2, atol=1e-2,
check_dtypes=False)
@jtu.sample_product(method=['cholesky', 'eigh', 'svd'])
@jtu.skip_on_devices('gpu', 'tpu') # Some NaNs on accelerators.
def testMultivariateNormalSingularCovariance(self, method):
# Singular covariance matrix https://github.com/jax-ml/jax/discussions/13293
mu = jnp.zeros((2,))
sigma = jnp.ones((2, 2))
key = self.make_key(0)
result = random.multivariate_normal(key, mean=mu, cov=sigma, shape=(10,), method=method)
self.assertAllClose(result[:, 0], result[:, 1], atol=1e-3, rtol=1e-3)
# Cholesky fails for singular inputs.
if method == 'cholesky':
self.assertTrue(np.all(np.isnan(result)))
else:
self.assertFalse(np.any(np.isnan(result)))
def testIssue222(self):
x = random.randint(self.make_key(10003), (), 0, 0)
assert x == 0
def testFoldIn(self):
key = self.make_key(0)
keys = [random.key_data(random.fold_in(key, i)) for i in range(10)]
assert np.unique(keys, axis=0).shape[0] == 10
def testFoldInBig(self):
key = self.make_key(0)
seeds = [2 ** 32 - 2, 2 ** 32 - 1]
keys = [random.key_data(random.fold_in(key, seed)) for seed in seeds]
assert np.unique(keys, axis=0).shape[0] == 2
def testStaticShapeErrors(self):
if config.disable_jit.value:
raise SkipTest("test only relevant when jit enabled")
@jax.jit
def feature_map(n, d, sigma=1.0, seed=123):
key = self.make_key(seed)
W = random.normal(key, (d, n)) / sigma
w = random.normal(key, (d, )) / sigma
b = 2 * jnp.pi * random.uniform(key, (d, ))
phi = lambda x, t: jnp.sqrt(2.0 / d) * jnp.cos(jnp.matmul(W, x) + w*t + b)
return phi
self.assertRaisesRegex(TypeError, 'Shapes must be 1D.*',
lambda: feature_map(5, 3))
def testIssue756(self):
key = self.make_key(0)
w = random.normal(key, ())
self.assertEqual(w.dtype, dtypes.canonicalize_dtype(jnp.float_))
def testIssue1789(self):
def f(x):
return random.gamma(self.make_key(0), x)
grad(lambda x: jnp.sum(vmap(f)(x)))(jnp.ones(2))
def testDtypeErrorMessage(self):
with self.assertRaisesRegex(ValueError, r"dtype argument to.*"):
random.normal(self.make_key(0), (), dtype=jnp.int32)
def testRandomBroadcast(self):
"""Issue 4033"""
# test for broadcast issue in https://github.com/jax-ml/jax/issues/4033
key = lambda: self.make_key(0)
shape = (10, 2)
with jax.numpy_rank_promotion('allow'):
x1 = random.uniform(key(), shape, minval=jnp.zeros(2), maxval=jnp.ones(2))
x2 = random.randint(key(), shape, jnp.array([0, 1]), jnp.array([1, 2]))
assert x1.shape == shape
assert x2.shape == shape
def testMaxwellSample(self):
num_samples = 10**5
rng = lambda: self.make_key(0)
rand = lambda x: random.maxwell(x, (num_samples, ))
crand = jax.jit(rand)
loc = jtu.to_default_dtype(scipy.stats.maxwell.mean())
std = jtu.to_default_dtype(scipy.stats.maxwell.std())
uncompiled_samples = rand(rng())
compiled_samples = crand(rng())
for samples in [uncompiled_samples, compiled_samples]:
# Check first and second moments.
self.assertEqual((num_samples,), samples.shape)
self.assertAllClose(np.mean(samples), loc, atol=0., rtol=0.1)
self.assertAllClose(np.std(samples), std, atol=0., rtol=0.1)
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.maxwell().cdf)
@parameterized.named_parameters(
('test1', 4.0, 1.0),
('test2', 2.0, 3.0))
def testWeibullSample(self, concentration, scale):
num_samples = 10**5
rng = lambda: self.make_key(0)
rand = lambda x: random.weibull_min(x, scale, concentration, (num_samples,))
crand = jax.jit(rand)
loc = jtu.to_default_dtype(scipy.stats.weibull_min.mean(c=concentration, scale=scale))
std = jtu.to_default_dtype(scipy.stats.weibull_min.std(c=concentration, scale=scale))
uncompiled_samples = rand(rng())
compiled_samples = crand(rng())
for samples in [uncompiled_samples, compiled_samples]:
# Check first and second moments.
self.assertEqual((num_samples,), samples.shape)
self.assertAllClose(np.mean(samples), loc, atol=0., rtol=0.1)
self.assertAllClose(np.std(samples), std, atol=0., rtol=0.1)
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.weibull_min(
c=concentration, scale=scale).cdf)
@parameterized.named_parameters(
('test1', 4.0, 1.0),
('test2', 2.0, 3.0))
def testDoublesidedMaxwellSample(self, loc, scale):
num_samples = 10**4
rng = lambda: self.make_key(0)
rand = lambda key: random.double_sided_maxwell(
rng(), loc, scale, (num_samples,))
crand = jax.jit(rand)
mean = loc
std = np.sqrt(3.) * scale
uncompiled_samples = rand(rng())
compiled_samples = crand(rng())
# Compute the double sided maxwell CDF through the one sided maxwell cdf.
# This is done as follows:
# P(DSM <= x) = P (loc + scale * radamacher_sample * one_sided_sample <=x) =
# P (radamacher_sample * one_sided_sample <= (x - loc) / scale) =
# 1/2 P(one_sided_sample <= (x - loc) / scale)
# + 1/2 P( - one_sided_sample <= (x - loc) / scale) =
# 1/2 P(one_sided_sample <= (x - loc) / scale)
# + 1/2 P(one_sided_sample >= - (x - loc) / scale) =
# 1/2 CDF_one_maxwell((x - loc) / scale))
# + 1/2 (1 - CDF_one_maxwell(- (x - loc) / scale)))
def double_sided_maxwell_cdf(x, loc, scale):
pos = scipy.stats.maxwell().cdf((x - loc) / scale)
neg = (1 - scipy.stats.maxwell().cdf((-x + loc) / scale))
return (pos + neg) / 2
for samples in [uncompiled_samples, compiled_samples]:
# Check first and second moments.
self.assertEqual((num_samples,), samples.shape)
self.assertAllClose(samples.mean(), jtu.to_default_dtype(mean), atol=0., rtol=0.1)
self.assertAllClose(samples.std(), jtu.to_default_dtype(std), atol=0., rtol=0.1)
self._CheckKolmogorovSmirnovCDF(
samples, lambda x: double_sided_maxwell_cdf(x, loc, scale))
def testRadamacher(self):
rng = lambda: self.make_key(0)
num_samples = 10**5
rand = lambda x: random.rademacher(x, (num_samples,))
crand = jax.jit(rand)
uncompiled_samples = rand(rng())
compiled_samples = crand(rng())
for samples in [uncompiled_samples, compiled_samples]:
unique_values, counts = np.unique(samples, return_counts=True)
assert len(unique_values) == 2
assert len(counts) == 2
self.assertAllClose(
counts[0] / num_samples, 0.5, rtol=1e-02, atol=1e-02)
self.assertAllClose(
counts[1] / num_samples, 0.5, rtol=1e-02, atol=1e-02)
def testChoiceShapeIsNotSequenceError(self):
key = self.make_key(0)
with self.assertRaises(TypeError):
random.choice(key, 5, 2, replace=False)
with self.assertRaises(TypeError):
random.choice(key, 5, 2, replace=True)
def test_eval_shape_big_random_array(self):
def f(x):
return random.normal(self.make_key(x), (int(1e12),))
with jax.enable_checks(False): # check_jaxpr will materialize array
jax.eval_shape(f, 0) # doesn't error
@jtu.sample_product(
type_=["int", "np.array", "jnp.array"],
seed=[-1, 0, 1, (1 << 32) - 1, (1 << 63) - 1, np.uint64((1 << 64) - 1)],
)
def test_prng_jit_invariance(self, seed, type_):
if type_ == "int" and seed == (1 << 64) - 1:
self.skipTest("Expected failure: Python int too large.")
if not config.enable_x64.value and seed > np.iinfo(np.int32).max:
self.skipTest("Expected failure: Python int too large.")
type_ = {"int": int, "np.array": np.array, "jnp.array": jnp.array}[type_]
args_maker = lambda: [type_(seed)]
f = lambda s: random.key_data(self.make_key(s))
self._CompileAndCheck(f, args_maker)
def test_prng_errors(self):
seed = np.iinfo(np.int64).max + 1
with self.assertRaises(OverflowError):
self.make_key(seed)
with self.assertRaises(OverflowError):
jax.jit(self.make_key)(seed)
def test_random_split_doesnt_device_put_during_tracing(self):
key = self.make_key(1).block_until_ready()
with jtu.count_device_put() as count:
jax.jit(random.split)(key)
self.assertLessEqual(count(), 1) # 1 for the argument device_put
@jtu.sample_product(dtype=int_dtypes + uint_dtypes)
def test_randint_bounds(self, dtype):
min = np.iinfo(dtype).min
max = np.iinfo(dtype).max
key = lambda: self.make_key(1701)
shape = (10,)
if np.iinfo(dtype).bits < np.iinfo(dtypes.canonicalize_dtype(int)).bits:
expected = random.randint(key(), shape, min, max + 1, dtype)
self.assertArraysEqual(expected, random.randint(key(), shape, min - 12345, max + 12345, dtype))
else:
self.assertRaises(OverflowError, random.randint, key(), shape, min - 12345, max + 12345, dtype)
def test_randint_out_of_range(self):
key = self.make_key(0)
r = random.randint(key, (10,), 255, 256, np.uint8)
self.assertAllClose(r, jnp.full_like(r, 255))
key = self.make_key(0)
r = random.randint(key, (1000,), -128, 128, np.int8)
self.assertGreater((r == -128).sum(), 0)
self.assertGreater((r == 127).sum(), 0)
key = self.make_key(0)
r = random.randint(key, (1000,), -1000, 1000, np.uint8)
self.assertGreater((r == 0).sum(), 0)
self.assertGreater((r == 255).sum(), 0)
def test_large_prng(self):
# https://github.com/jax-ml/jax/issues/11010
def f():
return random.uniform(
self.make_key(3), (308000000, 128), dtype=jnp.bfloat16)
# TODO(jakevdp): key reuse checks for this OOM because of slice masking.
# Can we fix this?
with jax.debug_key_reuse(False):
# just lower, don't run, takes too long
jax.jit(f).lower()
@jtu.sample_product(shape=[(3, 4)],
logits_shape_base=[(3, 4), (3, 1), (1, 4)],
axis=[-3, -2, -1, 0, 1, 2])
def test_categorical_shape_argument(self, shape, logits_shape_base, axis):
# https://github.com/jax-ml/jax/issues/13124
logits_shape = list(logits_shape_base)
logits_shape.insert(axis % (len(logits_shape_base) + 1), 10)
assert logits_shape[axis] == 10
logits = jnp.ones(logits_shape)
samples = random.categorical(self.make_key(0), logits=logits,
axis=axis, shape=shape)
self.assertEqual(samples.shape, shape)
@jtu.sample_product(
df = [0.2, 1., 10., 100.],
dtype=jtu.dtypes.floating)
def testChisquare(self, df, dtype):
key = lambda: self.make_key(1)
def rand(key, df):
return random.chisquare(key, df, shape=(10000,), dtype=dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key(), df)
compiled_samples = crand(key(), df)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.chi2(df).cdf)
@jtu.sample_product(
dfnum = [1., 2., 10. ,100.],
dfden = [1. ,2., 10., 100.],
dtype=jtu.dtypes.floating)
def testF(self, dfnum, dfden, dtype):
key = lambda: self.make_key(9)
rand = lambda key: random.f(key, dfnum, dfden, shape = (10000, ), dtype = dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.f(dfnum, dfden).cdf)
@jtu.sample_product(
scale= [0.2, 1., 2., 10. ,100.],
dtype=jtu.dtypes.floating)
def testRayleigh(self, scale, dtype):
key = lambda: self.make_key(0)
rand = lambda key: random.rayleigh(key, scale, shape = (10000, ), dtype = dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.rayleigh(scale=scale).cdf)
@jtu.sample_product(
mean= [0.2, 1., 2., 10. ,100.],
dtype=jtu.dtypes.floating)
def testWald(self, mean, dtype):
key = lambda: self.make_key(0)
rand = lambda key: random.wald(key, mean, shape=(10000, ), dtype=dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.invgauss(mu=mean).cdf)
@jtu.sample_product(
p=[0.2, 0.3, 0.4, 0.5 ,0.6],
dtype=jtu.dtypes.supported([np.int16, np.int32, np.int64]))
def testGeometric(self, p, dtype):
key = lambda: self.make_key(1)
rand = lambda key: random.geometric(key, p, shape=(10000, ), dtype=dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckChiSquared(samples, scipy.stats.geom(p).pmf)
self.assertAllClose(samples.mean(), 1 / p, rtol=0.02, check_dtypes=False)
self.assertAllClose(samples.var(), (1 - p) / (p * p) , rtol=0.05,
check_dtypes=False)
@jtu.sample_product(
left = [0.2, 0.5, 1., 2.],
mode = [3., 5., 8., 9.],
right= [10., 20., 30., 40.],
dtype= jtu.dtypes.floating)
def testTriangular(self, left, mode, right, dtype):
key = lambda: self.make_key(1)
rand = lambda key: random.triangular(key, left, mode, right, shape=(10000,),
dtype=dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.triang(
(mode - left) / (right - left), loc=left, scale=right - left).cdf)
@jtu.sample_product(
sigma = [0.2, 0.5, 1., 2.],
dtype=jtu.dtypes.floating)
def testLogNormal(self, sigma, dtype):
key = lambda: self.make_key(0)
rand = lambda key: random.lognormal(key, sigma, shape=(10000,), dtype=dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.lognorm(s=sigma).cdf)
@jtu.sample_product(
n= [5, 13, 21, 53, 500],
p= [0.1, 0.3, 0.5, 0.7, 0.9],
dtype= jtu.dtypes.floating)
def testBinomialSample(self, n, p, dtype):
key = lambda: self.make_key(12)
rand = lambda key: random.binomial(key, n, p, shape=(12000,), dtype=dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key())
compiled_samples = crand(key())
pmf = lambda x: scipy.stats.binom(n, p).pmf(x)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckChiSquared(samples.astype(int), pmf, pval=1e-3)
self.assertAllClose(samples.mean(), n * p, rtol=0.025, check_dtypes=False)
self.assertAllClose(samples.var(), n * p * (1 - p) , rtol=0.036,
check_dtypes=False)
def testBinomialCornerCases(self):
key = lambda: self.make_key(0)
# corner case n
n = jnp.array([-1, 0, jnp.nan, jnp.inf])
samples1 = random.binomial(key(), n, 0.5, shape=(4,))
# corner case p
p = jnp.array([jnp.nan, 0, -0.1, 1.1])
samples2 = random.binomial(key(), 5, p, shape=(4,))
# corner case n and p
# expect nan or illegal will lead to nan
n_cc = jnp.array([jnp.inf, -1, jnp.inf])
p_cc = jnp.array([jnp.nan, jnp.nan, -0.1])
samples3 = random.binomial(key(), n_cc, p_cc, shape=(3,))
self.assertArraysAllClose(samples1, jnp.array([jnp.nan, 0., jnp.nan, jnp.inf]), check_dtypes=False)
self.assertArraysAllClose(samples2, jnp.array([jnp.nan, 0., jnp.nan, jnp.nan]), check_dtypes=False)
self.assertArraysAllClose(samples3, jnp.array([jnp.nan, jnp.nan, jnp.nan]), check_dtypes=False)
def test_binomial_dtypes(self):
# Regression test for https://github.com/jax-ml/jax/pull/25688#discussion_r1938010569
key = jax.random.key(0)
n = jax.numpy.float16(100)
p = jax.numpy.float16(0.5)
jax.random.binomial(key, n, p) # doesn't error
def testMultinomialExample(self):
key = random.key(0)
probs = jnp.array([
[0.5, 0.2, 0.3],
[0.1, 0.2, 0.7],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[0.5, 0.0, 0.5],
])
trials = 1e5
counts = random.multinomial(key, trials, probs)
freqs = counts / trials
self.assertAllClose(freqs, probs, atol=1e-2)
@jtu.sample_product(
categories=[1, 2, 3, 5, 7, 11],
trials=[1, 2, 3, 5, 7, 11],
dtype=[jnp.float32],
)
def testMultinomialNumpy(
self,
categories,
trials,
dtype,
test_samples=10**6,
tolerance=1e-1,
):
probs = jnp.linspace(-1, 2, categories)[::-1] ** 2
probs /= probs.sum(-1, keepdims=True)
rng = np.random.default_rng(0)
counts_numpy = jnp.array(rng.multinomial(trials, probs, size=test_samples), dtype)
shape = (test_samples,) + probs.shape
key = random.key(0)
counts_jax = random.multinomial(key, trials, probs, shape=shape, dtype=dtype)
assert counts_jax.shape == shape
energy_distance = get_energy_distance(counts_numpy, counts_jax)
assert energy_distance < tolerance
@jtu.sample_product([
dict(shape=shape, outcomes=outcomes)
for shape in [(5,), (2, 3), (2, 3, 5)]
for outcomes in [2, 3, 4]
])
def testMultinomialShape(self, shape, outcomes):
key = random.key(0)
key, subkey = random.split(key)
probs = random.dirichlet(subkey, jnp.ones(outcomes))
trials = 1e5
counts = random.multinomial(key, trials, probs, shape=(*shape, *probs.shape))
freqs = counts / trials
self.assertAllClose(freqs, jnp.broadcast_to(probs, freqs.shape), atol=1e-2)
@jtu.sample_product([
dict(n_dtype=n_dtype, p_dtype=p_dtype, dtype=dtype)
for n_dtype in jtu.dtypes.all_floating
for p_dtype in jtu.dtypes.all_floating
for dtype in jtu.dtypes.all_floating
])
@jax.numpy_dtype_promotion('standard')
def testMultinomialDtype(self, n_dtype, p_dtype, dtype):
key = random.key(0)
n = jnp.astype(10, n_dtype)
p = jnp.astype(jnp.ones(3) / 3, p_dtype)
random.multinomial(key, n, p)
def test_batched_key_errors(self):
keys = lambda: jax.random.split(self.make_key(0))
msg = "{} accepts a single key, but was given a key array of shape.*"
# Check a handful of functions that are expected to error.
with self.assertRaisesRegex(ValueError, msg.format('bits')):
jax.random.bits(keys(), shape=(2,))
with self.assertRaisesRegex(ValueError, msg.format('chisquare')):
jax.random.chisquare(keys(), 1.0, shape=(2,))
with self.assertRaisesRegex(ValueError, msg.format('dirichlet')):
jax.random.dirichlet(keys(), jnp.arange(2.0), shape=(2,))
with self.assertRaisesRegex(ValueError, msg.format('gamma')):
jax.random.gamma(keys(), 1.0, shape=(2,))
with self.assertRaisesRegex(ValueError, msg.format('loggamma')):
jax.random.loggamma(keys(), 1.0, shape=(2,))
with self.assertRaisesRegex(ValueError, msg.format('fold_in')):
jax.random.fold_in(keys(), 0)
with self.assertRaisesRegex(ValueError, msg.format('split')):
jax.random.split(keys())
# Shouldn't error or warn:
with self.assertNoWarnings():
jax.random.key_data(keys())
jax.random.key_impl(keys())
def get_energy_distance(samples_1, samples_2):
"""
Estimates the energy distance between two distributions, given
batches of independent samples from each.
For more information, see https://en.wikipedia.org/wiki/Energy_distance.
"""
x, xp = jnp.split(samples_1, 2)
y, yp = jnp.split(samples_2, 2)
return (
2 * jnp.linalg.norm(x - y, axis=-1)
- jnp.linalg.norm(x - xp, axis=-1)
- jnp.linalg.norm(y - yp, axis=-1)
).mean(0)
threefry_seed = prng_internal.threefry_seed
threefry_split = prng_internal.threefry_split
threefry_random_bits = prng_internal.threefry_random_bits
threefry_fold_in = prng_internal.threefry_fold_in
def _double_threefry_seed(seed):
int_t = seed.dtype.type if hasattr(seed, 'dtype') else type(seed)
s1, s2 = seed, seed ^ int_t(3)
return jnp.vstack([threefry_seed(s1),
threefry_seed(s2)])
def _double_threefry_split(key, shape):
return vmap(
threefry_split, (0, None), len(shape))(key, shape)
def _double_threefry_random_bits(key, bit_width, shape):
bits0 = threefry_random_bits(key[0], bit_width, shape)
bits1 = threefry_random_bits(key[1], bit_width, shape)
del bits1
# TODO(frostig): Currently this behaves like normal threefry, to
# avoid a few probabilistic test failures. Ideally we might want to
# test different generation behavior here (e.g. `bits0 ^ bits1`).
return bits0
def _double_threefry_fold_in(key, data):
return jnp.vstack([threefry_fold_in(key[0], data),
threefry_fold_in(key[1], data)])
double_threefry_prng_impl = prng_internal.PRNGImpl(
key_shape=(2, 2),
seed=_double_threefry_seed,
split=_double_threefry_split,
random_bits=_double_threefry_random_bits,
fold_in=_double_threefry_fold_in,
tag='fry2')
@jtu.with_config(jax_default_prng_impl='threefry2x32')
class LaxRandomWithCustomPRNGTest(LaxRandomTest):
def make_key(self, seed):
return prng_internal.random_seed(seed, impl=double_threefry_prng_impl)
def test_split_shape(self):
key = self.make_key(73)
keys = random.split(key, 10)
self.assertEqual(keys.shape, (10,))
def test_vmap_fold_in_shape(self):
# broadcast with scalar
keys = lambda: random.split(self.make_key(73), 2)
msgs = jnp.arange(3)
out = vmap(lambda i: random.fold_in(keys()[0], i))(msgs)
self.assertEqual(out.shape, (3,))
out = vmap(lambda k: random.fold_in(k, msgs[0]))(keys())
self.assertEqual(out.shape, (2,))
out = vmap(random.fold_in, in_axes=(None, 0))(keys()[0], msgs)
self.assertEqual(out.shape, (3,))
out = vmap(random.fold_in, in_axes=(0, None))(keys(), msgs[0])
self.assertEqual(out.shape, (2,))
# vmap all
msgs = jnp.arange(2)
out = vmap(random.fold_in)(keys(), msgs)
self.assertEqual(out.shape, (2,))
# nested vmap
keys = lambda: random.split(self.make_key(73), 2 * 3).reshape((2, 3))
msgs = jnp.arange(2 * 3).reshape((2, 3))
out = vmap(vmap(random.fold_in), in_axes=(0, 1))(keys(), msgs.T)
self.assertEqual(out.shape, (2, 3))
out = vmap(vmap(random.fold_in), in_axes=(1, 0))(keys(), msgs.T)
self.assertEqual(out.shape, (3, 2))
@jax.debug_key_reuse(False)
def test_vmap_split_mapped_key(self):
key = self.make_key(73)
mapped_keys = random.split(key, num=3)
forloop_keys = [random.split(k) for k in mapped_keys]
vmapped_keys = vmap(random.split)(mapped_keys)
self.assertEqual(vmapped_keys.shape, (3, 2))
for fk, vk in zip(forloop_keys, vmapped_keys):
self.assertArraysEqual(random.key_data(fk),
random.key_data(vk))
def test_cannot_add(self):
key = self.make_key(73)
self.assertRaisesRegex(
TypeError, r'add does not accept dtypes key<.*>, int.*',
lambda: key + 47)
def test_grad_of_prng_key(self):
key = self.make_key(73)
with self.assertRaisesRegex(TypeError, 'grad requires real- or complex-valued inputs'):
jax.grad(lambda x: 1.)(key)
out = jax.grad(lambda x: 1., allow_int=True)(key)
self.assertArraysEqual(out, np.zeros(key.shape, jax.dtypes.float0))
@jtu.with_config(jax_default_prng_impl='rbg')
class LaxRandomWithRBGPRNGTest(LaxRandomTest):
def make_key(self, seed):
return random.PRNGKey(seed, impl='rbg')
def test_split_shape(self):
key = self.make_key(73)
keys = random.split(key, 10)
self.assertEqual(keys.shape, (10, *key.shape))
@jax.debug_key_reuse(False)
def test_vmap_fold_in_shape(self):
# broadcast with scalar
keys = random.split(self.make_key(73), 2)
msgs = jnp.arange(3)
out = vmap(lambda i: random.fold_in(keys[0], i))(msgs)
self.assertEqual(out.shape, (3, *keys[0].shape))
out = vmap(random.fold_in, in_axes=(None, 0))(keys[0], msgs)
self.assertEqual(out.shape, (3, *keys[0].shape))
out = vmap(lambda k: random.fold_in(k, msgs[0]))(keys)
self.assertEqual(out.shape, keys.shape)
out = vmap(random.fold_in, in_axes=(0, None))(keys, msgs[0])
self.assertEqual(out.shape, keys.shape)
@jax.debug_key_reuse(False)
def test_vmap_split_not_mapped_key(self):
key = self.make_key(73)
single_split_key = random.split(key)
vmapped_keys = vmap(lambda _: random.split(key))(jnp.zeros(3,))
self.assertEqual(vmapped_keys.shape, (3, 2, *key.shape))
for vk in vmapped_keys:
self.assertArraysEqual(random.key_data(vk),
random.key_data(single_split_key))
@jax.debug_key_reuse(False)
def test_vmap_split_mapped_key_shape(self):
key = self.make_key(73)
mapped_keys = random.split(key, num=3)
vmapped_keys = vmap(random.split)(mapped_keys)
self.assertEqual(vmapped_keys.shape, (3, 2, *key.shape))
@jax.debug_key_reuse(False)
def test_vmap_split_mapped_key_values(self):
key = self.make_key(73)
mapped_keys = random.split(key, num=3)
vmapped_keys = vmap(random.split)(mapped_keys)
ref_keys = [random.split(k) for k in mapped_keys]
for rk, vk in zip(ref_keys, vmapped_keys):
self.assertArraysEqual(random.key_data(rk),
random.key_data(vk))
@jax.debug_key_reuse(False)
def test_vmap_random_bits_shape(self):
rand_fun = lambda key, shape=(): random.randint(key, shape, 0, 100)
key = self.make_key(73)
mapped_keys = random.split(key, num=3)
rand_nums = vmap(rand_fun)(mapped_keys)
self.assertEqual(rand_nums.shape, (3,))
@jtu.skip_on_devices("tpu")
@jax.debug_key_reuse(False)
def test_vmap_random_bits_value(self):
rand_fun = lambda key, shape=(): random.randint(key, shape, 0, 100)
key = self.make_key(73)
mapped_keys = random.split(key, num=3)
rand_nums = vmap(rand_fun)(mapped_keys)
ref_nums = rand_fun(mapped_keys[0], shape=(3,))
self.assertArraysEqual(rand_nums, ref_nums)
def test_vmap_random_bits_distribution(self):
dtype = jnp.float32
keys = lambda: jax.random.split(self.make_key(0), 10)
def rand(key):
nums = jax.vmap(lambda key: random.uniform(key, (1000,), dtype))(key)
return nums.flatten()
crand = jax.jit(rand)
uncompiled_samples = rand(keys())
compiled_samples = crand(keys())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckCollisions(samples, jnp.finfo(dtype).nmant)
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf,
pval=0.005)
def test_cannot_add(self):
key = self.make_key(73)
if not jnp.issubdtype(key.dtype, dtypes.prng_key):
raise SkipTest('relies on typed key arrays')
self.assertRaisesRegex(
TypeError, r'add does not accept dtypes key<.*>, int.*',
lambda: key + 47)
def test_grad_of_prng_key(self):
key = self.make_key(73)
with self.assertRaisesRegex(TypeError, 'grad requires real- or complex-valued inputs'):
jax.grad(lambda x: 1.)(key)
out = jax.grad(lambda x: 1., allow_int=True)(key)
self.assertArraysEqual(out, np.zeros(key.shape, jax.dtypes.float0))
def test_random_split_doesnt_device_put_during_tracing(self):
return # this test doesn't apply to the RBG PRNG
def test_randint_out_of_range(self):
# TODO(mattjj): enable this test if/when RngBitGenerator supports it
raise SkipTest('8-bit types not supported with RBG PRNG')
@jtu.with_config(jax_default_prng_impl='unsafe_rbg')
class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest):
def make_key(self, seed):
return random.PRNGKey(seed, impl="unsafe_rbg")
@jtu.skip_on_devices("tpu")
@jax.debug_key_reuse(False)
def test_vmap_split_mapped_key_values(self):
key = self.make_key(73)
mapped_keys = random.split(key, num=3)
vmapped_keys = vmap(random.split)(mapped_keys)
ref_keys = random.split(mapped_keys[0], (3, 2))
self.assertArraysEqual(random.key_data(vmapped_keys),
random.key_data(ref_keys))
def _sampler_unimplemented_with_custom_prng(*args, **kwargs):
raise SkipTest('sampler only implemented for default RNG')
for test_prefix in [
'testPoisson',
'testPoissonBatched',
'testPoissonShape',
'testPoissonZeros',
]:
for attr in dir(LaxRandomTest):
if attr.startswith(test_prefix):
setattr(LaxRandomWithCustomPRNGTest, attr,
_sampler_unimplemented_with_custom_prng)
setattr(LaxRandomWithRBGPRNGTest, attr,
_sampler_unimplemented_with_custom_prng)
setattr(LaxRandomWithUnsafeRBGPRNGTest, attr,
_sampler_unimplemented_with_custom_prng)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())