rocm_jax/tests/random_test.py
2020-09-23 20:15:32 -07:00

869 lines
32 KiB
Python

# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from unittest import SkipTest
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import scipy.linalg
import scipy.special
import scipy.stats
from jax import api
from jax import core
from jax import grad
from jax import lax
from jax import numpy as jnp
from jax import random
from jax import test_util as jtu
from jax import vmap
from jax.interpreters import xla
from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
float_dtypes = jtu.dtypes.all_floating
int_dtypes = jtu.dtypes.all_integer
uint_dtypes = jtu.dtypes.all_unsigned
class LaxRandomTest(jtu.JaxTestCase):
def _CheckCollisions(self, samples, nbits):
fail_prob = 0.01 # conservative bound on statistical fail prob by Chebyshev
nitems = len(samples)
nbins = 2 ** nbits
nexpected = nbins * (1 - ((nbins - 1) / nbins) ** nitems)
ncollisions = len(np.unique(samples))
sq_percent_deviation = ((ncollisions - nexpected) / nexpected) ** 2
self.assertLess(sq_percent_deviation, 1 / np.sqrt(nexpected * fail_prob))
def _CheckKolmogorovSmirnovCDF(self, samples, cdf):
fail_prob = 0.01 # conservative bound on statistical fail prob by Kolmo CDF
self.assertGreater(scipy.stats.kstest(samples, cdf).pvalue, fail_prob)
def _CheckChiSquared(self, samples, pmf):
alpha = 0.01 # significance level, threshold for p-value
values, actual_freq = np.unique(samples, return_counts=True)
expected_freq = pmf(values) * samples.size
# per scipy: "A typical rule is that all of the observed and expected
# frequencies should be at least 5."
valid = (actual_freq > 5) & (expected_freq > 5)
self.assertGreater(valid.sum(), 1,
msg='not enough valid frequencies for chi-squared test')
_, p_value = scipy.stats.chisquare(
actual_freq[valid], expected_freq[valid])
self.assertGreater(
p_value, alpha,
msg=f'Failed chi-squared test with p={p_value}.\n'
'Expected vs. actual frequencies:\n'
f'{expected_freq[valid]}\n{actual_freq[valid]}')
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in [np.float32, np.float64]))
def testNumpyAndXLAAgreeOnFloatEndianness(self, dtype):
if not FLAGS.jax_enable_x64 and jnp.issubdtype(dtype, np.float64):
raise SkipTest("can't test float64 agreement")
bits_dtype = np.uint32 if jnp.finfo(dtype).bits == 32 else np.uint64
numpy_bits = np.array(1., dtype).view(bits_dtype)
xla_bits = api.jit(
lambda: lax.bitcast_convert_type(np.array(1., dtype), bits_dtype))()
self.assertEqual(numpy_bits, xla_bits)
def testThreefry2x32(self):
# We test the hash by comparing to known values provided in the test code of
# the original reference implementation of Threefry. For the values, see
# https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_threefry.cpp#L30-L32
def result_to_hex(result):
return tuple([hex(x.copy()).rstrip("L") for x in result])
expected = ("0x6b200159", "0x99ba4efe")
result = random.threefry_2x32(np.uint32([0, 0]), np.uint32([0, 0]))
self.assertEqual(expected, result_to_hex(result))
expected = ("0x1cb996fc", "0xbb002be7")
result = random.threefry_2x32(np.uint32([-1, -1]), np.uint32([-1, -1]))
self.assertEqual(expected, result_to_hex(result))
expected = ("0xc4923a9c", "0x483df7a0")
result = random.threefry_2x32(
np.uint32([0x13198a2e, 0x03707344]),
np.uint32([0x243f6a88, 0x85a308d3]))
self.assertEqual(expected, result_to_hex(result))
def testThreefry2x32Large(self):
n = 10000000
result = random.threefry_2x32(
(np.uint32(0x13198a2e), np.uint32(0x03707344)),
jnp.concatenate([
jnp.full((n,), 0x243f6a88, jnp.uint32),
jnp.full((n,), 0x85a308d3, jnp.uint32)
]))
np.testing.assert_equal(result[:n], np.full((n,), 0xc4923a9c, dtype=np.uint32))
np.testing.assert_equal(result[n:], np.full((n,), 0x483df7a0, dtype=np.uint32))
def testRngRandomBitsViewProperty(self):
# TODO: add 64-bit if it ever supports this property.
# TODO: will this property hold across endian-ness?
N = 10
key = random.PRNGKey(1701)
nbits = [8, 16, 32]
rand_bits = [random._random_bits(key, n, (N * 64 // n,)) for n in nbits]
rand_bits_32 = np.array([np.array(r).view(np.uint32) for r in rand_bits])
assert np.all(rand_bits_32 == rand_bits_32[0])
def testRngRandomBits(self):
# Test specific outputs to ensure consistent random values between JAX versions.
key = random.PRNGKey(1701)
bits8 = random._random_bits(key, 8, (3,))
expected8 = np.array([216, 115, 43], dtype=np.uint8)
self.assertArraysEqual(bits8, expected8)
bits16 = random._random_bits(key, 16, (3,))
expected16 = np.array([41682, 1300, 55017], dtype=np.uint16)
self.assertArraysEqual(bits16, expected16)
bits32 = random._random_bits(key, 32, (3,))
expected32 = np.array([56197195, 4200222568, 961309823], dtype=np.uint32)
self.assertArraysEqual(bits32, expected32)
bits64 = random._random_bits(key, 64, (3,))
if FLAGS.jax_enable_x64:
expected64 = np.array([3982329540505020460, 16822122385914693683,
7882654074788531506], dtype=np.uint64)
else:
expected64 = np.array([676898860, 3164047411, 4010691890], dtype=np.uint32)
self.assertArraysEqual(bits64, expected64)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in float_dtypes))
def testRngUniform(self, dtype):
key = random.PRNGKey(0)
rand = lambda key: random.uniform(key, (10000,), dtype)
crand = api.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckCollisions(samples, jnp.finfo(dtype).nmant)
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in int_dtypes + uint_dtypes))
def testRngRandint(self, dtype):
lo = 5
hi = 10
key = random.PRNGKey(0)
rand = lambda key: random.randint(key, (10000,), lo, hi, dtype)
crand = api.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
for samples in [uncompiled_samples, compiled_samples]:
self.assertTrue(np.all(lo <= samples))
self.assertTrue(np.all(samples < hi))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in [np.float16, np.float32, np.float64]))
def testNormal(self, dtype):
key = random.PRNGKey(0)
rand = lambda key: random.normal(key, (10000,), dtype)
crand = api.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.norm().cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in [np.float32, np.float64, np.int32, np.int64]))
def testShuffle(self, dtype):
key = random.PRNGKey(0)
x = np.arange(100).astype(dtype)
rand = lambda key: random.shuffle(key, x)
crand = api.jit(rand)
with self.assertWarns(FutureWarning):
perm1 = rand(key)
with self.assertWarns(FutureWarning):
perm2 = crand(key)
self.assertAllClose(perm1, perm2)
self.assertFalse(np.all(perm1 == x)) # seems unlikely!
self.assertAllClose(np.sort(perm1), x, check_dtypes=False)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_shape={}_replace={}_weighted={}_array_input={}".format(
np.dtype(dtype).name, shape, replace, weighted, array_input),
"dtype": dtype, "shape": shape, "replace": replace,
"weighted": weighted, "array_input": array_input}
for dtype in [np.float32, np.float64, np.int32, np.int64]
for shape in [(), (5,), (4, 5)]
for replace in [True, False]
for weighted in [True, False]
for array_input in [False, 'jnp', 'np']))
def testChoice(self, dtype, shape, replace, weighted, array_input):
N = 100
key = random.PRNGKey(0)
x = (N if not array_input else
jnp.arange(N, dtype=dtype) if array_input == 'jnp' else
np.arange(N, dtype=dtype))
p = None if not weighted else jnp.arange(N)
rand = lambda key: random.choice(key, x, shape, p=p, replace=replace)
crand = api.jit(rand)
sample1 = rand(key)
sample2 = crand(key)
self.assertEqual(shape, sample1.shape)
if array_input == 'jnp':
self.assertEqual(x.dtype, sample1.dtype)
if not replace:
assert len(np.unique(sample1)) == len(np.ravel(sample1))
self.assertAllClose(sample1, sample2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
"dtype": dtype, "shape": shape}
for dtype in [np.float32, np.float64, np.int32, np.int64]
for shape in [100, (10, 10), (10, 5, 2)]))
def testPermutationArray(self, dtype, shape):
key = random.PRNGKey(0)
x = jnp.arange(np.prod(shape)).reshape(shape).astype(dtype)
rand = lambda key: random.permutation(key, x)
crand = api.jit(rand)
perm1 = rand(key)
perm2 = crand(key)
self.assertAllClose(perm1, perm2)
self.assertFalse(np.all(perm1 == x)) # seems unlikely!
self.assertAllClose(np.sort(perm1.ravel()), x.ravel(), check_dtypes=False)
self.assertArraysAllClose(
x, jnp.arange(np.prod(shape)).reshape(shape).astype(dtype))
def testPermutationInteger(self):
key = random.PRNGKey(0)
x = 100
rand = lambda key: random.permutation(key, x)
crand = api.jit(rand)
perm1 = rand(key)
perm2 = crand(key)
self.assertAllClose(perm1, perm2)
self.assertEqual(perm1.dtype, perm2.dtype)
self.assertFalse(np.all(perm1 == np.arange(100))) # seems unlikely!
self.assertAllClose(np.sort(perm1), np.arange(100), check_dtypes=False)
def testPermutationErrors(self):
key = random.PRNGKey(0)
with self.assertRaises(TypeError):
random.permutation(key, 10.)
with self.assertRaises(core.ConcretizationTypeError):
api.jit(random.permutation)(key, 10)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_p={}_dtype={}".format(p, np.dtype(dtype).name),
"p": p, "dtype": dtype}
for p in [0.1, 0.5, 0.9]
for dtype in [np.float32, np.float64]))
def testBernoulli(self, p, dtype):
key = random.PRNGKey(0)
p = np.array(p, dtype=dtype)
rand = lambda key, p: random.bernoulli(key, p, (10000,))
crand = api.jit(rand)
uncompiled_samples = rand(key, p)
compiled_samples = crand(key, p)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckChiSquared(samples, scipy.stats.bernoulli(p).pmf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_p={}_{}_{}".format(p, np.dtype(dtype).name, sample_shape),
"p": p, "axis": axis, "dtype": dtype, 'sample_shape': sample_shape}
for (p, axis) in [
([.25] * 4, -1),
([.1, .2, .3, .4], -1),
([[.5, .5], [.1, .9]], 1),
([[.5, .1], [.5, .9]], 0),
]
for sample_shape in [(10000,), (5000, 2)]
for dtype in [np.float32, np.float64]))
def testCategorical(self, p, axis, dtype, sample_shape):
key = random.PRNGKey(0)
p = np.array(p, dtype=dtype)
logits = np.log(p) - 42 # test unnormalized
out_shape = tuple(np.delete(logits.shape, axis))
shape = sample_shape + out_shape
rand = partial(random.categorical, shape=shape, axis=axis)
crand = api.jit(rand)
uncompiled_samples = rand(key, logits)
compiled_samples = crand(key, logits)
if axis < 0:
axis += len(logits.shape)
for samples in [uncompiled_samples, compiled_samples]:
assert samples.shape == shape
samples = jnp.reshape(samples, (10000,) + out_shape)
if len(p.shape[:-1]) > 0:
ps = np.transpose(p, (1, 0)) if axis == 0 else p
for cat_samples, cat_p in zip(samples.transpose(), ps):
self._CheckChiSquared(cat_samples, pmf=lambda x: cat_p[x])
else:
self._CheckChiSquared(samples, pmf=lambda x: p[x])
def testBernoulliShape(self):
key = random.PRNGKey(0)
x = random.bernoulli(key, np.array([0.2, 0.3]), shape=(3, 2))
assert x.shape == (3, 2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_a={}_b={}_dtype={}".format(a, b, np.dtype(dtype).name),
"a": a, "b": b, "dtype": dtype}
for a in [0.2, 5.]
for b in [0.2, 5.]
for dtype in [np.float64])) # NOTE: KS test fails with float32
def testBeta(self, a, b, dtype):
if not FLAGS.jax_enable_x64:
raise SkipTest("skip test except on X64")
key = random.PRNGKey(0)
rand = lambda key, a, b: random.beta(key, a, b, (10000,), dtype)
crand = api.jit(rand)
uncompiled_samples = rand(key, a, b)
compiled_samples = crand(key, a, b)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.beta(a, b).cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in [np.float16, np.float32, np.float64]))
def testCauchy(self, dtype):
key = random.PRNGKey(0)
rand = lambda key: random.cauchy(key, (10000,), dtype)
crand = api.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.cauchy().cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_alpha={}_dtype={}".format(alpha, np.dtype(dtype).name),
"alpha": alpha, "dtype": dtype}
for alpha in [
np.array([0.2, 1., 5.]),
]
for dtype in [np.float32, np.float64]))
@jtu.skip_on_devices("tpu") # TODO(mattjj): slow compilation times
def testDirichlet(self, alpha, dtype):
key = random.PRNGKey(0)
rand = lambda key, alpha: random.dirichlet(key, alpha, (10000,), dtype)
crand = api.jit(rand)
uncompiled_samples = rand(key, alpha)
compiled_samples = crand(key, alpha)
for samples in [uncompiled_samples, compiled_samples]:
self.assertAllClose(samples.sum(-1), np.ones(10000, dtype=dtype))
alpha_sum = sum(alpha)
for i, a in enumerate(alpha):
self._CheckKolmogorovSmirnovCDF(samples[..., i], scipy.stats.beta(a, alpha_sum - a).cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in float_dtypes))
def testExponential(self, dtype):
key = random.PRNGKey(0)
rand = lambda key: random.exponential(key, (10000,), dtype)
crand = api.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.expon().cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_a={}_dtype={}".format(a, np.dtype(dtype).name),
"a": a, "dtype": dtype}
for a in [0.1, 1., 10.]
for dtype in [np.float32, np.float64]))
def testGamma(self, a, dtype):
key = random.PRNGKey(0)
rand = lambda key, a: random.gamma(key, a, (10000,), dtype)
crand = api.jit(rand)
uncompiled_samples = rand(key, a)
compiled_samples = crand(key, a)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gamma(a).cdf)
def testGammaShape(self):
key = random.PRNGKey(0)
x = random.gamma(key, np.array([0.2, 0.3]), shape=(3, 2))
assert x.shape == (3, 2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_a={}".format(alpha), "alpha": alpha}
for alpha in [1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]))
def testGammaGrad(self, alpha):
rng = random.PRNGKey(0)
alphas = np.full((100,), alpha)
z = random.gamma(rng, alphas)
actual_grad = api.grad(lambda x: random.gamma(rng, x).sum())(alphas)
eps = 0.01 * alpha / (1.0 + np.sqrt(alpha))
cdf_dot = (scipy.stats.gamma.cdf(z, alpha + eps)
- scipy.stats.gamma.cdf(z, alpha - eps)) / (2 * eps)
pdf = scipy.stats.gamma.pdf(z, alpha)
expected_grad = -cdf_dot / pdf
self.assertAllClose(actual_grad, expected_grad, check_dtypes=True,
rtol=2e-2 if jtu.device_under_test() == "tpu" else 7e-4)
def testGammaGradType(self):
# Regression test for https://github.com/google/jax/issues/2130
key = random.PRNGKey(0)
a = jnp.array(1., dtype=jnp.float32)
b = jnp.array(3., dtype=jnp.float32)
f = lambda x, y: random.gamma(key=key, a=x, dtype=jnp.float32) / y
# Should not crash with a type error.
api.vjp(f, a, b)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_lam={}_dtype={}".format(lam, np.dtype(dtype).name),
"lam": lam, "dtype": np.dtype(dtype)}
for lam in [0.5, 3, 9, 11, 50, 500]
for dtype in [np.int16, np.int32, np.int64]))
def testPoisson(self, lam, dtype):
key = random.PRNGKey(0)
rand = lambda key, lam: random.poisson(key, lam, (10000,), dtype)
crand = api.jit(rand)
uncompiled_samples = rand(key, lam)
compiled_samples = crand(key, lam)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckChiSquared(samples, scipy.stats.poisson(lam).pmf)
# TODO(shoyer): determine error bounds for moments more rigorously (e.g.,
# based on the central limit theorem).
self.assertAllClose(samples.mean(), lam, rtol=0.01, check_dtypes=False)
self.assertAllClose(samples.var(), lam, rtol=0.03, check_dtypes=False)
def testPoissonBatched(self):
key = random.PRNGKey(0)
lam = jnp.concatenate([2 * jnp.ones(10000), 20 * jnp.ones(10000)])
samples = random.poisson(key, lam, shape=(20000,))
self._CheckChiSquared(samples[:10000], scipy.stats.poisson(2.0).pmf)
self._CheckChiSquared(samples[10000:], scipy.stats.poisson(20.0).pmf)
def testPoissonShape(self):
key = random.PRNGKey(0)
x = random.poisson(key, np.array([2.0, 20.0]), shape=(3, 2))
assert x.shape == (3, 2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in [np.float32, np.float64]))
def testGumbel(self, dtype):
key = random.PRNGKey(0)
rand = lambda key: random.gumbel(key, (10000,), dtype)
crand = api.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gumbel_r().cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in float_dtypes))
def testLaplace(self, dtype):
key = random.PRNGKey(0)
rand = lambda key: random.laplace(key, (10000,), dtype)
crand = api.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.laplace().cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in float_dtypes))
def testLogistic(self, dtype):
key = random.PRNGKey(0)
rand = lambda key: random.logistic(key, (10000,), dtype)
crand = api.jit(rand)
uncompiled_samples = rand(key)
compiled_samples = crand(key)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.logistic().cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_b={}_dtype={}".format(b, np.dtype(dtype).name),
"b": b, "dtype": dtype}
for b in [0.1, 1., 10.]
for dtype in [np.float32, np.float64]))
def testPareto(self, b, dtype):
key = random.PRNGKey(0)
rand = lambda key, b: random.pareto(key, b, (10000,), dtype)
crand = api.jit(rand)
uncompiled_samples = rand(key, b)
compiled_samples = crand(key, b)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.pareto(b).cdf)
def testParetoShape(self):
key = random.PRNGKey(0)
x = random.pareto(key, np.array([0.2, 0.3]), shape=(3, 2))
assert x.shape == (3, 2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_df={}_dtype={}".format(df, np.dtype(dtype).name),
"df": df, "dtype": dtype}
for df in [0.1, 1., 10.]
for dtype in [np.float32, np.float64]))
@jtu.skip_on_devices("cpu", "tpu") # TODO(phawkins): slow compilation times
def testT(self, df, dtype):
key = random.PRNGKey(0)
rand = lambda key, df: random.t(key, df, (10000,), dtype)
crand = api.jit(rand)
uncompiled_samples = rand(key, df)
compiled_samples = crand(key, df)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.t(df).cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dim={}_dtype={}".format(dim, np.dtype(dtype)),
"dim": dim, "dtype": dtype}
for dim in [1, 3, 5]
for dtype in float_dtypes))
def testMultivariateNormal(self, dim, dtype):
r = np.random.RandomState(dim)
mean = r.randn(dim)
cov_factor = r.randn(dim, dim)
cov = np.dot(cov_factor, cov_factor.T) + dim * np.eye(dim)
key = random.PRNGKey(0)
rand = partial(random.multivariate_normal, mean=mean, cov=cov,
shape=(10000,))
crand = api.jit(rand)
uncompiled_samples = np.asarray(rand(key), np.float64)
compiled_samples = np.asarray(crand(key), np.float64)
inv_scale = scipy.linalg.lapack.dtrtri(np.linalg.cholesky(cov), lower=True)[0]
for samples in [uncompiled_samples, compiled_samples]:
centered = samples - mean
whitened = np.einsum('nj,ij->ni', centered, inv_scale)
# This is a quick-and-dirty multivariate normality check that tests that a
# uniform mixture of the marginals along the covariance matrix's
# eigenvectors follow a standard normal distribution.
self._CheckKolmogorovSmirnovCDF(whitened.ravel(), scipy.stats.norm().cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dim={}_mean_batch_size={}_cov_batch_size={}_shape={}"\
.format(dim, mean_batch_size, cov_batch_size, shape),
"dim": dim,
"mean_batch_size": mean_batch_size,
"cov_batch_size": cov_batch_size,
"shape": shape}
for dim in [1, 2, 4]
for mean_batch_size in [(), (3,), (2, 3)]
for cov_batch_size in [(), (3,), (2, 3)]
for shape in [(), (1,), (5,)]))
def testMultivariateNormalShapes(self, dim, mean_batch_size, cov_batch_size,
shape):
r = np.random.RandomState(0)
key = random.PRNGKey(0)
eff_batch_size = mean_batch_size \
if len(mean_batch_size) > len(cov_batch_size) else cov_batch_size
mean = r.randn(*(mean_batch_size + (dim,)))
cov_factor = r.randn(*(cov_batch_size + (dim, dim)))
cov = np.einsum('...ij,...kj->...ik', cov_factor, cov_factor)
cov += 1e-3 * np.eye(dim)
shape = shape + eff_batch_size
samples = random.multivariate_normal(key, mean, cov, shape=shape)
assert samples.shape == shape + (dim,)
def testMultivariateNormalCovariance(self):
# test code based on https://github.com/google/jax/issues/1869
N = 100000
cov = jnp.array([[ 0.19, 0.00, -0.13, 0.00],
[ 0.00, 0.29, 0.00, -0.23],
[ -0.13, 0.00, 0.39, 0.00],
[ 0.00, -0.23, 0.00, 0.49]])
mean = jnp.zeros(4)
out_np = np.random.RandomState(0).multivariate_normal(mean, cov, N)
key = random.PRNGKey(0)
out_jnp = random.multivariate_normal(key, mean=mean, cov=cov, shape=(N,))
var_np = out_np.var(axis=0)
var_jnp = out_jnp.var(axis=0)
self.assertAllClose(var_np, var_jnp, rtol=1e-2, atol=1e-2,
check_dtypes=False)
var_np = np.cov(out_np, rowvar=False)
var_jnp = np.cov(out_jnp, rowvar=False)
self.assertAllClose(var_np, var_jnp, rtol=1e-2, atol=1e-2,
check_dtypes=False)
def testIssue222(self):
x = random.randint(random.PRNGKey(10003), (), 0, 0)
assert x == 0
def testFoldIn(self):
key = random.PRNGKey(0)
keys = [random.fold_in(key, i) for i in range(10)]
assert np.unique(np.ravel(keys)).shape == (20,)
def testStaticShapeErrors(self):
if config.read("jax_disable_jit"):
raise SkipTest("test only relevant when jit enabled")
@api.jit
def feature_map(n, d, sigma=1.0, seed=123):
key = random.PRNGKey(seed)
W = random.normal(key, (d, n)) / sigma
w = random.normal(key, (d, )) / sigma
b = 2 * jnp.pi * random.uniform(key, (d, ))
phi = lambda x, t: jnp.sqrt(2.0 / d) * jnp.cos(jnp.matmul(W, x) + w*t + b)
return phi
self.assertRaisesRegex(TypeError, 'Shapes must be 1D.*',
lambda: feature_map(5, 3))
def testIssue756(self):
key = random.PRNGKey(0)
w = random.normal(key, ())
if FLAGS.jax_enable_x64:
self.assertEqual(np.result_type(w), np.float64)
else:
self.assertEqual(np.result_type(w), np.float32)
def testIssue1789(self):
def f(x):
return random.gamma(random.PRNGKey(0), x)
grad(lambda x: jnp.sum(vmap(f)(x)))(jnp.ones(2))
def testNoOpByOpUnderHash(self):
def fail(*args, **kwargs): assert False
apply_primitive, xla.apply_primitive = xla.apply_primitive, fail
try:
_ = random.threefry_2x32(np.zeros(2, np.uint32), np.arange(10, dtype=np.uint32))
finally:
xla.apply_primitive = apply_primitive
def testPRNGValues(self):
# Test to ensure consistent random values between JAX versions
k = random.PRNGKey(0)
if FLAGS.jax_enable_x64:
self.assertAllClose(
random.randint(k, (3, 3), 0, 8),
np.array([[7, 2, 6],
[2, 1, 0],
[6, 7, 7]], dtype='int64'))
else:
self.assertAllClose(
random.randint(k, (3, 3), 0, 8),
np.array([[2, 1, 3],
[6, 1, 5],
[6, 3, 4]], dtype='int32'))
self.assertAllClose(
random.split(k, 4),
np.array([[2285895361, 1501764800],
[1518642379, 4090693311],
[ 433833334, 4221794875],
[ 839183663, 3740430601]], dtype='uint32'))
self.assertAllClose(
random.fold_in(k, 4),
np.array([2285895361, 433833334], dtype='uint32'))
def testDtypeErrorMessage(self):
with self.assertRaisesRegex(ValueError, r"dtype argument to.*"):
random.normal(random.PRNGKey(0), (), dtype=jnp.int32)
def testRandomBroadcast(self):
"""Issue 4033"""
# test for broadcast issue in https://github.com/google/jax/issues/4033
key = random.PRNGKey(0)
shape = (10, 2)
x = random.uniform(key, shape, minval=jnp.zeros(2), maxval=jnp.ones(2))
assert x.shape == shape
x = random.randint(key, shape, jnp.array([0, 1]), jnp.array([1, 2]))
assert x.shape == shape
def testMaxwellSample(self):
num_samples = 10**5
rng = random.PRNGKey(0)
rand = lambda x: random.maxwell(x, (num_samples, ))
crand = api.jit(rand)
loc = scipy.stats.maxwell.mean()
std = scipy.stats.maxwell.std()
uncompiled_samples = rand(rng)
compiled_samples = crand(rng)
for samples in [uncompiled_samples, compiled_samples]:
# Check first and second moments.
self.assertEqual((num_samples,), samples.shape)
self.assertAllClose(np.mean(samples), loc, atol=0., rtol=0.1)
self.assertAllClose(np.std(samples), std, atol=0., rtol=0.1)
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.maxwell().cdf)
@parameterized.named_parameters(
('test1', 4.0, 1.0),
('test2', 2.0, 3.0))
def testWeibullSample(self, concentration, scale):
num_samples = 10**5
rng = random.PRNGKey(0)
rand = lambda x: random.weibull_min(x, scale, concentration, (num_samples,))
crand = api.jit(rand)
loc = scipy.stats.weibull_min.mean(c=concentration, scale=scale)
std = scipy.stats.weibull_min.std(c=concentration, scale=scale)
uncompiled_samples = rand(rng)
compiled_samples = crand(rng)
for samples in [uncompiled_samples, compiled_samples]:
# Check first and second moments.
self.assertEqual((num_samples,), samples.shape)
self.assertAllClose(np.mean(samples), loc, atol=0., rtol=0.1)
self.assertAllClose(np.std(samples), std, atol=0., rtol=0.1)
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.weibull_min(
c=concentration, scale=scale).cdf)
@parameterized.named_parameters(
('test1', 4.0, 1.0),
('test2', 2.0, 3.0))
def testDoublesidedMaxwellSample(self, loc, scale):
num_samples = 10**5
rng = random.PRNGKey(0)
rand = lambda key: random.double_sided_maxwell(
rng, loc, scale, (num_samples,))
crand = api.jit(rand)
mean = loc
std = np.sqrt(3.) * scale
uncompiled_samples = rand(rng)
compiled_samples = crand(rng)
# Compute the double sided maxwell CDF through the one sided maxwell cdf.
# This is done as follows:
# P(DSM <= x) = P (loc + scale * radamacher_sample * one_sided_sample <=x) =
# P (radamacher_sample * one_sided_sample <= (x - loc) / scale) =
# 1/2 P(one_sided_sample <= (x - loc) / scale)
# + 1/2 P( - one_sided_sample <= (x - loc) / scale) =
# 1/2 P(one_sided_sample <= (x - loc) / scale)
# + 1/2 P(one_sided_sample >= - (x - loc) / scale) =
# 1/2 CDF_one_maxwell((x - loc) / scale))
# + 1/2 (1 - CDF_one_maxwell(- (x - loc) / scale)))
def double_sided_maxwell_cdf(x, loc, scale):
pos = scipy.stats.maxwell().cdf((x - loc)/ scale)
neg = (1 - scipy.stats.maxwell().cdf((-x + loc)/ scale))
return (pos + neg) / 2
for samples in [uncompiled_samples, compiled_samples]:
# Check first and second moments.
self.assertEqual((num_samples,), samples.shape)
self.assertAllClose(np.mean(samples), mean, atol=0., rtol=0.1)
self.assertAllClose(np.std(samples), std, atol=0., rtol=0.1)
self._CheckKolmogorovSmirnovCDF(
samples, lambda x: double_sided_maxwell_cdf(x, loc, scale))
def testRadamacher(self):
rng = random.PRNGKey(0)
num_samples = 10**5
rand = lambda x: random.rademacher(x, (num_samples,))
crand = api.jit(rand)
uncompiled_samples = rand(rng)
compiled_samples = crand(rng)
for samples in [uncompiled_samples, compiled_samples]:
unique_values, counts = np.unique(samples, return_counts=True)
assert len(unique_values) == 2
assert len(counts) == 2
self.assertAllClose(
counts[0]/ num_samples, 0.5, rtol=1e-02, atol=1e-02)
self.assertAllClose(
counts[1]/ num_samples, 0.5, rtol=1e-02, atol=1e-02)
def testChoiceShapeIsNotSequenceError(self):
key = random.PRNGKey(0)
with self.assertRaises(TypeError):
random.choice(key, 5, 2, replace=False)
with self.assertRaises(TypeError):
random.choice(key, 5, 2, replace=True)
def test_eval_shape_big_random_array(self):
def f(x):
return random.normal(random.PRNGKey(x), (int(1e12),))
with core.skipping_checks(): # check_jaxpr will materialize array
api.eval_shape(f, 0) # doesn't error
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())