# Copyright 2018 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial from unittest import SkipTest, skipIf from typing import Any, Tuple, NamedTuple, Optional import zlib from absl.testing import absltest from absl.testing import parameterized import numpy as np import scipy.linalg import scipy.special import scipy.stats import jax from jax import core from jax import dtypes from jax import grad from jax import lax from jax import numpy as jnp from jax import prng from jax import random from jax._src import test_util as jtu from jax import vmap from jax.interpreters import xla import jax._src.random from jax.config import config config.parse_flags_with_absl() float_dtypes = jtu.dtypes.all_floating complex_dtypes = jtu.dtypes.complex int_dtypes = jtu.dtypes.all_integer uint_dtypes = jtu.dtypes.all_unsigned def _prng_key_as_array(key): # TODO(frostig): remove once we upgrade to always enable_custom_prng return key.unsafe_raw_array() if config.jax_enable_custom_prng else key PRNG_IMPLS = [('threefry2x32', prng.threefry_prng_impl), ('rbg', prng.rbg_prng_impl), ('unsafe_rbg', prng.unsafe_rbg_prng_impl)] class RandomValuesCase(NamedTuple): name: str prng_impl: str shape: Tuple[int] dtype: Any params: dict expected: np.ndarray skip_on_x64: bool = False atol: Optional[float] = None rtol: Optional[float] = None def _testname(self): if self.dtype is None: shape_dtype = str(self.shape) else: shape_dtype = jtu.format_shape_dtype_string(self.shape, self.dtype) name = f"_{self.name}_{self.prng_impl}_{shape_dtype}" if self.params: fmt = lambda x: str(x).replace(' ', '').replace('\n', '') name += "_" + "_".join(f"{k}={fmt(v)}" for k, v in self.params.items()) return name def _seed(self): # Generate a deterministic unique 32-bit seed given the name and prng impl return zlib.adler32((self.name + self.prng_impl).encode()) _RANDOM_VALUES_CASES = [ # TODO(jakevdp) add coverage for other distributions. RandomValuesCase("bernoulli", "threefry2x32", (5,), None, {'p': 0.5}, np.array([False, True, True, True, False]), skip_on_x64=True), RandomValuesCase("bernoulli", "rbg", (5,), None, {'p': 0.5}, np.array([True, True, True, True, True]), skip_on_x64=True), RandomValuesCase("beta", "threefry2x32", (5,), np.float32, {'a': 0.8, 'b': 0.9}, np.array([0.533685, 0.843179, 0.063495, 0.573444, 0.459514], dtype='float32')), RandomValuesCase("beta", "rbg", (5,), np.float32, {'a': 0.8, 'b': 0.9}, np.array([0.841308, 0.669989, 0.731763, 0.985127, 0.022745], dtype='float32')), RandomValuesCase("cauchy", "threefry2x32", (5,), np.float32, {}, np.array([ -0.088416, -10.169713, 3.49677, -1.18056, 0.34556], dtype='float32'), rtol=1E-5), RandomValuesCase("cauchy", "rbg", (5,), np.float32, {}, np.array([0.008389, 0.108793, -0.031826, -0.01876, 0.963218], dtype='float32')), RandomValuesCase("dirichlet", "threefry2x32", (2,), np.float32, {'alpha': np.array([0.5, 0.6, 0.7], dtype='float32')}, np.array([[0.556287, 0.304219, 0.139494], [0.15221 , 0.632251, 0.21554]], dtype='float32')), RandomValuesCase("dirichlet", "rbg", (2,), np.float32, {'alpha': np.array([0.5, 0.6, 0.7], dtype='float32')}, np.array([[0.024769, 0.002189, 0.973041], [0.326, 0.00244, 0.67156]], dtype='float32')), RandomValuesCase("double_sided_maxwell", "threefry2x32", (5,), np.float32, {"loc": 1, "scale": 2}, np.array([-2.408914, -3.370437, 3.235352, -0.907734, -1.708732], dtype='float32'), skip_on_x64=True), RandomValuesCase("double_sided_maxwell", "rbg", (5,), np.float32, {"loc": 1, "scale": 2}, np.array([4.957495, 3.003086, 5.33935, 2.942878, -1.203524], dtype='float32'), skip_on_x64=True), RandomValuesCase("exponential", "threefry2x32", (5,), np.float32, {}, np.array([0.526067, 0.043046, 0.039932, 0.46427 , 0.123886], dtype='float32')), RandomValuesCase("exponential", "rbg", (5,), np.float32, {}, np.array([0.231303, 0.684814, 0.017181, 0.089552, 0.345087], dtype='float32')), RandomValuesCase("gamma", "threefry2x32", (5,), np.float32, {'a': 0.8}, np.array([0.332641, 0.10187 , 1.816109, 0.023457, 0.487853], dtype='float32')), RandomValuesCase("gamma", "rbg", (5,), np.float32, {'a': 0.8}, np.array([0.235293, 0.446747, 0.146372, 0.79252 , 0.294762], dtype='float32')), RandomValuesCase("gumbel", "threefry2x32", (5,), np.float32, {}, np.array([2.06701, 0.911726, 0.145736, 0.185427, -0.00711], dtype='float32')), RandomValuesCase("gumbel", "rbg", (5,), np.float32, {}, np.array([-0.099308, -1.123809, 1.007618, -0.077968, 3.421349], dtype='float32')), RandomValuesCase("laplace", "threefry2x32", (5,), np.float32, {}, np.array([0.578939, -0.204902, 0.555733, 0.911053, -0.96456], dtype='float32')), RandomValuesCase("laplace", "rbg", (5,), np.float32, {}, np.array([-2.970422, 1.925082, -0.757887, -4.444797, 0.561983], dtype='float32')), RandomValuesCase("loggamma", "threefry2x32", (5,), np.float32, {'a': 0.8}, np.array([-0.899633, -0.424083, 0.631593, 0.102374, -1.07189], dtype='float32')), RandomValuesCase("loggamma", "rbg", (5,), np.float32, {'a': 0.8}, np.array([-1.333825, 0.287259, -0.343074, -0.998258, -0.773598], dtype='float32')), RandomValuesCase("logistic", "threefry2x32", (5,), np.float32, {}, np.array([0.19611, -1.709053, -0.274093, -0.208322, -1.675489], dtype='float32')), RandomValuesCase("logistic", "rbg", (5,), np.float32, {}, np.array([-0.234923, -0.545184, 0.700992, -0.708609, -1.474884], dtype='float32')), RandomValuesCase("maxwell", "threefry2x32", (5,), np.float32, {}, np.array([3.070779, 0.908479, 1.521317, 0.875551, 1.306137], dtype='float32')), RandomValuesCase("maxwell", "rbg", (5,), np.float32, {}, np.array([2.048746, 0.470027, 1.053105, 1.01969, 2.710645], dtype='float32')), RandomValuesCase("multivariate_normal", "threefry2x32", (2,), np.float32, {"mean": np.ones((1, 3)), "cov": np.eye(3)}, np.array([[ 1.067826, 1.215599, 0.234166], [-0.237534, 1.32591, 1.413987]], dtype='float32'), skip_on_x64=True), RandomValuesCase("multivariate_normal", "rbg", (2,), np.float32, {"mean": np.ones((1, 3)), "cov": np.eye(3)}, np.array([[-0.036897, 0.770969, 0.756959], [1.755091, 2.350553, 0.627142]], dtype='float32'), skip_on_x64=True), RandomValuesCase("normal", "threefry2x32", (5,), np.float32, {}, np.array([-1.173234, -1.511662, 0.070593, -0.099764, 1.052845], dtype='float32')), RandomValuesCase("normal", "rbg", (5,), np.float32, {}, np.array([-0.479658, 0.565747, -1.065106, 0.997962, -1.478002], dtype='float32')), RandomValuesCase("pareto", "threefry2x32", (5,), np.float32, {"b": 0.5}, np.array([2.751398, 1.281863, 87.85448, 1.254542, 2.824487], dtype='float32')), RandomValuesCase("pareto", "rbg", (5,), np.float32, {"b": 0.5}, np.array([1.241914, 1.521864, 5.615384, 1911.502, 1.816702], dtype='float32')), RandomValuesCase("poisson", "threefry2x32", (5,), np.int32, {"lam": 5}, np.array([7, 3, 6, 11, 6], dtype='int32')), # Note: poisson not implemented for rbg sampler. RandomValuesCase("rademacher", "threefry2x32", (5,), np.int32, {}, np.array([-1, -1, -1, -1, 1], dtype='int32'), skip_on_x64=True), RandomValuesCase("rademacher", "rbg", (5,), np.int32, {}, np.array([1, 1, 1, -1, -1], dtype='int32'), skip_on_x64=True), RandomValuesCase("randint", "threefry2x32", (5,), np.int32, {"minval": 0, "maxval": 10}, np.array([0, 5, 7, 7, 5], dtype='int32')), RandomValuesCase("randint", "rbg", (5,), np.int32, {"minval": 0, "maxval": 10}, np.array([7, 1, 8, 5, 8], dtype='int32')), RandomValuesCase("truncated_normal", "threefry2x32", (5,), np.float32, {"lower": 0, "upper": 2}, np.array([0.582807, 1.709771, 0.159513, 0.861376, 0.36148], dtype='float32')), RandomValuesCase("truncated_normal", "rbg", (5,), np.float32, {"lower": 0, "upper": 2}, np.array([0.770068, 1.516464, 0.710406, 0.762801, 1.305324], dtype='float32')), RandomValuesCase("uniform", "threefry2x32", (5,), np.float32, {}, np.array([0.298671, 0.073213, 0.873356, 0.260549, 0.412797], dtype='float32')), RandomValuesCase("uniform", "rbg", (5,), np.float32, {}, np.array([0.477161, 0.706508, 0.656261, 0.432547, 0.057772], dtype='float32')), RandomValuesCase("weibull_min", "threefry2x32", (5,), np.float32, {"scale": 1, "concentration": 1}, np.array([1.605863, 0.841809, 0.224218, 0.4826 , 0.027901], dtype='float32')), RandomValuesCase("weibull_min", "rbg", (5,), np.float32, {"scale": 1, "concentration": 1}, np.array([1.370903, 0.086532, 0.061688, 3.407599, 0.215077], dtype='float32')), ] class PrngTest(jtu.JaxTestCase): def testThreefry2x32(self): # We test the hash by comparing to known values provided in the test code of # the original reference implementation of Threefry. For the values, see # https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_threefry.cpp#L30-L32 def result_to_hex(result): return tuple([hex(x.copy()).rstrip("L") for x in result]) expected = ("0x6b200159", "0x99ba4efe") result = prng.threefry_2x32(np.uint32([0, 0]), np.uint32([0, 0])) self.assertEqual(expected, result_to_hex(result)) expected = ("0x1cb996fc", "0xbb002be7") result = prng.threefry_2x32(np.uint32([-1, -1]), np.uint32([-1, -1])) self.assertEqual(expected, result_to_hex(result)) expected = ("0xc4923a9c", "0x483df7a0") result = prng.threefry_2x32( np.uint32([0x13198a2e, 0x03707344]), np.uint32([0x243f6a88, 0x85a308d3])) self.assertEqual(expected, result_to_hex(result)) def testThreefry2x32Large(self): n = 10000000 result = prng.threefry_2x32( (np.uint32(0x13198a2e), np.uint32(0x03707344)), jnp.concatenate([ jnp.full((n,), 0x243f6a88, jnp.uint32), jnp.full((n,), 0x85a308d3, jnp.uint32) ])) np.testing.assert_equal(result[:n], np.full((n,), 0xc4923a9c, dtype=np.uint32)) np.testing.assert_equal(result[n:], np.full((n,), 0x483df7a0, dtype=np.uint32)) def testThreefry2x32Empty(self): # Regression test for an op-by-op crash for empty arrays in CUDA mode. with jax.disable_jit(): result = prng.threefry_2x32( (np.uint32(0x13198a2e), np.uint32(0x03707344)), jnp.ones((10, 0,), jnp.uint32)) np.testing.assert_equal(result, np.zeros((10, 0,), dtype=np.uint32)) def testNoOpByOpUnderHash(self): def fail(*args, **kwargs): assert False apply_primitive, xla.apply_primitive = xla.apply_primitive, fail try: _ = prng.threefry_2x32(np.zeros(2, np.uint32), np.arange(10, dtype=np.uint32)) finally: xla.apply_primitive = apply_primitive def testRngRandomBits(self): # Test specific outputs to ensure consistent random values between JAX versions. key = random.PRNGKey(1701) bits8 = jax._src.random._random_bits(key, 8, (3,)) expected8 = np.array([216, 115, 43], dtype=np.uint8) self.assertArraysEqual(bits8, expected8) bits16 = jax._src.random._random_bits(key, 16, (3,)) expected16 = np.array([41682, 1300, 55017], dtype=np.uint16) self.assertArraysEqual(bits16, expected16) bits32 = jax._src.random._random_bits(key, 32, (3,)) expected32 = np.array([56197195, 4200222568, 961309823], dtype=np.uint32) self.assertArraysEqual(bits32, expected32) with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"): bits64 = jax._src.random._random_bits(key, 64, (3,)) if config.x64_enabled: expected64 = np.array([3982329540505020460, 16822122385914693683, 7882654074788531506], dtype=np.uint64) else: expected64 = np.array([676898860, 3164047411, 4010691890], dtype=np.uint32) self.assertArraysEqual(bits64, expected64) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_" + name, "prng_name": name} for name, _ in PRNG_IMPLS)) def testRngRandomBitsShapeDtype(self, prng_name): # Like testRngRandomBits, but only meant to exercise random_bits # on every PRNG implementation. Instead of values, only checks # that shapes/dtypes are as expected. with jax.default_prng_impl(prng_name): key = random.PRNGKey(1701) bits8 = jax._src.random._random_bits(key, 8, (3,)) self.assertEqual(bits8.shape, (3,)) self.assertEqual(bits8.dtype, np.dtype('uint8')) bits16 = jax._src.random._random_bits(key, 16, (3,)) self.assertEqual(bits16.shape, (3,)) self.assertEqual(bits16.dtype, np.dtype('uint16')) bits32 = jax._src.random._random_bits(key, 32, (3,)) self.assertEqual(bits32.shape, (3,)) self.assertEqual(bits32.dtype, np.dtype('uint32')) with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"): bits64 = jax._src.random._random_bits(key, 64, (3,)) expected_dtype = np.dtype('uint64' if config.x64_enabled else 'uint32') self.assertEqual(bits64.shape, (3,)) self.assertEqual(bits64.dtype, expected_dtype) def testRngRandomBitsViewProperty(self): # TODO: add 64-bit if it ever supports this property. # TODO: will this property hold across endian-ness? N = 10 key = random.PRNGKey(1701) nbits = [8, 16, 32] rand_bits = [jax._src.random._random_bits(key, n, (N * 64 // n,)) for n in nbits] rand_bits_32 = np.array([np.array(r).view(np.uint32) for r in rand_bits]) assert np.all(rand_bits_32 == rand_bits_32[0]) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": case._testname(), "case": case} for case in _RANDOM_VALUES_CASES)) @jtu.skip_on_devices("tpu") # TPU precision causes issues. def testRandomDistributionValues(self, case): """ Tests values output by various distributions. This will catch any unintentional changes to the implementations that could result in different random sequences. Any refactoring of random distributions that leads to non-trivial differences in this test should involve a deprecation cycle following the procedures outlined at https://jax.readthedocs.io/en/latest/api_compatibility.html """ if config.x64_enabled and case.skip_on_x64: self.skipTest("test produces different values when jax_enable_x64=True") with jax.default_prng_impl(case.prng_impl): func = getattr(random, case.name) key = random.PRNGKey(case._seed()) if case.dtype: actual = func(key, **case.params, shape=case.shape, dtype=case.dtype) else: actual = func(key, **case.params, shape=case.shape) self.assertAllClose(actual, case.expected, atol=case.atol, rtol=case.rtol) def testPRNGValues(self): # Test to ensure consistent random values between JAX versions k = random.PRNGKey(0) self.assertEqual(random.randint(k, (3, 3), 0, 8).dtype, dtypes.canonicalize_dtype(jnp.int_)) if config.x64_enabled: self.assertAllClose( random.randint(k, (3, 3), 0, 8, dtype='int64'), np.array([[7, 2, 6], [2, 1, 0], [6, 7, 7]], dtype='int64')) self.assertAllClose( random.randint(k, (3, 3), 0, 8, dtype='int32'), np.array([[2, 1, 3], [6, 1, 5], [6, 3, 4]], dtype='int32')) self.assertAllClose( _prng_key_as_array(random.split(k, 4)), np.array([[2285895361, 1501764800], [1518642379, 4090693311], [ 433833334, 4221794875], [ 839183663, 3740430601]], dtype='uint32')) self.assertAllClose( _prng_key_as_array(random.fold_in(k, 4)), np.array([2285895361, 433833334], dtype='uint32')) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "seed={seed}_type={type}_jit={jit}".format(**dct), **dct} for dct in [ {"seed": 0, "type": int, "jit": True, "key": [0, 0]}, {"seed": 0, "type": int, "jit": False, "key": [0, 0]}, {"seed": 1, "type": np.int32, "jit": True, "key": [0, 1]}, {"seed": 1, "type": np.int32, "jit": False, "key": [0, 1]}, {"seed": 2, "type": np.uint32, "jit": True, "key": [0, 2]}, {"seed": 2, "type": np.uint32, "jit": False, "key": [0, 2]}, {"seed": 3, "type": np.int64, "jit": True, "key": [0, 3]}, {"seed": 3, "type": np.int64, "jit": False, "key": [0, 3]}, {"seed": -1, "type": int, "jit": True, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]}, {"seed": -1, "type": int, "jit": False, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]}, {"seed": -2, "type": np.int32, "jit": True, "key": [0, 4294967294]}, {"seed": -2, "type": np.int32, "jit": False, "key": [0, 4294967294]}, {"seed": -3, "type": np.int64, "jit": True, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]}, {"seed": -3, "type": np.int64, "jit": False, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]}, {"seed": np.iinfo(np.int32).max + 100, "type": int, "jit": True, "key": [0, 2147483747]}, {"seed": np.iinfo(np.int32).max + 100, "type": int, "jit": False, "key": [0, 2147483747]}, {"seed": np.iinfo(np.int32).max + 101, "type": np.uint32, "jit": True, "key": [0, 2147483748]}, {"seed": np.iinfo(np.int32).max + 101, "type": np.uint32, "jit": False, "key": [0, 2147483748]}, {"seed": np.iinfo(np.int32).min - 100, "type": int, "jit": True, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]}, {"seed": np.iinfo(np.int32).min - 100, "type": int, "jit": False, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]}, {"seed": np.iinfo(np.int32).min - 101, "type": np.int64, "jit": True, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]}, {"seed": np.iinfo(np.int32).min - 101, "type": np.int64, "jit": False, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]}, ] )) def test_prng_seeds_and_keys(self, seed, type, jit, key): if (jit and type is int and not config.x64_enabled and (seed < np.iinfo('int32').min or seed > np.iinfo('int32').max)): self.skipTest("Expected failure: integer out of range for jit.") seed = type(seed) if jit: actual = _prng_key_as_array(jax.jit(random.PRNGKey)(seed)) else: actual = _prng_key_as_array(random.PRNGKey(seed)) expected = jnp.array(key, dtype=jnp.uint32) self.assertArraysEqual(actual, expected) def test_default_prng_selection(self): if not config.jax_enable_custom_prng: self.skipTest("test requires config.jax_enable_custom_prng") for name, impl in PRNG_IMPLS: with jax.default_prng_impl(name): self.assertIs(random.default_prng_impl(), impl) key = random.PRNGKey(42) self.assertIs(key.impl, impl) k1, k2 = random.split(key, 2) self.assertIs(k1.impl, impl) self.assertIs(k2.impl, impl) def test_default_prng_selection_without_custom_prng_mode(self): if config.jax_enable_custom_prng: self.skipTest("test requires that config.jax_enable_custom_prng is False") for name, impl in PRNG_IMPLS: with jax.default_prng_impl(name): self.assertIs(random.default_prng_impl(), impl) key = random.PRNGKey(42) self.assertEqual(key.shape, impl.key_shape) k1, k2 = random.split(key, 2) self.assertEqual(k1.shape, impl.key_shape) self.assertEqual(k2.shape, impl.key_shape) def test_explicit_threefry2x32_key(self): if not config.jax_enable_custom_prng: self.skipTest("test requires config.jax_enable_custom_prng") key = random.threefry2x32_key(42) self.assertIs(key.impl, prng.threefry_prng_impl) def test_explicit_rbg_key(self): if not config.jax_enable_custom_prng: self.skipTest("test requires config.jax_enable_custom_prng") key = random.rbg_key(42) self.assertIs(key.impl, prng.rbg_prng_impl) def test_explicit_unsafe_rbg_key(self): if not config.jax_enable_custom_prng: self.skipTest("test requires config.jax_enable_custom_prng") key = random.unsafe_rbg_key(42) self.assertIs(key.impl, prng.unsafe_rbg_prng_impl) def test_key_array_indexing_0d(self): if not config.jax_enable_custom_prng: self.skipTest("test requires config.jax_enable_custom_prng") key = random.PRNGKey(1701) self.assertEqual(key.shape, ()) self.assertEqual(key[None].shape, (1,)) self.assertRaisesRegex(IndexError, 'Too many indices for PRNGKeyArray.*', lambda: key[0]) def test_key_array_indexing_nd(self): if not config.jax_enable_custom_prng: self.skipTest("test requires config.jax_enable_custom_prng") keys = vmap(vmap(random.PRNGKey))(jnp.arange(6).reshape((2, 3))) self.assertEqual(keys.shape, (2, 3)) self.assertEqual(keys[0, 0].shape, ()) self.assertEqual(keys[0, 1].shape, ()) self.assertEqual(keys[0].shape, (3,)) self.assertEqual(keys[1, :].shape, (3,)) self.assertEqual(keys[:, 1].shape, (2,)) self.assertEqual(keys[None].shape, (1, 2, 3)) self.assertEqual(keys[None, None].shape, (1, 1, 2, 3)) self.assertEqual(keys[None, :, None].shape, (1, 2, 1, 3)) self.assertEqual(keys[None, None, None, 0, None, None, None, 1].shape, (1,) * 6) self.assertEqual(keys[..., 1:, None].shape, (2, 2, 1)) self.assertEqual(keys[None, 0, ..., 1, None].shape, (1, 1)) self.assertRaisesRegex(IndexError, 'Too many indices for PRNGKeyArray.*', lambda: keys[0, 1, 2]) self.assertRaisesRegex(IndexError, 'Too many indices for PRNGKeyArray.*', lambda: keys[0, 1, None, 2]) class LaxRandomTest(jtu.JaxTestCase): def _CheckCollisions(self, samples, nbits): fail_prob = 0.01 # conservative bound on statistical fail prob by Chebyshev nitems = len(samples) nbins = 2 ** nbits nexpected = nbins * (1 - ((nbins - 1) / nbins) ** nitems) ncollisions = len(np.unique(samples)) sq_percent_deviation = ((ncollisions - nexpected) / nexpected) ** 2 self.assertLess(sq_percent_deviation, 1 / np.sqrt(nexpected * fail_prob)) def _CheckKolmogorovSmirnovCDF(self, samples, cdf): fail_prob = 0.01 # conservative bound on statistical fail prob by Kolmo CDF self.assertGreater(scipy.stats.kstest(samples, cdf).pvalue, fail_prob) def _CheckChiSquared(self, samples, pmf): alpha = 0.01 # significance level, threshold for p-value # scipy.stats.chisquare requires the sum of expected and actual to # match; this is only the case if we compute the expected frequency # at *all* nonzero values of the pmf. We don't know this a priori, # so we add extra values past the largest observed value. The number # below is empirically enough to get full coverage for the current set # of tests. If a new test is added where this is not enough, chisquare() # below will error due to the sums of the inputs not matching. extra_values = 100 actual_freq = np.bincount(samples, minlength=samples.max() + extra_values) values = np.arange(len(actual_freq)) expected_freq = pmf(values) * samples.size valid = expected_freq > 0 actual_freq = actual_freq[valid] expected_freq = expected_freq[valid] _, p_value = scipy.stats.chisquare(actual_freq, expected_freq) self.assertGreater( p_value, alpha, msg=f'Failed chi-squared test with p={p_value}.\n' 'Expected vs. actual frequencies:\n' f'{expected_freq}\n{actual_freq}') def seed_prng(self, seed): return random.threefry2x32_key(seed) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype} for dtype in jtu.dtypes.floating)) def testNumpyAndXLAAgreeOnFloatEndianness(self, dtype): bits_dtype = np.uint32 if jnp.finfo(dtype).bits == 32 else np.uint64 numpy_bits = np.array(1., dtype).view(bits_dtype) xla_bits = jax.jit( lambda: lax.bitcast_convert_type(np.array(1., dtype), bits_dtype))() self.assertEqual(numpy_bits, xla_bits) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype} for dtype in float_dtypes)) def testRngUniform(self, dtype): key = self.seed_prng(0) rand = lambda key: random.uniform(key, (10000,), dtype) crand = jax.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckCollisions(samples, jnp.finfo(dtype).nmant) self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype} for dtype in int_dtypes + uint_dtypes)) def testRngRandint(self, dtype): lo = 5 hi = 10 key = self.seed_prng(0) rand = lambda key: random.randint(key, (10000,), lo, hi, dtype) crand = jax.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self.assertTrue(np.all(lo <= samples)) self.assertTrue(np.all(samples < hi)) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype} for dtype in float_dtypes)) def testNormal(self, dtype): key = self.seed_prng(0) rand = lambda key: random.normal(key, (10000,), dtype) crand = jax.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.norm().cdf) def testNormalBfloat16(self): # Passing bfloat16 as dtype string. # https://github.com/google/jax/issues/6813 res_bfloat16_str = random.normal(self.seed_prng(0), dtype='bfloat16') res_bfloat16 = random.normal(self.seed_prng(0), dtype=jnp.bfloat16) self.assertAllClose(res_bfloat16, res_bfloat16_str) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "dtype={}".format(np.dtype(dtype).name), "dtype": dtype} for dtype in complex_dtypes)) def testNormalComplex(self, dtype): key = self.seed_prng(0) rand = lambda key: random.normal(key, (10000,), dtype) crand = jax.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(jnp.real(samples), scipy.stats.norm(scale=1/np.sqrt(2)).cdf) self._CheckKolmogorovSmirnovCDF(jnp.imag(samples), scipy.stats.norm(scale=1/np.sqrt(2)).cdf) self.assertEqual(dtype, samples.dtype) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype} for dtype in float_dtypes)) def testTruncatedNormal(self, dtype): key = self.seed_prng(0) rand = lambda key: random.truncated_normal(key, -0.3, 0.3, (10000,), dtype) crand = jax.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) min_val = np.min(uncompiled_samples) max_val = np.max(uncompiled_samples) self.assertTrue(min_val > -0.3) self.assertTrue(max_val < 0.3) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.truncnorm(-0.3, 0.3).cdf) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype} for dtype in jtu.dtypes.floating + jtu.dtypes.integer)) def testShuffle(self, dtype): key = self.seed_prng(0) x = np.arange(100).astype(dtype) rand = lambda key: random.shuffle(key, x) crand = jax.jit(rand) with self.assertWarns(FutureWarning): perm1 = rand(key) with self.assertWarns(FutureWarning): perm2 = crand(key) self.assertAllClose(perm1, perm2) self.assertFalse(np.all(perm1 == x)) # seems unlikely! self.assertAllClose(np.sort(perm1), x, check_dtypes=False) @parameterized.named_parameters(jtu.cases_from_list( dict( testcase_name= f"_{np.dtype(dtype).name}_input_range_or_shape={input_range_or_shape}" f"_shape={shape}_replace={replace}_weighted={weighted}_axis={axis}", dtype=dtype, input_range_or_shape=input_range_or_shape, shape=shape, replace=replace, weighted=weighted, axis=axis) for dtype in jtu.dtypes.floating + jtu.dtypes.integer for shape in [(), (5,), (4, 5)] for replace in [True, False] for weighted in [True, False] for input_range_or_shape in [100, (10, 10), (10, 5, 2), 1, (1, 5)] for is_range in [type(input_range_or_shape) is int] for ndim in [1 if is_range else len(input_range_or_shape)] for axis in range(-ndim, ndim or 1) for ninputs in [input_range_or_shape if is_range else input_range_or_shape[axis]] if replace or np.prod(shape) <= ninputs)) def testChoice(self, dtype, input_range_or_shape, shape, replace, weighted, axis): # This is the function API that we test against (note that self.rng().choice differs) np_choice = np.random.default_rng(0).choice key = self.seed_prng(0) is_range = type(input_range_or_shape) is int x = (input_range_or_shape if is_range else self.rng().permutation(jnp.arange(np.prod( input_range_or_shape), dtype=dtype)).reshape(input_range_or_shape)) N = x if is_range else x.shape[axis] p = None if not weighted else (np.arange(N) + 1) / np.sum(np.arange(N) + 1) rand = lambda key, x: random.choice(key, x, shape, replace, p, axis) sample = rand(key, x) if not is_range: self.assertEqual(dtype, sample.dtype) np_shape = np.shape(np_choice(x, shape or None, replace, p, axis)) self.assertEqual(np_shape, sample.shape) if not replace and shape: def lsort(x): if not np.prod(x.shape): return x ind = np.lexsort(np.swapaxes(x, axis, -1).reshape((-1, x.shape[axis]))) return jnp.take(x, ind, axis) self.assertArraysEqual(lsort(sample), lsort(np.unique(sample, axis=axis))) self.assertArraysEqual(sample, rand(key, np.array(x))) self.assertArraysEqual(sample, jax.jit(rand, static_argnames= 'x' if is_range else None)(key, x)) @parameterized.named_parameters(jtu.cases_from_list( dict( testcase_name=f"_dtype={dtype}_range_or_shape={range_or_shape}" f"_axis={axis}_independent={independent}", dtype=dtype, range_or_shape=range_or_shape, axis=axis, independent=independent) for dtype in jtu.dtypes.floating + jtu.dtypes.integer for range_or_shape in [0, 1, 100, (0,), (1,), (100,), (10, 10), (10, 5, 2), (0, 5), (1, 5)] for ndim in [1 if type(range_or_shape) is int else len(range_or_shape)] for axis in range(-ndim, ndim or 1) for independent in [True, False])) def testPermutation(self, dtype, range_or_shape, axis, independent): key = self.seed_prng(0) is_range = type(range_or_shape) is int x = (range_or_shape if is_range else self.rng().permutation(jnp.arange( np.prod(range_or_shape), dtype=dtype)).reshape(range_or_shape)) shape = ((range_or_shape,) if is_range else range_or_shape) x_ = np.copy(x) rand = lambda key, x: random.permutation(key, x, axis, independent=independent) perm = rand(key, x) if shape[axis] >= 10: self.assertFalse(np.all(perm == x)) # seems unlikely! arr = np.arange(x) if is_range else x def lsort(x): if not np.prod(x.shape): return x ind = np.lexsort(np.swapaxes(x, axis, -1).reshape((-1, x.shape[axis]))) return jnp.take(x, ind, axis) if not independent: self.assertArraysEqual(lsort(arr), lsort(perm), check_dtypes=not is_range) if independent and (arr.shape[axis] > 4) and (arr.size // arr.shape[axis] > 4): # Check for independent shuffling if there are >4 vectors of size >4. # Chance of false positive is 1 in (5!)^4 with self.assertRaises(AssertionError): self.assertArraysEqual(lsort(arr), lsort(perm), check_dtypes=not is_range) self.assertArraysEqual(x_, x) self.assertArraysEqual(perm, rand(key, np.array(x))) self.assertArraysEqual(perm, jax.jit(rand, static_argnames= 'x' if is_range else None)(key, x)) def testPermutationErrors(self): key = self.seed_prng(0) with self.assertRaises(ValueError): random.permutation(key, 10, axis=3) with self.assertRaises(TypeError): random.permutation(key, 10.) with self.assertRaises(core.ConcretizationTypeError): jax.jit(random.permutation)(key, 10) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_p={}_dtype={}".format(p, np.dtype(dtype).name), "p": p, "dtype": dtype} for p in [0.1, 0.5, 0.9] for dtype in jtu.dtypes.floating)) def testBernoulli(self, p, dtype): key = self.seed_prng(0) p = np.array(p, dtype=dtype) rand = lambda key, p: random.bernoulli(key, p, (10000,)) crand = jax.jit(rand) uncompiled_samples = rand(key, p) compiled_samples = crand(key, p) for samples in [uncompiled_samples, compiled_samples]: self._CheckChiSquared(samples, scipy.stats.bernoulli(p).pmf) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_p={}_{}_{}".format(p, np.dtype(dtype).name, sample_shape), "p": p, "axis": axis, "dtype": dtype, 'sample_shape': sample_shape} for (p, axis) in [ ([.25] * 4, -1), ([.1, .2, .3, .4], -1), ([[.5, .5], [.1, .9]], 1), ([[.5, .1], [.5, .9]], 0), ] for sample_shape in [(10000,), (5000, 2)] for dtype in jtu.dtypes.floating)) def testCategorical(self, p, axis, dtype, sample_shape): key = self.seed_prng(0) p = np.array(p, dtype=dtype) logits = np.log(p) - 42 # test unnormalized out_shape = tuple(np.delete(logits.shape, axis)) shape = sample_shape + out_shape rand = partial(random.categorical, shape=shape, axis=axis) crand = jax.jit(rand) uncompiled_samples = rand(key, logits) compiled_samples = crand(key, logits) if axis < 0: axis += len(logits.shape) for samples in [uncompiled_samples, compiled_samples]: assert samples.shape == shape samples = jnp.reshape(samples, (10000,) + out_shape) if len(p.shape[:-1]) > 0: ps = np.transpose(p, (1, 0)) if axis == 0 else p for cat_samples, cat_p in zip(samples.transpose(), ps): pmf = lambda x: np.where(x < len(cat_p), cat_p[np.minimum(len(cat_p) - 1, x)], 0.0) self._CheckChiSquared(cat_samples, pmf=pmf) else: pmf = lambda x: np.where(x < len(p), p[np.minimum(len(p) - 1, x)], 0.0) self._CheckChiSquared(samples, pmf=pmf) def testBernoulliShape(self): key = self.seed_prng(0) with jax.numpy_rank_promotion('allow'): x = random.bernoulli(key, np.array([0.2, 0.3]), shape=(3, 2)) assert x.shape == (3, 2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_a={}_b={}_dtype={}".format(a, b, np.dtype(dtype).name), "a": a, "b": b, "dtype": dtype} for a in [0.2, 5.] for b in [0.2, 5.] for dtype in [np.float64])) # NOTE: KS test fails with float32 def testBeta(self, a, b, dtype): if not config.x64_enabled: raise SkipTest("skip test except on X64") key = self.seed_prng(0) rand = lambda key, a, b: random.beta(key, a, b, (10000,), dtype) crand = jax.jit(rand) uncompiled_samples = rand(key, a, b) compiled_samples = crand(key, a, b) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.beta(a, b).cdf) def testBetaSmallParameters(self, dtype=np.float32): # Regression test for beta version of https://github.com/google/jax/issues/9896 key = self.seed_prng(0) a, b = 0.0001, 0.0002 samples = random.beta(key, a, b, shape=(100,), dtype=dtype) # With such small parameters, all samples should be exactly zero or one. zeros = samples[samples < 0.5] self.assertAllClose(zeros, jnp.zeros_like(zeros)) ones = samples[samples >= 0.5] self.assertAllClose(ones, jnp.ones_like(ones)) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype} for dtype in float_dtypes)) def testCauchy(self, dtype): key = self.seed_prng(0) rand = lambda key: random.cauchy(key, (10000,), dtype) crand = jax.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.cauchy().cdf) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_alpha={}_dtype={}".format(alpha, np.dtype(dtype).name), "alpha": alpha, "dtype": dtype} for alpha in [ np.array([0.2, 1., 5.]), ] for dtype in jtu.dtypes.floating)) @jtu.skip_on_devices("tpu") # TODO(mattjj): slow compilation times def testDirichlet(self, alpha, dtype): key = self.seed_prng(0) rand = lambda key, alpha: random.dirichlet(key, alpha, (10000,), dtype) crand = jax.jit(rand) uncompiled_samples = rand(key, alpha) compiled_samples = crand(key, alpha) for samples in [uncompiled_samples, compiled_samples]: self.assertAllClose(samples.sum(-1), np.ones(10000, dtype=dtype)) alpha_sum = sum(alpha) for i, a in enumerate(alpha): self._CheckKolmogorovSmirnovCDF(samples[..., i], scipy.stats.beta(a, alpha_sum - a).cdf) def testDirichletSmallAlpha(self, dtype=np.float32): # Regression test for https://github.com/google/jax/issues/9896 key = self.seed_prng(0) alpha = 0.0001 * jnp.ones(3) samples = random.dirichlet(key, alpha, shape=(100,), dtype=dtype) # Check that results lie on the simplex. self.assertAllClose(samples.sum(1), jnp.ones(samples.shape[0]), check_dtypes=False, rtol=1E-5) # Check that results contain 1 in one of the dimensions: # this is highly likely to be true when alpha is small. self.assertAllClose(samples.max(1), jnp.ones(samples.shape[0]), check_dtypes=False, rtol=1E-5) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype} for dtype in float_dtypes)) def testExponential(self, dtype): key = self.seed_prng(0) rand = lambda key: random.exponential(key, (10000,), dtype) crand = jax.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.expon().cdf) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_a={}_dtype={}".format(a, np.dtype(dtype).name), "a": a, "dtype": dtype} for a in [0.1, 1., 10.] for dtype in jtu.dtypes.floating)) def testGammaVsLogGamma(self, a, dtype): key = self.seed_prng(0) rand_gamma = lambda key, a: random.gamma(key, a, (10000,), dtype) rand_loggamma = lambda key, a: random.loggamma(key, a, (10000,), dtype) crand_loggamma = jax.jit(rand_loggamma) self.assertAllClose(rand_gamma(key, a), jnp.exp(rand_loggamma(key, a))) self.assertAllClose(rand_gamma(key, a), jnp.exp(crand_loggamma(key, a))) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_a={}_dtype={}".format(a, np.dtype(dtype).name), "a": a, "dtype": dtype} for a in [0.1, 1., 10.] for dtype in jtu.dtypes.floating)) def testGamma(self, a, dtype): key = self.seed_prng(0) rand = lambda key, a: random.gamma(key, a, (10000,), dtype) crand = jax.jit(rand) uncompiled_samples = rand(key, a) compiled_samples = crand(key, a) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gamma(a).cdf) def testGammaShape(self): key = self.seed_prng(0) x = random.gamma(key, np.array([0.2, 0.3]), shape=(3, 2)) assert x.shape == (3, 2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_a={}_logspace={}".format(alpha, log_space), "alpha": alpha, "log_space": log_space} for log_space in [True, False] for alpha in [1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4])) def testGammaGrad(self, log_space, alpha): rng = self.seed_prng(0) alphas = np.full((100,), alpha) z = random.gamma(rng, alphas) if log_space: actual_grad = jax.grad(lambda x: lax.exp(random.loggamma(rng, x)).sum())(alphas) # TODO(jakevdp): this NaN correction is required because we generate negative infinities # in the log-space computation; see related TODO in the source of random._gamma_one(). actual_grad = jnp.where(jnp.isnan(actual_grad), 0.0, actual_grad) else: actual_grad = jax.grad(lambda x: random.gamma(rng, x).sum())(alphas) eps = 0.01 * alpha / (1.0 + np.sqrt(alpha)) cdf_dot = (scipy.stats.gamma.cdf(z, alpha + eps) - scipy.stats.gamma.cdf(z, alpha - eps)) / (2 * eps) with np.errstate(over='ignore'): pdf = scipy.stats.gamma.pdf(z, alpha) expected_grad = -cdf_dot / pdf rtol = 2e-2 if jtu.device_under_test() == "tpu" else 7e-4 self.assertAllClose(actual_grad, expected_grad, check_dtypes=True, rtol=rtol) def testGammaGradType(self): # Regression test for https://github.com/google/jax/issues/2130 key = self.seed_prng(0) a = jnp.array(1., dtype=jnp.float32) b = jnp.array(3., dtype=jnp.float32) f = lambda x, y: random.gamma(key=key, a=x, dtype=jnp.float32) / y # Should not crash with a type error. jax.vjp(f, a, b) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lam={}_dtype={}".format(lam, np.dtype(dtype).name), "lam": lam, "dtype": np.dtype(dtype)} for lam in [0.5, 3, 9, 11, 50, 500] for dtype in [np.int16, np.int32, np.int64])) def testPoisson(self, lam, dtype): key = self.seed_prng(0) rand = lambda key, lam: random.poisson(key, lam, (10000,), dtype) crand = jax.jit(rand) uncompiled_samples = rand(key, lam) compiled_samples = crand(key, lam) for samples in [uncompiled_samples, compiled_samples]: self._CheckChiSquared(samples, scipy.stats.poisson(lam).pmf) # TODO(shoyer): determine error bounds for moments more rigorously (e.g., # based on the central limit theorem). self.assertAllClose(samples.mean(), lam, rtol=0.01, check_dtypes=False) self.assertAllClose(samples.var(), lam, rtol=0.03, check_dtypes=False) def testPoissonBatched(self): key = self.seed_prng(1) lam = jnp.concatenate([2 * jnp.ones(10000), 20 * jnp.ones(10000)]) samples = random.poisson(key, lam, shape=(20000,)) self._CheckChiSquared(samples[:10000], scipy.stats.poisson(2.0).pmf) self._CheckChiSquared(samples[10000:], scipy.stats.poisson(20.0).pmf) def testPoissonWithoutShape(self): key = self.seed_prng(1) lam = 2 * jnp.ones(10000) samples = random.poisson(key, lam) self._CheckChiSquared(samples, scipy.stats.poisson(2.0).pmf) def testPoissonShape(self): key = self.seed_prng(0) x = random.poisson(key, np.array([2.0, 20.0]), shape=(3, 2)) assert x.shape == (3, 2) def testPoissonZeros(self): key = self.seed_prng(0) lam = jnp.concatenate([jnp.zeros(10), 20 * jnp.ones(10)]) samples = random.poisson(key, lam, shape=(2, 20)) self.assertArraysEqual(samples[:, :10], jnp.zeros_like(samples[:, :10])) def testPoissonCornerCases(self): key = self.seed_prng(0) lam = jnp.array([-1, 0, jnp.nan]) samples = random.poisson(key, lam, shape=(3,)) self.assertArraysEqual(samples, jnp.array([-1, 0, -1])) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype} for dtype in jtu.dtypes.floating)) def testGumbel(self, dtype): key = self.seed_prng(0) rand = lambda key: random.gumbel(key, (10000,), dtype) crand = jax.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gumbel_r().cdf) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype} for dtype in float_dtypes)) def testLaplace(self, dtype): key = self.seed_prng(0) rand = lambda key: random.laplace(key, (10000,), dtype) crand = jax.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.laplace().cdf) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype} for dtype in float_dtypes)) def testLogistic(self, dtype): key = self.seed_prng(0) rand = lambda key: random.logistic(key, (10000,), dtype) crand = jax.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.logistic().cdf) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_n={}_shape={}"\ .format(n, jtu.format_shape_dtype_string(shape, dtype)), "n": n, "shape": shape, "dtype": dtype} for n in range(1, 5) for shape in [(), (5,), (10, 5)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) def testOrthogonal(self, n, shape, dtype): key = self.seed_prng(0) q = random.orthogonal(key, n, shape, dtype) self.assertEqual(q.shape, (*shape, n, n)) self.assertEqual(q.dtype, dtype) with jax.numpy_rank_promotion('allow'): self.assertAllClose( jnp.einsum('...ij,...jk->...ik', q, jnp.conj(q).swapaxes(-2, -1)), jnp.broadcast_to(jnp.eye(n, dtype=dtype), (*shape, n, n)) ) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_b={}_dtype={}".format(b, np.dtype(dtype).name), "b": b, "dtype": dtype} for b in [0.1, 1., 10.] for dtype in jtu.dtypes.floating)) def testPareto(self, b, dtype): key = self.seed_prng(0) rand = lambda key, b: random.pareto(key, b, (10000,), dtype) crand = jax.jit(rand) uncompiled_samples = rand(key, b) compiled_samples = crand(key, b) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.pareto(b).cdf) def testParetoShape(self): key = self.seed_prng(0) with jax.numpy_rank_promotion('allow'): x = random.pareto(key, np.array([0.2, 0.3]), shape=(3, 2)) assert x.shape == (3, 2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_df={}_dtype={}".format(df, np.dtype(dtype).name), "df": df, "dtype": dtype} for df in [0.1, 1., 10.] for dtype in jtu.dtypes.floating)) @jtu.skip_on_devices("cpu", "tpu") # TODO(phawkins): slow compilation times def testT(self, df, dtype): key = self.seed_prng(0) rand = lambda key, df: random.t(key, df, (10000,), dtype) crand = jax.jit(rand) uncompiled_samples = rand(key, df) compiled_samples = crand(key, df) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.t(df).cdf) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dim={}_dtype={}_method={}".format( dim, np.dtype(dtype), method), "dim": dim, "dtype": dtype, "method": method} for dim in [1, 3, 5] for dtype in float_dtypes for method in ['svd', 'eigh', 'cholesky'])) @jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1 def testMultivariateNormal(self, dim, dtype, method): r = self.rng() mean = r.randn(dim) cov_factor = r.randn(dim, dim) cov = np.dot(cov_factor, cov_factor.T) + dim * np.eye(dim) key = self.seed_prng(0) rand = partial(random.multivariate_normal, mean=mean, cov=cov, shape=(10000,), method=method) crand = jax.jit(rand) with jax.numpy_rank_promotion('allow'): uncompiled_samples = np.asarray(rand(key), np.float64) compiled_samples = np.asarray(crand(key), np.float64) inv_scale = scipy.linalg.lapack.dtrtri(np.linalg.cholesky(cov), lower=True)[0] for samples in [uncompiled_samples, compiled_samples]: centered = samples - mean whitened = np.einsum('nj,ij->ni', centered, inv_scale) # This is a quick-and-dirty multivariate normality check that tests that a # uniform mixture of the marginals along the covariance matrix's # eigenvectors follow a standard normal distribution. self._CheckKolmogorovSmirnovCDF(whitened.ravel(), scipy.stats.norm().cdf) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dim={}_mean_batch_size={}_cov_batch_size={}_shape={}_method={}"\ .format(dim, mean_batch_size, cov_batch_size, shape, method), "dim": dim, "mean_batch_size": mean_batch_size, "cov_batch_size": cov_batch_size, "shape": shape, "method": method} for dim in [1, 2, 4] for mean_batch_size in [(), (3,), (2, 3)] for cov_batch_size in [(), (3,), (2, 3)] for shape in [(), (1,), (5,)] for method in ['cholesky', 'svd', 'eigh'])) @jtu.skip_on_devices("rocm") # will be solved in rocm-5.1 def testMultivariateNormalShapes(self, dim, mean_batch_size, cov_batch_size, shape, method): r = self.rng() key = self.seed_prng(0) eff_batch_size = mean_batch_size \ if len(mean_batch_size) > len(cov_batch_size) else cov_batch_size mean = r.randn(*(mean_batch_size + (dim,))) cov_factor = r.randn(*(cov_batch_size + (dim, dim))) cov = np.einsum('...ij,...kj->...ik', cov_factor, cov_factor) cov += 1e-3 * np.eye(dim) shape = shape + eff_batch_size with jax.numpy_rank_promotion('allow'): samples = random.multivariate_normal(key, mean, cov, shape=shape, method=method) assert samples.shape == shape + (dim,) def testMultivariateNormalCovariance(self): # test code based on https://github.com/google/jax/issues/1869 N = 100000 cov = jnp.array([[ 0.19, 0.00, -0.13, 0.00], [ 0.00, 0.29, 0.00, -0.23], [ -0.13, 0.00, 0.39, 0.00], [ 0.00, -0.23, 0.00, 0.49]]) mean = jnp.zeros(4) out_np = self.rng().multivariate_normal(mean, cov, N) key = self.seed_prng(0) with jax.numpy_rank_promotion('allow'): out_jnp = random.multivariate_normal(key, mean=mean, cov=cov, shape=(N,)) var_np = out_np.var(axis=0) var_jnp = out_jnp.var(axis=0) self.assertAllClose(var_np, var_jnp, rtol=1e-2, atol=1e-2, check_dtypes=False) var_np = np.cov(out_np, rowvar=False) var_jnp = np.cov(out_jnp, rowvar=False) self.assertAllClose(var_np, var_jnp, rtol=1e-2, atol=1e-2, check_dtypes=False) def testIssue222(self): x = random.randint(self.seed_prng(10003), (), 0, 0) assert x == 0 def testFoldIn(self): key = self.seed_prng(0) keys = [_prng_key_as_array(random.fold_in(key, i)) for i in range(10)] assert np.unique(keys, axis=0).shape[0] == 10 def testFoldInBig(self): key = self.seed_prng(0) seeds = [2 ** 32 - 2, 2 ** 32 - 1] keys = [_prng_key_as_array(random.fold_in(key, seed)) for seed in seeds] assert np.unique(keys, axis=0).shape[0] == 2 def testStaticShapeErrors(self): if config.jax_disable_jit: raise SkipTest("test only relevant when jit enabled") @jax.jit def feature_map(n, d, sigma=1.0, seed=123): key = self.seed_prng(seed) W = random.normal(key, (d, n)) / sigma w = random.normal(key, (d, )) / sigma b = 2 * jnp.pi * random.uniform(key, (d, )) phi = lambda x, t: jnp.sqrt(2.0 / d) * jnp.cos(jnp.matmul(W, x) + w*t + b) return phi self.assertRaisesRegex(TypeError, 'Shapes must be 1D.*', lambda: feature_map(5, 3)) def testIssue756(self): key = self.seed_prng(0) w = random.normal(key, ()) self.assertEqual(w.dtype, dtypes.canonicalize_dtype(jnp.float_)) def testIssue1789(self): def f(x): return random.gamma(self.seed_prng(0), x) grad(lambda x: jnp.sum(vmap(f)(x)))(jnp.ones(2)) def testDtypeErrorMessage(self): with self.assertRaisesRegex(ValueError, r"dtype argument to.*"): random.normal(self.seed_prng(0), (), dtype=jnp.int32) def testRandomBroadcast(self): """Issue 4033""" # test for broadcast issue in https://github.com/google/jax/issues/4033 key = self.seed_prng(0) shape = (10, 2) with jax.numpy_rank_promotion('allow'): x1 = random.uniform(key, shape, minval=jnp.zeros(2), maxval=jnp.ones(2)) x2 = random.randint(key, shape, jnp.array([0, 1]), jnp.array([1, 2])) assert x1.shape == shape assert x2.shape == shape def testMaxwellSample(self): num_samples = 10**5 rng = self.seed_prng(0) rand = lambda x: random.maxwell(x, (num_samples, )) crand = jax.jit(rand) loc = jtu.to_default_dtype(scipy.stats.maxwell.mean()) std = jtu.to_default_dtype(scipy.stats.maxwell.std()) uncompiled_samples = rand(rng) compiled_samples = crand(rng) for samples in [uncompiled_samples, compiled_samples]: # Check first and second moments. self.assertEqual((num_samples,), samples.shape) self.assertAllClose(np.mean(samples), loc, atol=0., rtol=0.1) self.assertAllClose(np.std(samples), std, atol=0., rtol=0.1) self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.maxwell().cdf) @parameterized.named_parameters( ('test1', 4.0, 1.0), ('test2', 2.0, 3.0)) def testWeibullSample(self, concentration, scale): num_samples = 10**5 rng = self.seed_prng(0) rand = lambda x: random.weibull_min(x, scale, concentration, (num_samples,)) crand = jax.jit(rand) loc = jtu.to_default_dtype(scipy.stats.weibull_min.mean(c=concentration, scale=scale)) std = jtu.to_default_dtype(scipy.stats.weibull_min.std(c=concentration, scale=scale)) uncompiled_samples = rand(rng) compiled_samples = crand(rng) for samples in [uncompiled_samples, compiled_samples]: # Check first and second moments. self.assertEqual((num_samples,), samples.shape) self.assertAllClose(np.mean(samples), loc, atol=0., rtol=0.1) self.assertAllClose(np.std(samples), std, atol=0., rtol=0.1) self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.weibull_min( c=concentration, scale=scale).cdf) @parameterized.named_parameters( ('test1', 4.0, 1.0), ('test2', 2.0, 3.0)) def testDoublesidedMaxwellSample(self, loc, scale): num_samples = 10**5 rng = self.seed_prng(0) rand = lambda key: random.double_sided_maxwell( rng, loc, scale, (num_samples,)) crand = jax.jit(rand) mean = loc std = np.sqrt(3.) * scale uncompiled_samples = rand(rng) compiled_samples = crand(rng) # Compute the double sided maxwell CDF through the one sided maxwell cdf. # This is done as follows: # P(DSM <= x) = P (loc + scale * radamacher_sample * one_sided_sample <=x) = # P (radamacher_sample * one_sided_sample <= (x - loc) / scale) = # 1/2 P(one_sided_sample <= (x - loc) / scale) # + 1/2 P( - one_sided_sample <= (x - loc) / scale) = # 1/2 P(one_sided_sample <= (x - loc) / scale) # + 1/2 P(one_sided_sample >= - (x - loc) / scale) = # 1/2 CDF_one_maxwell((x - loc) / scale)) # + 1/2 (1 - CDF_one_maxwell(- (x - loc) / scale))) def double_sided_maxwell_cdf(x, loc, scale): pos = scipy.stats.maxwell().cdf((x - loc) / scale) neg = (1 - scipy.stats.maxwell().cdf((-x + loc) / scale)) return (pos + neg) / 2 for samples in [uncompiled_samples, compiled_samples]: # Check first and second moments. self.assertEqual((num_samples,), samples.shape) self.assertAllClose(samples.mean(), jtu.to_default_dtype(mean), atol=0., rtol=0.1) self.assertAllClose(samples.std(), jtu.to_default_dtype(std), atol=0., rtol=0.1) self._CheckKolmogorovSmirnovCDF( samples, lambda x: double_sided_maxwell_cdf(x, loc, scale)) def testRadamacher(self): rng = self.seed_prng(0) num_samples = 10**5 rand = lambda x: random.rademacher(x, (num_samples,)) crand = jax.jit(rand) uncompiled_samples = rand(rng) compiled_samples = crand(rng) for samples in [uncompiled_samples, compiled_samples]: unique_values, counts = np.unique(samples, return_counts=True) assert len(unique_values) == 2 assert len(counts) == 2 self.assertAllClose( counts[0] / num_samples, 0.5, rtol=1e-02, atol=1e-02) self.assertAllClose( counts[1] / num_samples, 0.5, rtol=1e-02, atol=1e-02) def testChoiceShapeIsNotSequenceError(self): key = self.seed_prng(0) with self.assertRaises(TypeError): random.choice(key, 5, 2, replace=False) with self.assertRaises(TypeError): random.choice(key, 5, 2, replace=True) def test_eval_shape_big_random_array(self): def f(x): return random.normal(self.seed_prng(x), (int(1e12),)) with jax.enable_checks(False): # check_jaxpr will materialize array jax.eval_shape(f, 0) # doesn't error @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": f"_seed={seed}_type={type}", "seed": seed, "type": type} for type in ["int", "np.array", "jnp.array"] for seed in [-1, 0, 1, (1 << 32) - 1, (1 << 63) - 1, np.uint64((1 << 64) - 1)])) def test_prng_jit_invariance(self, seed, type): if type == "int" and seed == (1 << 64) - 1: self.skipTest("Expected failure: Python int too large.") if not config.x64_enabled and seed > np.iinfo(np.int32).max: self.skipTest("Expected failure: Python int too large.") type = {"int": int, "np.array": np.array, "jnp.array": jnp.array}[type] args_maker = lambda: [type(seed)] make_prng = lambda seed: _prng_key_as_array(self.seed_prng(seed)) self._CompileAndCheck(make_prng, args_maker) def test_prng_errors(self): seed = np.iinfo(np.int64).max + 1 with self.assertRaises(OverflowError): self.seed_prng(seed) with self.assertRaises(OverflowError): jax.jit(self.seed_prng)(seed) def test_random_split_doesnt_device_put_during_tracing(self): key = _prng_key_as_array(self.seed_prng(1)).block_until_ready() with jtu.count_device_put() as count: jax.jit(random.split)(key) self.assertEqual(count[0], 1) # 1 for the argument device_put @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": f"_dtype={dtype}", "dtype": dtype} for dtype in int_dtypes + uint_dtypes)) def test_randint_bounds(self, dtype): min = np.iinfo(dtype).min max = np.iinfo(dtype).max key = self.seed_prng(1701) shape = (10,) if np.iinfo(dtype).bits < np.iinfo(dtypes.canonicalize_dtype(int)).bits: expected = random.randint(key, shape, min, max, dtype) self.assertArraysEqual(expected, random.randint(key, shape, min - 12345, max + 12345, dtype)) else: self.assertRaises(OverflowError, random.randint, key, shape, min - 12345, max + 12345, dtype) def test_randint_out_of_range(self): key = self.seed_prng(0) r = random.randint(key, (10,), 255, 256, np.uint8) self.assertAllClose(r, jnp.full_like(r, 255)) r = random.randint(key, (1000,), -128, 128, np.int8) self.assertGreater((r == -128).sum(), 0) self.assertGreater((r == 127).sum(), 0) r = random.randint(key, (1000,), -1000, 1000, np.uint8) self.assertGreater((r == 0).sum(), 0) self.assertGreater((r == 255).sum(), 0) threefry_seed = jax._src.prng.threefry_seed threefry_split = jax._src.prng.threefry_split threefry_random_bits = jax._src.prng.threefry_random_bits threefry_fold_in = jax._src.prng.threefry_fold_in def _double_threefry_seed(seed): int_t = seed.dtype.type if hasattr(seed, 'dtype') else type(seed) s1, s2 = seed ^ int_t(1), seed ^ int_t(3) return jnp.vstack([threefry_seed(s1), threefry_seed(s2)]) def _double_threefry_split(key, num): split0 = threefry_split(key[0], num) split1 = threefry_split(key[1], num) merge = jnp.vstack([jnp.expand_dims(split0.T, axis=0), jnp.expand_dims(split1.T, axis=0)]) return merge.transpose((2, 0, 1)) def _double_threefry_random_bits(key, bit_width, shape): bits0 = threefry_random_bits(key[0], bit_width, shape) bits1 = threefry_random_bits(key[1], bit_width, shape) return bits0 * bits1 def _double_threefry_fold_in(key, data): return jnp.vstack([threefry_fold_in(key[0], data), threefry_fold_in(key[1], data)]) double_threefry_prng_impl = prng.PRNGImpl( key_shape=(2, 2), seed=_double_threefry_seed, split=_double_threefry_split, random_bits=_double_threefry_random_bits, fold_in=_double_threefry_fold_in) @skipIf(not config.jax_enable_custom_prng, 'custom PRNG tests require config.jax_enable_custom_prng') class LaxRandomWithCustomPRNGTest(LaxRandomTest): def seed_prng(self, seed): return prng.seed_with_impl(double_threefry_prng_impl, seed) def test_split_shape(self): key = self.seed_prng(73) keys = random.split(key, 10) self.assertEqual(keys.shape, (10,)) def test_vmap_fold_in_shape(self): key = self.seed_prng(73) keys = vmap(lambda i: random.fold_in(key, i))(jnp.arange(3)) self.assertEqual(keys.shape, (3,)) def test_cannot_add(self): key = self.seed_prng(73) self.assertRaisesRegex( TypeError, r'unsupported operand type\(s\) for \+*', lambda: key + 47) @skipIf(np.__version__ == "1.21.0", "https://github.com/numpy/numpy/issues/19305") def test_grad_of_prng_key(self): key = self.seed_prng(73) jax.grad(lambda x: 1., allow_int=True)(key) # does not crash @skipIf(not config.jax_enable_custom_prng, 'custom PRNG tests require config.jax_enable_custom_prng') class LaxRandomWithRBGPRNGTest(LaxRandomTest): def seed_prng(self, seed): return random.rbg_key(seed) def test_split_shape(self): key = self.seed_prng(73) keys = random.split(key, 10) self.assertEqual(keys.shape, (10,)) def test_vmap_fold_in_shape(self): key = self.seed_prng(73) keys = vmap(lambda i: random.fold_in(key, i))(jnp.arange(3)) self.assertEqual(keys.shape, (3,)) def test_vmap_split_not_mapped_key(self): key = self.seed_prng(73) single_split_key = random.split(key) vmapped_keys = vmap(lambda _: random.split(key))(jnp.zeros(3,)) self.assertEqual(vmapped_keys.shape, (3, 2)) for vk in vmapped_keys: self.assertArraysEqual(vk.unsafe_raw_array(), single_split_key.unsafe_raw_array()) def test_vmap_split_mapped_key(self): key = self.seed_prng(73) mapped_keys = random.split(key, num=3) forloop_keys = [random.split(k) for k in mapped_keys] vmapped_keys = vmap(random.split)(mapped_keys) self.assertEqual(vmapped_keys.shape, (3, 2)) for fk, vk in zip(forloop_keys, vmapped_keys): self.assertArraysEqual(fk.unsafe_raw_array(), vk.unsafe_raw_array()) def test_vmap_random_bits(self): rand_fun = lambda key: random.randint(key, (), 0, 100) key = self.seed_prng(73) mapped_keys = random.split(key, num=3) forloop_rand_nums = [rand_fun(k) for k in mapped_keys] rand_nums = vmap(rand_fun)(mapped_keys) self.assertEqual(rand_nums.shape, (3,)) self.assertArraysEqual(rand_nums, jnp.array(forloop_rand_nums)) def test_cannot_add(self): key = self.seed_prng(73) self.assertRaisesRegex( TypeError, r'unsupported operand type\(s\) for \+*', lambda: key + 47) @skipIf(np.__version__ == "1.21.0", "https://github.com/numpy/numpy/issues/19305") def test_grad_of_prng_key(self): key = self.seed_prng(73) jax.grad(lambda x: 1., allow_int=True)(key) # does not crash def test_random_split_doesnt_device_put_during_tracing(self): return # this test doesn't apply to the RBG PRNG def test_randint_out_of_range(self): # TODO(mattjj): enable this test if/when RngBitGenerator supports it raise SkipTest('8-bit types not supported with RBG PRNG') class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest): def seed_prng(self, seed): return prng.seed_with_impl(prng.unsafe_rbg_prng_impl, seed) def like(keys): return jnp.ones(keys.shape) @skipIf(not config.jax_enable_custom_prng, 'custom PRNG tests require config.jax_enable_custom_prng') class JnpWithPRNGKeyArrayTest(jtu.JaxTestCase): def test_reshape(self): key = random.PRNGKey(123) keys = random.split(key, 4) out = jnp.reshape(keys, (2, 2)) ref = jnp.reshape(like(keys), (2, 2)) self.assertEqual(out.shape, ref.shape) self.assertEqual(out.shape, (2, 2)) def test_tile(self): key = random.PRNGKey(123) out = jnp.tile(key, 3) ref = jnp.tile(like(key), 3) self.assertEqual(out.shape, ref.shape) self.assertEqual(out.shape, (3,)) def test_concatenate(self): key = random.PRNGKey(123) keys = random.split(key, 2) out = jnp.concatenate([keys, keys, keys], axis=0) ref = jnp.concatenate([like(keys)] * 3, axis=0) self.assertEqual(out.shape, ref.shape) self.assertEqual(out.shape, (6,)) def test_broadcast_to(self): key = random.PRNGKey(123) out = jnp.broadcast_to(key, (3,)) ref = jnp.broadcast_to(like(key), (3,)) self.assertEqual(out.shape, ref.shape) self.assertEqual(out.shape, (3,)) out = jnp.broadcast_to(key, 3) self.assertEqual(out.shape, ref.shape) self.assertEqual(out.shape, (3,)) def test_expand_dims(self): key = random.PRNGKey(123) keys = random.split(key, 6) keys = jnp.reshape(keys, (2, 3)) out = jnp.expand_dims(keys, 1) ref = jnp.expand_dims(like(keys), 1) self.assertEqual(out.shape, ref.shape) self.assertEqual(out.shape, (2, 1, 3)) def test_broadcast_arrays(self): key = random.PRNGKey(123) keys = jax.random.split(key, 3) out, _ = jnp.broadcast_arrays(key, keys) ref, _ = jnp.broadcast_arrays(like(key), like(keys)) self.assertEqual(out.shape, ref.shape) self.assertEqual(out.shape, (3,)) def test_append(self): key = random.PRNGKey(123) out = jnp.append(key, key) ref = jnp.append(like(key), like(key)) self.assertEqual(out.shape, ref.shape) self.assertEqual(out.shape, (2,)) out_ = jnp.append(out, out) ref_ = jnp.append(like(out), like(out)) self.assertEqual(out_.shape, ref_.shape) self.assertEqual(out_.shape, (4,)) def test_ravel(self): key = random.PRNGKey(123) keys = jax.random.split(key, 4) keys = jnp.reshape(keys, (2, 2)) out = jnp.ravel(keys) ref = jnp.ravel(like(keys)) self.assertEqual(out.shape, ref.shape) self.assertEqual(out.shape, (4,)) def test_stack(self): key = random.PRNGKey(123) keys = jax.random.split(key, 2) out = jnp.stack([keys, keys, keys], axis=0) ref = jnp.stack([like(keys)] * 3, axis=0) self.assertEqual(out.shape, ref.shape) self.assertEqual(out.shape, (3, 2)) def _sampler_unimplemented_with_custom_prng(*args, **kwargs): raise SkipTest('sampler only implemented for default RNG') for test_prefix in [ 'testPoisson', 'testPoissonBatched', 'testPoissonShape', 'testPoissonZeros', ]: for attr in dir(LaxRandomTest): if attr.startswith(test_prefix): setattr(LaxRandomWithCustomPRNGTest, attr, _sampler_unimplemented_with_custom_prng) setattr(LaxRandomWithRBGPRNGTest, attr, _sampler_unimplemented_with_custom_prng) setattr(LaxRandomWithUnsafeRBGPRNGTest, attr, _sampler_unimplemented_with_custom_prng) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())