2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
from absl.testing import absltest
|
|
|
|
from absl.testing import parameterized
|
|
|
|
|
|
|
|
import numpy as onp
|
|
|
|
import scipy.special
|
|
|
|
import scipy.stats
|
|
|
|
|
|
|
|
from jax import api
|
|
|
|
from jax import lax
|
|
|
|
from jax import random
|
|
|
|
from jax import test_util as jtu
|
2018-11-29 12:30:34 -08:00
|
|
|
from jax.config import config
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-11-29 12:30:34 -08:00
|
|
|
FLAGS = config.FLAGS
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
class LaxRandomTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
def _CheckCollisions(self, samples, nbits):
|
|
|
|
fail_prob = 0.01 # conservative bound on statistical fail prob by Chebyshev
|
|
|
|
nitems = len(samples)
|
|
|
|
nbins = 2 ** nbits
|
|
|
|
nexpected = nbins * (1 - ((nbins - 1) / nbins) ** nitems)
|
|
|
|
ncollisions = len(onp.unique(samples))
|
|
|
|
sq_percent_deviation = ((ncollisions - nexpected) / nexpected) ** 2
|
|
|
|
self.assertLess(sq_percent_deviation, 1 / onp.sqrt(nexpected * fail_prob))
|
|
|
|
|
|
|
|
def _CheckKolmogorovSmirnovCDF(self, samples, cdf):
|
|
|
|
fail_prob = 0.01 # conservative bound on statistical fail prob by Kolmo CDF
|
|
|
|
statistic = scipy.stats.kstest(samples, cdf).statistic
|
|
|
|
self.assertLess(1. - scipy.special.kolmogorov(statistic), fail_prob)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
|
|
|
|
for dtype in [onp.float32, onp.float64])
|
|
|
|
def testNumpyAndXLAAgreeOnFloatEndianness(self, dtype):
|
|
|
|
if not FLAGS.jax_enable_x64 and onp.issubdtype(dtype, onp.float64):
|
|
|
|
return absltest.unittest.skip("can't test float64 agreement")
|
|
|
|
|
|
|
|
bits_dtype = onp.uint32 if onp.finfo(dtype).bits == 32 else onp.uint64
|
|
|
|
numpy_bits = onp.array(1., dtype).view(bits_dtype)
|
|
|
|
xla_bits = api.jit(
|
|
|
|
lambda: lax.bitcast_convert_type(onp.array(1., dtype), bits_dtype))()
|
|
|
|
self.assertEqual(numpy_bits, xla_bits)
|
|
|
|
|
|
|
|
def testThreefry2x32(self):
|
|
|
|
# We test the hash by comparing to known values provided in the test code of
|
|
|
|
# the original reference implementation of Threefry. For the values, see
|
|
|
|
# https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_threefry.cpp#L30-L32
|
2018-11-21 13:20:44 -08:00
|
|
|
def result_to_hex(result):
|
|
|
|
return tuple([hex(x.copy()).rstrip("L") for x in result])
|
|
|
|
|
|
|
|
expected = ("0x6b200159", "0x99ba4efe")
|
2018-11-17 18:03:33 -08:00
|
|
|
result = random.threefry_2x32(onp.uint32([0, 0]), onp.uint32([0, 0]))
|
|
|
|
|
2018-11-21 13:20:44 -08:00
|
|
|
self.assertEqual(expected, result_to_hex(result))
|
|
|
|
|
|
|
|
expected = ("0x1cb996fc", "0xbb002be7")
|
2018-11-17 18:03:33 -08:00
|
|
|
result = random.threefry_2x32(onp.uint32([-1, -1]), onp.uint32([-1, -1]))
|
2018-11-21 13:20:44 -08:00
|
|
|
self.assertEqual(expected, result_to_hex(result))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-11-21 13:20:44 -08:00
|
|
|
expected = ("0xc4923a9c", "0x483df7a0")
|
2018-11-17 18:03:33 -08:00
|
|
|
result = random.threefry_2x32(
|
|
|
|
onp.uint32([0x13198a2e, 0x03707344]),
|
|
|
|
onp.uint32([0x243f6a88, 0x85a308d3]))
|
2018-11-21 13:20:44 -08:00
|
|
|
self.assertEqual(expected, result_to_hex(result))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
|
|
|
|
for dtype in [onp.float32, onp.float64])
|
|
|
|
def testRngUniform(self, dtype):
|
|
|
|
key = random.PRNGKey(0)
|
|
|
|
rand = lambda key: random.uniform(key, (10000,), dtype)
|
|
|
|
crand = api.jit(rand)
|
|
|
|
|
|
|
|
uncompiled_samples = rand(key)
|
|
|
|
compiled_samples = crand(key)
|
|
|
|
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
|
|
self._CheckCollisions(samples, onp.finfo(dtype).nmant)
|
|
|
|
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
|
|
|
|
for dtype in [onp.int32, onp.int64])
|
|
|
|
def testRngRandint(self, dtype):
|
|
|
|
lo = 5
|
|
|
|
hi = 10
|
|
|
|
|
|
|
|
key = random.PRNGKey(0)
|
|
|
|
rand = lambda key: random.randint(key, (10000,), lo, hi, dtype)
|
|
|
|
crand = api.jit(rand)
|
|
|
|
|
|
|
|
uncompiled_samples = rand(key)
|
|
|
|
compiled_samples = crand(key)
|
|
|
|
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
|
|
self.assertTrue(onp.all(lo <= samples))
|
|
|
|
self.assertTrue(onp.all(samples < hi))
|
|
|
|
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.randint(lo, hi).cdf)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
|
|
|
|
for dtype in [onp.float32, onp.float64])
|
|
|
|
def testNormal(self, dtype):
|
|
|
|
key = random.PRNGKey(0)
|
|
|
|
rand = lambda key: random.normal(key, (10000,), dtype)
|
|
|
|
crand = api.jit(rand)
|
|
|
|
|
|
|
|
uncompiled_samples = rand(key)
|
|
|
|
compiled_samples = crand(key)
|
|
|
|
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
|
|
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.norm().cdf)
|
|
|
|
|
2018-11-19 07:43:23 -08:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
|
|
|
|
for dtype in [onp.float32, onp.float64, onp.int32, onp.int64])
|
|
|
|
def testShuffle(self, dtype):
|
|
|
|
key = random.PRNGKey(0)
|
|
|
|
x = onp.arange(100).astype(dtype)
|
|
|
|
rand = lambda key: random.shuffle(key, x)
|
|
|
|
crand = api.jit(rand)
|
|
|
|
|
|
|
|
perm1 = rand(key)
|
|
|
|
perm2 = crand(key)
|
|
|
|
|
|
|
|
self.assertTrue(onp.all(perm1 == perm2))
|
|
|
|
self.assertTrue(onp.all(perm1.dtype == perm2.dtype))
|
|
|
|
self.assertFalse(onp.all(perm1 == x)) # seems unlikely!
|
|
|
|
self.assertTrue(onp.all(onp.sort(perm1) == x))
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2018-11-29 12:30:34 -08:00
|
|
|
config.config_with_absl()
|
2018-11-17 18:03:33 -08:00
|
|
|
absltest.main()
|