mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00

Before this change, JAX could dispatch compiled functions over new-style (typed) RNG key arrays, but it would always do so off of the fast (C++-based) dispatch path. In other words, switching from old-style `uint32` RNG keys to new-style keys would regress dispatch times. With this change, dispatch happens on the fast path again and performance regressions ought to be minimal. We currently maintain only one pytree registry, for all registered pytree node types. We want RNG key arrays to also be treated as pytree leaves everywhere *except* during dispatch. In other words: we want operations on (typed) RNG key arrays to appear in Jaxpr, but we want to unravel those arrays into their underlying `uint32` arrays only during dispatch. To do this, we add a new internal pytree registry that dispatch respects uniquely. This registry includes all items in the default registry, but also the RNG key array type. Co-authored-by: Matthew Johnson <mattjj@google.com> PiperOrigin-RevId: 565077758
2637 lines
100 KiB
Python
2637 lines
100 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import copy
|
|
import enum
|
|
from functools import partial
|
|
import math
|
|
from unittest import SkipTest, skipIf
|
|
from typing import Any, NamedTuple, Optional
|
|
import zlib
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
|
|
import numpy as np
|
|
import scipy.linalg
|
|
import scipy.special
|
|
import scipy.stats
|
|
|
|
import jax
|
|
from jax import grad
|
|
from jax import lax
|
|
from jax import numpy as jnp
|
|
from jax import random
|
|
from jax import tree_util
|
|
from jax._src import core
|
|
from jax._src import dtypes
|
|
from jax._src import test_util as jtu
|
|
from jax import vmap
|
|
from jax.interpreters import xla
|
|
|
|
from jax._src import random as jax_random
|
|
from jax._src import prng as prng_internal
|
|
|
|
from jax import config
|
|
config.parse_flags_with_absl()
|
|
|
|
float_dtypes = jtu.dtypes.all_floating
|
|
complex_dtypes = jtu.dtypes.complex
|
|
int_dtypes = jtu.dtypes.all_integer
|
|
uint_dtypes = jtu.dtypes.all_unsigned
|
|
|
|
|
|
def _prng_key_as_array(key):
|
|
# TODO(frostig): remove some day when we deprecate "raw" key arrays
|
|
if jnp.issubdtype(key.dtype, dtypes.prng_key):
|
|
return key.unsafe_raw_array()
|
|
else:
|
|
return key
|
|
|
|
def _maybe_unwrap(key):
|
|
# TODO(frostig): remove some day when we deprecate "raw" key arrays
|
|
unwrap = prng_internal.random_unwrap
|
|
return unwrap(key) if jnp.issubdtype(key, dtypes.prng_key) else key
|
|
|
|
|
|
PRNG_IMPLS = [('threefry2x32', prng_internal.threefry_prng_impl),
|
|
('rbg', prng_internal.rbg_prng_impl),
|
|
('unsafe_rbg', prng_internal.unsafe_rbg_prng_impl)]
|
|
|
|
|
|
class OnX64(enum.Enum):
|
|
ALSO = enum.auto()
|
|
SKIP = enum.auto()
|
|
ONLY = enum.auto()
|
|
|
|
class RandomValuesCase(NamedTuple):
|
|
name: str
|
|
prng_impl: str
|
|
shape: tuple[int, ...]
|
|
dtype: Any
|
|
params: dict
|
|
expected: np.ndarray
|
|
on_x64: OnX64 = OnX64.ALSO
|
|
atol: Optional[float] = None
|
|
rtol: Optional[float] = None
|
|
|
|
def _testname(self):
|
|
if self.dtype is None:
|
|
shape_dtype = str(self.shape)
|
|
else:
|
|
shape_dtype = jtu.format_shape_dtype_string(self.shape, self.dtype)
|
|
name = f"_{self.name}_{self.prng_impl}_{shape_dtype}"
|
|
if self.params:
|
|
fmt = lambda x: str(x).replace(' ', '').replace('\n', '')
|
|
name += "_" + "_".join(f"{k}={fmt(v)}" for k, v in self.params.items())
|
|
return name
|
|
|
|
def _seed(self):
|
|
# Generate a deterministic unique 32-bit seed given the name and prng impl
|
|
return zlib.adler32((self.name + self.prng_impl).encode())
|
|
|
|
|
|
_RANDOM_VALUES_CASES = [
|
|
# TODO(jakevdp) add coverage for other distributions.
|
|
RandomValuesCase("bernoulli", "threefry2x32", (5,), None, {'p': 0.5},
|
|
np.array([False, True, True, True, False]), on_x64=OnX64.SKIP),
|
|
RandomValuesCase("bernoulli", "rbg", (5,), None, {'p': 0.5},
|
|
np.array([True, True, True, True, True]), on_x64=OnX64.SKIP),
|
|
RandomValuesCase("beta", "threefry2x32", (5,), np.float32, {'a': 0.8, 'b': 0.9},
|
|
np.array([0.13259 , 0.824893, 0.948363, 0.964155, 0.235448], dtype='float32')),
|
|
RandomValuesCase("beta", "rbg", (5,), np.float32, {'a': 0.8, 'b': 0.9},
|
|
np.array([0.93215 , 0.833959, 0.121902, 0.270003, 0.429541], dtype='float32')),
|
|
# TODO(frostig,jakevdp) add coverage for non-threefry bits
|
|
RandomValuesCase("bits", "threefry2x32", (5,), np.uint8, {},
|
|
np.array([10, 158, 82, 54, 158], dtype='uint8')),
|
|
RandomValuesCase("bits", "threefry2x32", (5,), np.uint16, {},
|
|
np.array([6738, 38161, 50695, 57337, 61600], dtype='uint16')),
|
|
RandomValuesCase("bits", "threefry2x32", (5,), np.uint32, {},
|
|
np.array([1978747883, 4134381225, 3628107870, 689687174, 2788938207], dtype='uint32')),
|
|
RandomValuesCase("bits", "threefry2x32", (5,), np.uint64, {},
|
|
np.array([17649965731882839947, 1415307058040849897, 8282622628079774249,
|
|
14024425113645909402, 2012979996110532418], dtype='uint64'),
|
|
on_x64=OnX64.ONLY),
|
|
RandomValuesCase("cauchy", "threefry2x32", (5,), np.float32, {},
|
|
np.array([ -0.088416, -10.169713, 3.49677, -1.18056, 0.34556], dtype='float32'), rtol=1E-5),
|
|
RandomValuesCase("cauchy", "rbg", (5,), np.float32, {},
|
|
np.array([0.008389, 0.108793, -0.031826, -0.01876, 0.963218], dtype='float32')),
|
|
RandomValuesCase("dirichlet", "threefry2x32", (2,), np.float32, {'alpha': np.array([0.5, 0.6, 0.7], dtype='float32')},
|
|
np.array([[0.003128, 0.009694, 0.987178], [0.025938, 0.479091, 0.494971]], dtype='float32')),
|
|
RandomValuesCase("dirichlet", "rbg", (2,), np.float32, {'alpha': np.array([0.5, 0.6, 0.7], dtype='float32')},
|
|
np.array([[0.080742, 0.525493, 0.393765], [0.006837, 0.804796, 0.188366]], dtype='float32')),
|
|
RandomValuesCase("double_sided_maxwell", "threefry2x32", (5,), np.float32, {"loc": 1, "scale": 2},
|
|
np.array([-2.408914, -3.370437, 3.235352, -0.907734, -1.708732], dtype='float32'), on_x64=OnX64.SKIP),
|
|
RandomValuesCase("double_sided_maxwell", "rbg", (5,), np.float32, {"loc": 1, "scale": 2},
|
|
np.array([4.957495, 3.003086, 5.33935, 2.942878, -1.203524], dtype='float32'), on_x64=OnX64.SKIP),
|
|
RandomValuesCase("exponential", "threefry2x32", (5,), np.float32, {},
|
|
np.array([0.526067, 0.043046, 0.039932, 0.46427 , 0.123886], dtype='float32')),
|
|
RandomValuesCase("exponential", "rbg", (5,), np.float32, {},
|
|
np.array([0.231303, 0.684814, 0.017181, 0.089552, 0.345087], dtype='float32')),
|
|
RandomValuesCase("gamma", "threefry2x32", (5,), np.float32, {'a': 0.8},
|
|
np.array([0.824221, 1.724476, 0.502882, 5.386132, 0.685543], dtype='float32')),
|
|
RandomValuesCase("gamma", "rbg", (5,), np.float32, {'a': 0.8},
|
|
np.array([0.994946, 0.519941, 1.754347, 0.479223, 1.16932 ], dtype='float32')),
|
|
RandomValuesCase("gumbel", "threefry2x32", (5,), np.float32, {},
|
|
np.array([2.06701, 0.911726, 0.145736, 0.185427, -0.00711], dtype='float32')),
|
|
RandomValuesCase("gumbel", "rbg", (5,), np.float32, {},
|
|
np.array([-0.099308, -1.123809, 1.007618, -0.077968, 3.421349], dtype='float32')),
|
|
RandomValuesCase("laplace", "threefry2x32", (5,), np.float32, {},
|
|
np.array([0.578939, -0.204902, 0.555733, 0.911053, -0.96456], dtype='float32')),
|
|
RandomValuesCase("laplace", "rbg", (5,), np.float32, {},
|
|
np.array([-2.970422, 1.925082, -0.757887, -4.444797, 0.561983], dtype='float32')),
|
|
RandomValuesCase("loggamma", "threefry2x32", (5,), np.float32, {'a': 0.8},
|
|
np.array([ 0.240559, -3.575443, -0.450946, -2.161372, -2.943277], dtype='float32')),
|
|
RandomValuesCase("loggamma", "rbg", (5,), np.float32, {'a': 0.8},
|
|
np.array([-0.107021, -0.809968, -0.25546 , -1.212273, -1.946579], dtype='float32')),
|
|
RandomValuesCase("logistic", "threefry2x32", (5,), np.float32, {},
|
|
np.array([0.19611, -1.709053, -0.274093, -0.208322, -1.675489], dtype='float32')),
|
|
RandomValuesCase("logistic", "rbg", (5,), np.float32, {},
|
|
np.array([-0.234923, -0.545184, 0.700992, -0.708609, -1.474884], dtype='float32')),
|
|
RandomValuesCase("maxwell", "threefry2x32", (5,), np.float32, {},
|
|
np.array([3.070779, 0.908479, 1.521317, 0.875551, 1.306137], dtype='float32')),
|
|
RandomValuesCase("maxwell", "rbg", (5,), np.float32, {},
|
|
np.array([2.048746, 0.470027, 1.053105, 1.01969, 2.710645], dtype='float32')),
|
|
RandomValuesCase("multivariate_normal", "threefry2x32", (2,), np.float32, {"mean": np.ones((1, 3)), "cov": np.eye(3)},
|
|
np.array([[ 1.067826, 1.215599, 0.234166], [-0.237534, 1.32591, 1.413987]], dtype='float32'), on_x64=OnX64.SKIP),
|
|
RandomValuesCase("multivariate_normal", "rbg", (2,), np.float32, {"mean": np.ones((1, 3)), "cov": np.eye(3)},
|
|
np.array([[-0.036897, 0.770969, 0.756959], [1.755091, 2.350553, 0.627142]], dtype='float32'), on_x64=OnX64.SKIP),
|
|
RandomValuesCase("normal", "threefry2x32", (5,), np.float32, {},
|
|
np.array([-1.173234, -1.511662, 0.070593, -0.099764, 1.052845], dtype='float32')),
|
|
RandomValuesCase("normal", "rbg", (5,), np.float32, {},
|
|
np.array([-0.479658, 0.565747, -1.065106, 0.997962, -1.478002], dtype='float32')),
|
|
RandomValuesCase("pareto", "threefry2x32", (5,), np.float32, {"b": 0.5},
|
|
np.array([2.751398, 1.281863, 87.85448, 1.254542, 2.824487], dtype='float32')),
|
|
RandomValuesCase("pareto", "rbg", (5,), np.float32, {"b": 0.5},
|
|
np.array([1.241914, 1.521864, 5.615384, 1911.502, 1.816702], dtype='float32')),
|
|
RandomValuesCase("poisson", "threefry2x32", (5,), np.int32, {"lam": 5},
|
|
np.array([7, 3, 6, 11, 6], dtype='int32')),
|
|
# Note: poisson not implemented for rbg sampler.
|
|
RandomValuesCase("rademacher", "threefry2x32", (5,), np.int32, {},
|
|
np.array([-1, -1, -1, -1, 1], dtype='int32'), on_x64=OnX64.SKIP),
|
|
RandomValuesCase("rademacher", "rbg", (5,), np.int32, {},
|
|
np.array([1, 1, 1, -1, -1], dtype='int32'), on_x64=OnX64.SKIP),
|
|
RandomValuesCase("randint", "threefry2x32", (5,), np.int32, {"minval": 0, "maxval": 10},
|
|
np.array([0, 5, 7, 7, 5], dtype='int32')),
|
|
RandomValuesCase("randint", "rbg", (5,), np.int32, {"minval": 0, "maxval": 10},
|
|
np.array([7, 1, 8, 5, 8], dtype='int32')),
|
|
RandomValuesCase("truncated_normal", "threefry2x32", (5,), np.float32, {"lower": 0, "upper": 2},
|
|
np.array([0.582807, 1.709771, 0.159513, 0.861376, 0.36148], dtype='float32')),
|
|
RandomValuesCase("truncated_normal", "rbg", (5,), np.float32, {"lower": 0, "upper": 2},
|
|
np.array([0.770068, 1.516464, 0.710406, 0.762801, 1.305324], dtype='float32')),
|
|
RandomValuesCase("uniform", "threefry2x32", (5,), np.float32, {},
|
|
np.array([0.298671, 0.073213, 0.873356, 0.260549, 0.412797], dtype='float32')),
|
|
RandomValuesCase("uniform", "rbg", (5,), np.float32, {},
|
|
np.array([0.477161, 0.706508, 0.656261, 0.432547, 0.057772], dtype='float32')),
|
|
RandomValuesCase("weibull_min", "threefry2x32", (5,), np.float32, {"scale": 1, "concentration": 1},
|
|
np.array([1.605863, 0.841809, 0.224218, 0.4826 , 0.027901], dtype='float32')),
|
|
RandomValuesCase("weibull_min", "rbg", (5,), np.float32, {"scale": 1, "concentration": 1},
|
|
np.array([1.370903, 0.086532, 0.061688, 3.407599, 0.215077], dtype='float32')),
|
|
]
|
|
|
|
|
|
KEY_CTORS = [random.key, random.PRNGKey]
|
|
|
|
@jtu.with_config(jax_legacy_prng_key='allow')
|
|
class PrngTest(jtu.JaxTestCase):
|
|
|
|
def check_key_has_impl(self, key, impl):
|
|
if jnp.issubdtype(key.dtype, dtypes.prng_key):
|
|
self.assertIs(key.impl, impl)
|
|
else:
|
|
self.assertEqual(key.dtype, jnp.dtype('uint32'))
|
|
self.assertEqual(key.shape, impl.key_shape)
|
|
|
|
def testThreefry2x32(self):
|
|
# We test the hash by comparing to known values provided in the test code of
|
|
# the original reference implementation of Threefry. For the values, see
|
|
# https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_threefry.cpp#L30-L32
|
|
def result_to_hex(result):
|
|
return tuple(hex(x.copy()).rstrip("L") for x in result)
|
|
|
|
expected = ("0x6b200159", "0x99ba4efe")
|
|
result = prng_internal.threefry_2x32(np.uint32([0, 0]), np.uint32([0, 0]))
|
|
|
|
self.assertEqual(expected, result_to_hex(result))
|
|
|
|
expected = ("0x1cb996fc", "0xbb002be7")
|
|
u32_max = np.iinfo(np.uint32).max
|
|
result = prng_internal.threefry_2x32(np.uint32([u32_max, u32_max]), np.uint32([u32_max, u32_max]))
|
|
self.assertEqual(expected, result_to_hex(result))
|
|
|
|
expected = ("0xc4923a9c", "0x483df7a0")
|
|
result = prng_internal.threefry_2x32(
|
|
np.uint32([0x13198a2e, 0x03707344]),
|
|
np.uint32([0x243f6a88, 0x85a308d3]))
|
|
self.assertEqual(expected, result_to_hex(result))
|
|
|
|
def testThreefry2x32Large(self):
|
|
n = 10000000
|
|
result = prng_internal.threefry_2x32(
|
|
(np.uint32(0x13198a2e), np.uint32(0x03707344)),
|
|
jnp.concatenate([
|
|
jnp.full((n,), 0x243f6a88, jnp.uint32),
|
|
jnp.full((n,), 0x85a308d3, jnp.uint32)
|
|
]))
|
|
np.testing.assert_equal(result[:n], np.full((n,), 0xc4923a9c, dtype=np.uint32))
|
|
np.testing.assert_equal(result[n:], np.full((n,), 0x483df7a0, dtype=np.uint32))
|
|
|
|
def testThreefry2x32Empty(self):
|
|
# Regression test for an op-by-op crash for empty arrays in CUDA mode.
|
|
with jax.disable_jit():
|
|
result = prng_internal.threefry_2x32(
|
|
(np.uint32(0x13198a2e), np.uint32(0x03707344)),
|
|
jnp.ones((10, 0,), jnp.uint32))
|
|
np.testing.assert_equal(result, np.zeros((10, 0,), dtype=np.uint32))
|
|
|
|
def testNoOpByOpUnderHash(self):
|
|
def fail(*args, **kwargs): assert False
|
|
apply_primitive, xla.apply_primitive = xla.apply_primitive, fail
|
|
try:
|
|
_ = prng_internal.threefry_2x32(np.zeros(2, np.uint32), np.arange(10, dtype=np.uint32))
|
|
finally:
|
|
xla.apply_primitive = apply_primitive
|
|
|
|
@skipIf(config.jax_threefry_partitionable, 'changed random bit values')
|
|
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
|
def testRngRandomBits(self, make_key):
|
|
# Test specific outputs to ensure consistent random values between JAX versions.
|
|
|
|
def random_bits(key, width, shape):
|
|
# TODO(frostig): Use random.bits, as in:
|
|
#
|
|
# def random_bits(key, width, shape):
|
|
# dtype = jnp.dtype(f'uint{width}')
|
|
# return jax.random.bits(key, shape, dtype)
|
|
#
|
|
# Doing so doesn't work in width 64 at present due to
|
|
# normalization in random.bits.
|
|
key, _ = jax_random._check_prng_key(key)
|
|
return jax_random._random_bits(key, width, shape)
|
|
|
|
key = make_key(1701)
|
|
|
|
bits8 = random_bits(key, 8, (3,))
|
|
expected8 = np.array([216, 115, 43], dtype=np.uint8)
|
|
self.assertArraysEqual(bits8, expected8)
|
|
|
|
bits16 = random_bits(key, 16, (3,))
|
|
expected16 = np.array([41682, 1300, 55017], dtype=np.uint16)
|
|
self.assertArraysEqual(bits16, expected16)
|
|
|
|
bits32 = random_bits(key, 32, (3,))
|
|
expected32 = np.array([56197195, 4200222568, 961309823], dtype=np.uint32)
|
|
self.assertArraysEqual(bits32, expected32)
|
|
|
|
with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"):
|
|
bits64 = random_bits(key, 64, (3,))
|
|
if config.x64_enabled:
|
|
expected64 = np.array([3982329540505020460, 16822122385914693683,
|
|
7882654074788531506], dtype=np.uint64)
|
|
else:
|
|
expected64 = np.array([676898860, 3164047411, 4010691890], dtype=np.uint32)
|
|
self.assertArraysEqual(bits64, expected64)
|
|
|
|
@jtu.sample_product(prng_name=[name for name, _ in PRNG_IMPLS],
|
|
make_key=KEY_CTORS)
|
|
def testRngRandomBitsShapeDtype(self, prng_name, make_key):
|
|
# Like testRngRandomBits, but only meant to exercise random_bits
|
|
# on every PRNG implementation. Instead of values, only checks
|
|
# that shapes/dtypes are as expected.
|
|
|
|
def random_bits(key, width, shape):
|
|
dtype = jnp.dtype(f'uint{width}')
|
|
return jax.random.bits(key, shape, dtype)
|
|
|
|
with jax.default_prng_impl(prng_name):
|
|
key = make_key(1701)
|
|
|
|
bits8 = random_bits(key, 8, (3,))
|
|
self.assertEqual(bits8.shape, (3,))
|
|
self.assertEqual(bits8.dtype, np.dtype('uint8'))
|
|
|
|
bits16 = random_bits(key, 16, (3,))
|
|
self.assertEqual(bits16.shape, (3,))
|
|
self.assertEqual(bits16.dtype, np.dtype('uint16'))
|
|
|
|
bits32 = random_bits(key, 32, (3,))
|
|
self.assertEqual(bits32.shape, (3,))
|
|
self.assertEqual(bits32.dtype, np.dtype('uint32'))
|
|
|
|
with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"):
|
|
bits64 = random_bits(key, 64, (3,))
|
|
expected_dtype = np.dtype('uint64' if config.x64_enabled else 'uint32')
|
|
self.assertEqual(bits64.shape, (3,))
|
|
self.assertEqual(bits64.dtype, expected_dtype)
|
|
|
|
@skipIf(config.jax_threefry_partitionable, 'changed random bit values')
|
|
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
|
def testRngRandomBitsViewProperty(self, make_key):
|
|
# TODO: add 64-bit if it ever supports this property.
|
|
# TODO: will this property hold across endian-ness?
|
|
|
|
def random_bits(key, width, shape):
|
|
dtype = jnp.dtype(f'uint{width}')
|
|
return jax.random.bits(key, shape, dtype)
|
|
|
|
N = 10
|
|
key = make_key(1701)
|
|
nbits = [8, 16, 32]
|
|
rand_bits = [random_bits(key, n, (N * 64 // n,)) for n in nbits]
|
|
rand_bits_32 = np.array([np.array(r).view(np.uint32) for r in rand_bits])
|
|
assert np.all(rand_bits_32 == rand_bits_32[0])
|
|
|
|
|
|
@jtu.sample_product(case=_RANDOM_VALUES_CASES, make_key=KEY_CTORS)
|
|
@skipIf(config.jax_threefry_partitionable, 'changed random bit values')
|
|
@jtu.skip_on_devices("tpu") # TPU precision causes issues.
|
|
def testRandomDistributionValues(self, case, make_key):
|
|
"""
|
|
Tests values output by various distributions. This will catch any
|
|
unintentional changes to the implementations that could result in
|
|
different random sequences.
|
|
|
|
Any refactoring of random distributions that leads to non-trivial
|
|
differences in this test should follow the procedure outlined at
|
|
https://jax.readthedocs.io/en/latest/api_compatibility.html#numerics-and-randomness
|
|
|
|
This includes:
|
|
* Announcing the change in the CHANGELOG.md
|
|
* Considering adding a flag that reverts the new behavior, made
|
|
available for a deprecation window's amount of time.
|
|
"""
|
|
if config.x64_enabled and case.on_x64 == OnX64.SKIP:
|
|
self.skipTest("test produces different values when jax_enable_x64=True")
|
|
if not config.x64_enabled and case.on_x64 == OnX64.ONLY:
|
|
self.skipTest("test only valid when jax_enable_x64=True")
|
|
with jax.default_prng_impl(case.prng_impl):
|
|
func = getattr(random, case.name)
|
|
key = make_key(case._seed())
|
|
if case.dtype:
|
|
actual = func(key, **case.params, shape=case.shape, dtype=case.dtype)
|
|
else:
|
|
actual = func(key, **case.params, shape=case.shape)
|
|
self.assertAllClose(actual, case.expected, atol=case.atol, rtol=case.rtol)
|
|
|
|
@skipIf(config.jax_threefry_partitionable, 'changed random bit values')
|
|
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
|
def testPRNGValues(self, make_key):
|
|
# Test to ensure consistent random values between JAX versions
|
|
k = make_key(0)
|
|
|
|
self.assertEqual(random.randint(k, (3, 3), 0, 8).dtype,
|
|
dtypes.canonicalize_dtype(jnp.int_))
|
|
if config.x64_enabled:
|
|
self.assertAllClose(
|
|
random.randint(k, (3, 3), 0, 8, dtype='int64'),
|
|
np.array([[7, 2, 6],
|
|
[2, 1, 0],
|
|
[6, 7, 7]], dtype='int64'))
|
|
self.assertAllClose(
|
|
random.randint(k, (3, 3), 0, 8, dtype='int32'),
|
|
np.array([[2, 1, 3],
|
|
[6, 1, 5],
|
|
[6, 3, 4]], dtype='int32'))
|
|
|
|
self.assertAllClose(
|
|
_prng_key_as_array(random.split(k, 4)),
|
|
np.array([[2285895361, 1501764800],
|
|
[1518642379, 4090693311],
|
|
[ 433833334, 4221794875],
|
|
[ 839183663, 3740430601]], dtype='uint32'))
|
|
|
|
self.assertAllClose(
|
|
_prng_key_as_array(random.fold_in(k, 4)),
|
|
np.array([2285895361, 433833334], dtype='uint32'))
|
|
|
|
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
|
def test_random_bits_error(self, make_key):
|
|
msg = 'dtype argument .* must be an unsigned int dtype'
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
random.bits(make_key(0), (3, 4), np.dtype('int8'))
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
random.bits(make_key(0), (3, 4), np.dtype('float16'))
|
|
|
|
@skipIf(not config.jax_threefry_partitionable, 'enable after upgrade')
|
|
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
|
def test_threefry_split_fold_in_symmetry(self, make_key):
|
|
with jax.default_prng_impl('threefry2x32'):
|
|
key = make_key(72)
|
|
f1, f2, f3 = (random.fold_in(key, i) for i in range(3))
|
|
s1, s2, s3 = random.split(key, 3)
|
|
f1, f2, f3 = map(_prng_key_as_array, [f1, f2, f3])
|
|
s1, s2, s3 = map(_prng_key_as_array, [s1, s2, s3])
|
|
self.assertArraysEqual(f1, s1)
|
|
self.assertArraysEqual(f2, s2)
|
|
self.assertArraysEqual(f3, s3)
|
|
|
|
@skipIf(not config.jax_threefry_partitionable, 'enable after upgrade')
|
|
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
|
def test_threefry_split_vmapped_fold_in_symmetry(self, make_key):
|
|
# See https://github.com/google/jax/issues/7708
|
|
with jax.default_prng_impl('threefry2x32'):
|
|
key = make_key(72)
|
|
f1, f2, f3 = vmap(lambda k, _: random.fold_in(k, lax.axis_index('batch')),
|
|
in_axes=(None, 0), axis_name='batch')(key, jnp.ones(3))
|
|
s1, s2, s3 = random.split(key, 3)
|
|
f1, f2, f3 = map(_prng_key_as_array, [f1, f2, f3])
|
|
s1, s2, s3 = map(_prng_key_as_array, [s1, s2, s3])
|
|
self.assertArraysEqual(f1, s1)
|
|
self.assertArraysEqual(f2, s2)
|
|
self.assertArraysEqual(f3, s3)
|
|
|
|
@parameterized.parameters([params
|
|
for d in [
|
|
{"seed": 0, "typ": int, "jit": True, "key": [0, 0]},
|
|
{"seed": 0, "typ": int, "jit": False, "key": [0, 0]},
|
|
{"seed": 1, "typ": np.int32, "jit": True, "key": [0, 1]},
|
|
{"seed": 1, "typ": np.int32, "jit": False, "key": [0, 1]},
|
|
{"seed": 2, "typ": np.uint32, "jit": True, "key": [0, 2]},
|
|
{"seed": 2, "typ": np.uint32, "jit": False, "key": [0, 2]},
|
|
{"seed": 3, "typ": np.int64, "jit": True, "key": [0, 3]},
|
|
{"seed": 3, "typ": np.int64, "jit": False, "key": [0, 3]},
|
|
{"seed": -1, "typ": int, "jit": True, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
|
|
{"seed": -1, "typ": int, "jit": False, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
|
|
{"seed": -2, "typ": np.int32, "jit": True, "key": [0, 4294967294]},
|
|
{"seed": -2, "typ": np.int32, "jit": False, "key": [0, 4294967294]},
|
|
{"seed": -3, "typ": np.int64, "jit": True, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
|
|
{"seed": -3, "typ": np.int64, "jit": False, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
|
|
{"seed": np.iinfo(np.int32).max + 100, "typ": int, "jit": True, "key": [0, 2147483747]},
|
|
{"seed": np.iinfo(np.int32).max + 100, "typ": int, "jit": False, "key": [0, 2147483747]},
|
|
{"seed": np.iinfo(np.int32).max + 101, "typ": np.uint32, "jit": True, "key": [0, 2147483748]},
|
|
{"seed": np.iinfo(np.int32).max + 101, "typ": np.uint32, "jit": False, "key": [0, 2147483748]},
|
|
{"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": True, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
|
|
{"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": False, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
|
|
{"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": True, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
|
|
{"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": False, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
|
|
]
|
|
for params in [dict(**d, make_key=ctor) for ctor in KEY_CTORS]
|
|
])
|
|
def test_prng_seeds_and_keys(self, seed, typ, jit, key, make_key):
|
|
seed = typ(seed)
|
|
if jit:
|
|
maker = lambda k: _prng_key_as_array(jax.jit(make_key)(k))
|
|
else:
|
|
maker = lambda k: _prng_key_as_array(make_key(k))
|
|
if (jit and typ is int and not config.x64_enabled and
|
|
(seed < np.iinfo('int32').min or seed > np.iinfo('int32').max)):
|
|
# We expect an error to be raised.
|
|
# NOTE: we check 'if jit' because some people rely on builtin int seeds
|
|
# (e.g. from PRNGKey(hash("altair is best plotting library"))) outside jit
|
|
|
|
# First check with no cache entry (note lambda above).
|
|
with self.assertRaises(OverflowError):
|
|
maker(seed)
|
|
|
|
# Then populate a cache entry.
|
|
maker(typ(0)).block_until_ready()
|
|
|
|
# Then check now that we have a cache entry.
|
|
with self.assertRaises(OverflowError):
|
|
maker(seed)
|
|
|
|
else:
|
|
# Otherwise we expect no error.
|
|
actual = maker(seed)
|
|
expected = jnp.array(key, dtype=jnp.uint32)
|
|
self.assertArraysEqual(actual, expected)
|
|
|
|
@parameterized.parameters([
|
|
{'make_key': ctor, 'name': name, 'impl': impl}
|
|
for ctor in KEY_CTORS
|
|
for name, impl in PRNG_IMPLS])
|
|
def test_default_prng_selection(self, make_key, name, impl):
|
|
with jax.default_prng_impl(name):
|
|
self.assertIs(random.default_prng_impl(), impl)
|
|
key = make_key(42)
|
|
self.check_key_has_impl(key, impl)
|
|
k1, k2 = random.split(key, 2)
|
|
self.check_key_has_impl(k1, impl)
|
|
self.check_key_has_impl(k2, impl)
|
|
|
|
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
|
|
def test_explicit_threefry2x32_key(self):
|
|
self.check_key_has_impl(random.threefry2x32_key(42),
|
|
prng_internal.threefry_prng_impl)
|
|
|
|
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
|
|
def test_explicit_rbg_key(self):
|
|
self.check_key_has_impl(random.rbg_key(42),
|
|
prng_internal.rbg_prng_impl)
|
|
|
|
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
|
|
def test_explicit_unsafe_rbg_key(self):
|
|
self.check_key_has_impl(random.unsafe_rbg_key(42),
|
|
prng_internal.unsafe_rbg_prng_impl)
|
|
|
|
@parameterized.parameters([{'make_key': ctor, 'name': name, 'impl': impl}
|
|
for ctor in KEY_CTORS
|
|
for name, impl in PRNG_IMPLS])
|
|
def test_key_construction_with_explicit_impl_name(self, make_key, name, impl):
|
|
key = make_key(42, impl=name)
|
|
self.check_key_has_impl(key, impl)
|
|
|
|
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
|
def test_isinstance(self, make_key):
|
|
key = make_key(0)
|
|
self.assertIsInstance(key, jax.Array)
|
|
|
|
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
|
def test_key_output_vjp(self, make_key):
|
|
# See https://github.com/google/jax/issues/14856
|
|
def f(seed): return make_key(seed)
|
|
jax.vjp(f, 1) # doesn't crash
|
|
|
|
def test_legacy_prng_key_flag(self):
|
|
raw_key = jnp.zeros(2, dtype='uint32')
|
|
invalid_key = jnp.zeros(1, dtype='float32')
|
|
msg = "Legacy uint32 key array passed as key to jax.random function."
|
|
|
|
with jax.legacy_prng_key('allow'):
|
|
# TODO(jakevdp): remove when enable_custom_prng no longer issues warnings
|
|
with jax.enable_custom_prng(False):
|
|
with self.assertNoWarnings():
|
|
random.uniform(raw_key)
|
|
|
|
with jax.legacy_prng_key('warn'):
|
|
with self.assertWarnsRegex(UserWarning, msg):
|
|
random.uniform(raw_key)
|
|
|
|
with jax.legacy_prng_key('error'):
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
random.uniform(raw_key)
|
|
|
|
# Invalid key error should take precedence.
|
|
with self.assertRaisesRegex(TypeError, "JAX encountered invalid PRNG key data"):
|
|
random.uniform(invalid_key)
|
|
|
|
|
|
class ThreefryPrngTest(jtu.JaxTestCase):
|
|
@parameterized.parameters([{'make_key': ctor} for ctor in [
|
|
random.threefry2x32_key,
|
|
partial(random.PRNGKey, impl='threefry2x32'),
|
|
partial(random.key, impl='threefry2x32')]])
|
|
def test_seed_no_implicit_transfers(self, make_key):
|
|
# See https://github.com/google/jax/issues/15613
|
|
with jax.transfer_guard('disallow'):
|
|
make_key(jax.device_put(42)) # doesn't crash
|
|
|
|
|
|
@jtu.with_config(jax_legacy_prng_key='allow')
|
|
class LaxRandomTest(jtu.JaxTestCase):
|
|
|
|
def _CheckCollisions(self, samples, nbits):
|
|
fail_prob = 0.01 # conservative bound on statistical fail prob by Chebyshev
|
|
nitems = len(samples)
|
|
nbins = 2 ** nbits
|
|
nexpected = nbins * (1 - ((nbins - 1) / nbins) ** nitems)
|
|
ncollisions = len(np.unique(samples))
|
|
sq_percent_deviation = ((ncollisions - nexpected) / nexpected) ** 2
|
|
self.assertLess(sq_percent_deviation, 1 / np.sqrt(nexpected * fail_prob))
|
|
|
|
def _CheckKolmogorovSmirnovCDF(self, samples, cdf):
|
|
# conservative bound on statistical fail prob by Kolmo CDF
|
|
# bfloat16 quantization creates much lower p-values in large distributions
|
|
fail_prob = 0.003 if samples.dtype == jnp.bfloat16 else 0.01
|
|
# TODO(frostig): This reads enable_custom_prng as a proxy for
|
|
# whether RBG keys may be involved, but that's no longer exact.
|
|
if config.jax_enable_custom_prng and samples.dtype == jnp.bfloat16:
|
|
return
|
|
self.assertGreater(scipy.stats.kstest(samples, cdf).pvalue, fail_prob)
|
|
|
|
def _CheckChiSquared(self, samples, pmf):
|
|
if samples.dtype == bool:
|
|
samples = samples.astype(int)
|
|
alpha = 0.01 # significance level, threshold for p-value
|
|
|
|
# scipy.stats.chisquare requires the sum of expected and actual to
|
|
# match; this is only the case if we compute the expected frequency
|
|
# at *all* nonzero values of the pmf. We don't know this a priori,
|
|
# so we add extra values past the largest observed value. The number
|
|
# below is empirically enough to get full coverage for the current set
|
|
# of tests. If a new test is added where this is not enough, chisquare()
|
|
# below will error due to the sums of the inputs not matching.
|
|
extra_values = 100
|
|
actual_freq = np.bincount(samples, minlength=samples.max() + extra_values)
|
|
values = np.arange(len(actual_freq))
|
|
|
|
expected_freq = pmf(values) * samples.size
|
|
|
|
valid = expected_freq > 0
|
|
actual_freq = actual_freq[valid]
|
|
expected_freq = expected_freq[valid]
|
|
|
|
_, p_value = scipy.stats.chisquare(actual_freq, expected_freq)
|
|
self.assertGreater(
|
|
p_value, alpha,
|
|
msg=f'Failed chi-squared test with p={p_value}.\n'
|
|
'Expected vs. actual frequencies:\n'
|
|
f'{expected_freq}\n{actual_freq}')
|
|
|
|
def make_key(self, seed):
|
|
return random.threefry2x32_key(seed)
|
|
|
|
@jtu.sample_product(
|
|
num=(None, 6, (6,), (2, 3), (2, 3, 4)),
|
|
)
|
|
def test_split_size_shape(self, num):
|
|
key = self.make_key(0)
|
|
if num is None:
|
|
key_split = jax.random.split(key)
|
|
else:
|
|
key_split = jax.random.split(key, num)
|
|
|
|
if num is None:
|
|
self.assertEqual(key_split.shape, (2, *key.shape))
|
|
elif type(num) is tuple:
|
|
self.assertEqual(key_split.shape, (*num, *key.shape))
|
|
else:
|
|
self.assertEqual(key_split.shape, (num, *key.shape))
|
|
|
|
@jtu.sample_product(dtype=jtu.dtypes.floating)
|
|
def testNumpyAndXLAAgreeOnFloatEndianness(self, dtype):
|
|
bits_dtype = np.uint32 if jnp.finfo(dtype).bits == 32 else np.uint64
|
|
numpy_bits = np.array(1., dtype).view(bits_dtype)
|
|
xla_bits = jax.jit(
|
|
lambda: lax.bitcast_convert_type(np.array(1., dtype), bits_dtype))()
|
|
self.assertEqual(numpy_bits, xla_bits)
|
|
|
|
@jtu.sample_product(dtype=float_dtypes)
|
|
def testRngUniform(self, dtype):
|
|
key = self.make_key(0)
|
|
rand = lambda key: random.uniform(key, (10000,), dtype)
|
|
crand = jax.jit(rand)
|
|
|
|
uncompiled_samples = rand(key)
|
|
compiled_samples = crand(key)
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
self._CheckCollisions(samples, jnp.finfo(dtype).nmant)
|
|
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf)
|
|
|
|
@jtu.sample_product(dtype=int_dtypes + uint_dtypes)
|
|
def testRngRandint(self, dtype):
|
|
lo = 5
|
|
hi = 10
|
|
|
|
key = self.make_key(0)
|
|
rand = lambda key: random.randint(key, (10000,), lo, hi, dtype)
|
|
crand = jax.jit(rand)
|
|
|
|
uncompiled_samples = rand(key)
|
|
compiled_samples = crand(key)
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
self.assertTrue(np.all(lo <= samples))
|
|
self.assertTrue(np.all(samples < hi))
|
|
|
|
@jtu.sample_product(dtype=float_dtypes)
|
|
def testNormal(self, dtype):
|
|
key = self.make_key(0)
|
|
rand = lambda key: random.normal(key, (10000,), dtype)
|
|
crand = jax.jit(rand)
|
|
|
|
uncompiled_samples = rand(key)
|
|
compiled_samples = crand(key)
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.norm().cdf)
|
|
|
|
def testNormalBfloat16(self):
|
|
# Passing bfloat16 as dtype string.
|
|
# https://github.com/google/jax/issues/6813
|
|
res_bfloat16_str = random.normal(self.make_key(0), dtype='bfloat16')
|
|
res_bfloat16 = random.normal(self.make_key(0), dtype=jnp.bfloat16)
|
|
self.assertAllClose(res_bfloat16, res_bfloat16_str)
|
|
|
|
@jtu.sample_product(dtype=complex_dtypes)
|
|
def testNormalComplex(self, dtype):
|
|
key = self.make_key(0)
|
|
rand = lambda key: random.normal(key, (10000,), dtype)
|
|
crand = jax.jit(rand)
|
|
|
|
uncompiled_samples = rand(key)
|
|
compiled_samples = crand(key)
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
self._CheckKolmogorovSmirnovCDF(jnp.real(samples), scipy.stats.norm(scale=1/np.sqrt(2)).cdf)
|
|
self._CheckKolmogorovSmirnovCDF(jnp.imag(samples), scipy.stats.norm(scale=1/np.sqrt(2)).cdf)
|
|
self.assertEqual(dtype, samples.dtype)
|
|
|
|
@jtu.sample_product(dtype=float_dtypes)
|
|
def testTruncatedNormal(self, dtype):
|
|
key = self.make_key(0)
|
|
rand = lambda key: random.truncated_normal(key, -0.3, 0.3, (10000,), dtype)
|
|
crand = jax.jit(rand)
|
|
|
|
uncompiled_samples = rand(key)
|
|
compiled_samples = crand(key)
|
|
|
|
min_val = np.min(uncompiled_samples)
|
|
max_val = np.max(uncompiled_samples)
|
|
self.assertTrue(min_val > -0.3)
|
|
self.assertTrue(max_val < 0.3)
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.truncnorm(-0.3, 0.3).cdf)
|
|
|
|
@jtu.sample_product(dtype=jtu.dtypes.floating + jtu.dtypes.integer)
|
|
def testShuffle(self, dtype):
|
|
key = self.make_key(0)
|
|
x = np.arange(100).astype(dtype)
|
|
rand = lambda key: random.shuffle(key, x)
|
|
crand = jax.jit(rand)
|
|
|
|
with self.assertWarns(FutureWarning):
|
|
perm1 = rand(key)
|
|
with self.assertWarns(FutureWarning):
|
|
perm2 = crand(key)
|
|
|
|
self.assertAllClose(perm1, perm2)
|
|
self.assertFalse(np.all(perm1 == x)) # seems unlikely!
|
|
self.assertAllClose(np.sort(perm1), x, check_dtypes=False)
|
|
|
|
@jtu.sample_product(
|
|
[dict(shape=shape, replace=replace, axis=axis,
|
|
input_range_or_shape=input_range_or_shape)
|
|
for shape in [(), (5,), (4, 5)]
|
|
for replace in [True, False]
|
|
for input_range_or_shape in [100, (10, 10), (10, 5, 2), 1, (1, 5)]
|
|
for is_range in [type(input_range_or_shape) is int]
|
|
for ndim in [1 if is_range else len(input_range_or_shape)]
|
|
for axis in range(-ndim, ndim or 1)
|
|
for ninputs in [input_range_or_shape if is_range else input_range_or_shape[axis]]
|
|
if replace or math.prod(shape) <= ninputs
|
|
],
|
|
dtype=jtu.dtypes.floating + jtu.dtypes.integer,
|
|
weighted=[True, False],
|
|
)
|
|
def testChoice(self, dtype, input_range_or_shape, shape, replace, weighted, axis):
|
|
# This is the function API that we test against (note that self.rng().choice differs)
|
|
np_choice = np.random.default_rng(0).choice
|
|
p_dtype = dtypes.to_inexact_dtype(dtype)
|
|
|
|
key = self.make_key(0)
|
|
is_range = type(input_range_or_shape) is int
|
|
x = (input_range_or_shape if is_range else
|
|
self.rng().permutation(np.arange(math.prod(
|
|
input_range_or_shape), dtype=dtype)).reshape(input_range_or_shape))
|
|
N = x if is_range else x.shape[axis]
|
|
if weighted:
|
|
p = np.arange(N, dtype=p_dtype) + 1
|
|
p /= p.sum()
|
|
else:
|
|
p = None
|
|
rand = lambda key, x: random.choice(key, x, shape, replace, p, axis)
|
|
sample = rand(key, x)
|
|
if not is_range:
|
|
self.assertEqual(dtype, sample.dtype)
|
|
expected_shape = np.shape(np_choice(x, shape or None, replace, p, axis))
|
|
self.assertEqual(expected_shape, sample.shape)
|
|
expected_dtype = dtypes.result_type(int if is_range else x)
|
|
self.assertEqual(expected_dtype, sample.dtype)
|
|
if not replace and shape:
|
|
def lsort(x):
|
|
if not math.prod(x.shape): return x
|
|
ind = np.lexsort(np.swapaxes(x, axis, -1).reshape((-1, x.shape[axis])))
|
|
return jnp.take(x, ind, axis)
|
|
self.assertArraysEqual(lsort(sample), lsort(np.unique(sample, axis=axis)))
|
|
self.assertArraysEqual(sample, rand(key, np.array(x)))
|
|
self.assertArraysEqual(sample, jax.jit(rand, static_argnames=
|
|
'x' if is_range else None)(key, x))
|
|
|
|
@jtu.sample_product(
|
|
[dict(range_or_shape=range_or_shape, axis=axis)
|
|
for range_or_shape in [0, 1, 100, (0,), (1,), (100,),
|
|
(10, 10), (10, 5, 2), (0, 5), (1, 5)]
|
|
for ndim in [1 if type(range_or_shape) is int else len(range_or_shape)]
|
|
for axis in range(-ndim, ndim or 1)
|
|
],
|
|
dtype=jtu.dtypes.floating + jtu.dtypes.integer,
|
|
independent=[True, False],
|
|
)
|
|
def testPermutation(self, dtype, range_or_shape, axis, independent):
|
|
key = self.make_key(0)
|
|
is_range = type(range_or_shape) is int
|
|
x = (range_or_shape if is_range else
|
|
self.rng().permutation(np.arange(
|
|
math.prod(range_or_shape), dtype=dtype)).reshape(range_or_shape))
|
|
shape = ((range_or_shape,) if is_range else range_or_shape)
|
|
x_ = np.copy(x)
|
|
rand = lambda key, x: random.permutation(key, x, axis, independent=independent)
|
|
perm = rand(key, x)
|
|
if shape[axis] >= 10:
|
|
self.assertFalse(np.all(perm == x)) # seems unlikely!
|
|
arr = np.arange(x) if is_range else x
|
|
def lsort(x):
|
|
if not math.prod(x.shape): return x
|
|
ind = np.lexsort(np.swapaxes(x, axis, -1).reshape((-1, x.shape[axis])))
|
|
return jnp.take(x, ind, axis)
|
|
if not independent:
|
|
self.assertArraysEqual(lsort(arr), lsort(perm), check_dtypes=not is_range)
|
|
if independent and (arr.shape[axis] > 4) and (arr.size // arr.shape[axis] > 4):
|
|
# Check for independent shuffling if there are >4 vectors of size >4.
|
|
# Chance of false positive is 1 in (5!)^4
|
|
with self.assertRaises(AssertionError):
|
|
self.assertArraysEqual(lsort(arr), lsort(perm), check_dtypes=not is_range)
|
|
self.assertArraysEqual(x_, x)
|
|
self.assertArraysEqual(perm, rand(key, np.array(x)))
|
|
self.assertArraysEqual(perm, jax.jit(rand, static_argnames=
|
|
'x' if is_range else None)(key, x))
|
|
|
|
def testPermutationErrors(self):
|
|
key = self.make_key(0)
|
|
with self.assertRaises(ValueError):
|
|
random.permutation(key, 10, axis=3)
|
|
with self.assertRaises(TypeError):
|
|
random.permutation(key, 10.)
|
|
with self.assertRaises(core.ConcretizationTypeError):
|
|
jax.jit(random.permutation)(key, 10)
|
|
|
|
@jtu.sample_product(
|
|
p=[0.1, 0.5, 0.9],
|
|
dtype=jtu.dtypes.floating,
|
|
)
|
|
def testBernoulli(self, p, dtype):
|
|
key = self.make_key(0)
|
|
p = np.array(p, dtype=dtype)
|
|
rand = lambda key, p: random.bernoulli(key, p, (10000,))
|
|
crand = jax.jit(rand)
|
|
|
|
uncompiled_samples = rand(key, p)
|
|
compiled_samples = crand(key, p)
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
self._CheckChiSquared(samples, scipy.stats.bernoulli(p).pmf)
|
|
|
|
@jtu.sample_product(
|
|
[dict(p=p, axis=axis)
|
|
for (p, axis) in [
|
|
([.25] * 4, -1),
|
|
([.1, .2, .3, .4], -1),
|
|
([[.5, .5], [.1, .9]], 1),
|
|
([[.5, .1], [.5, .9]], 0),
|
|
]
|
|
],
|
|
sample_shape=[(10000,), (5000, 2)],
|
|
dtype=jtu.dtypes.floating,
|
|
)
|
|
def testCategorical(self, p, axis, dtype, sample_shape):
|
|
key = self.make_key(0)
|
|
p = np.array(p, dtype=dtype)
|
|
logits = np.log(p) - 42 # test unnormalized
|
|
out_shape = tuple(np.delete(logits.shape, axis))
|
|
shape = sample_shape + out_shape
|
|
rand = partial(random.categorical, shape=shape, axis=axis)
|
|
crand = jax.jit(rand)
|
|
|
|
uncompiled_samples = rand(key, logits)
|
|
compiled_samples = crand(key, logits)
|
|
|
|
if axis < 0:
|
|
axis += len(logits.shape)
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
assert samples.shape == shape
|
|
samples = jnp.reshape(samples, (10000,) + out_shape)
|
|
if len(p.shape[:-1]) > 0:
|
|
ps = np.transpose(p, (1, 0)) if axis == 0 else p
|
|
for cat_samples, cat_p in zip(samples.transpose(), ps):
|
|
pmf = lambda x: np.where(x < len(cat_p), cat_p[np.minimum(len(cat_p) - 1, x)], 0.0)
|
|
self._CheckChiSquared(cat_samples, pmf=pmf)
|
|
else:
|
|
pmf = lambda x: np.where(x < len(p), p[np.minimum(len(p) - 1, x)], 0.0)
|
|
self._CheckChiSquared(samples, pmf=pmf)
|
|
|
|
def testBernoulliShape(self):
|
|
key = self.make_key(0)
|
|
with jax.numpy_rank_promotion('allow'):
|
|
x = random.bernoulli(key, np.array([0.2, 0.3]), shape=(3, 2))
|
|
assert x.shape == (3, 2)
|
|
|
|
@jtu.sample_product(
|
|
a=[0.2, 5.],
|
|
b=[0.2, 5.],
|
|
dtype=[np.float64], # NOTE: KS test fails with float32
|
|
)
|
|
def testBeta(self, a, b, dtype):
|
|
if not config.x64_enabled:
|
|
raise SkipTest("skip test except on X64")
|
|
key = self.make_key(0)
|
|
rand = lambda key, a, b: random.beta(key, a, b, (10000,), dtype)
|
|
crand = jax.jit(rand)
|
|
|
|
uncompiled_samples = rand(key, a, b)
|
|
compiled_samples = crand(key, a, b)
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.beta(a, b).cdf)
|
|
|
|
@jtu.skip_on_devices("tpu") # TPU precision causes issues.
|
|
def testBetaSmallParameters(self, dtype=np.float32):
|
|
# Regression test for beta version of https://github.com/google/jax/issues/9896
|
|
key = self.make_key(0)
|
|
a, b = 0.0001, 0.0002
|
|
samples = random.beta(key, a, b, shape=(100,), dtype=dtype)
|
|
|
|
# With such small parameters, all samples should be exactly zero or one.
|
|
tol = 5E-2 if jtu.device_under_test() == "tpu" else 1E-3
|
|
|
|
zeros = samples[samples < 0.5]
|
|
self.assertAllClose(zeros, jnp.zeros_like(zeros), atol=tol)
|
|
|
|
ones = samples[samples >= 0.5]
|
|
self.assertAllClose(ones, jnp.ones_like(ones), atol=tol)
|
|
|
|
@jtu.sample_product(dtype=float_dtypes)
|
|
def testCauchy(self, dtype):
|
|
key = self.make_key(0)
|
|
rand = lambda key: random.cauchy(key, (10000,), dtype)
|
|
crand = jax.jit(rand)
|
|
|
|
uncompiled_samples = rand(key)
|
|
compiled_samples = crand(key)
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.cauchy().cdf)
|
|
|
|
@jtu.sample_product(
|
|
alpha=[np.array([0.2, 1., 5.]),],
|
|
dtype=jtu.dtypes.floating,
|
|
)
|
|
@jtu.skip_on_devices("tpu") # TODO(mattjj): slow compilation times
|
|
def testDirichlet(self, alpha, dtype):
|
|
key = self.make_key(0)
|
|
rand = lambda key, alpha: random.dirichlet(key, alpha, (10000,), dtype)
|
|
crand = jax.jit(rand)
|
|
|
|
uncompiled_samples = rand(key, alpha)
|
|
compiled_samples = crand(key, alpha)
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
self.assertAllClose(samples.sum(-1), np.ones(10000, dtype=dtype))
|
|
alpha_sum = sum(alpha)
|
|
for i, a in enumerate(alpha):
|
|
self._CheckKolmogorovSmirnovCDF(samples[..., i], scipy.stats.beta(a, alpha_sum - a).cdf)
|
|
|
|
@jtu.skip_on_devices("tpu") # lower accuracy leads to failures.
|
|
def testDirichletSmallAlpha(self, dtype=np.float32):
|
|
# Regression test for https://github.com/google/jax/issues/9896
|
|
key = self.make_key(0)
|
|
alpha = 0.00001 * jnp.ones(3)
|
|
samples = random.dirichlet(key, alpha, shape=(100,), dtype=dtype)
|
|
|
|
# Check that results lie on the simplex.
|
|
self.assertAllClose(samples.sum(1), jnp.ones(samples.shape[0]),
|
|
check_dtypes=False, rtol=1E-5)
|
|
|
|
# Check that results contain 1 in one of the dimensions:
|
|
# this is highly likely to be true when alpha is small.
|
|
self.assertAllClose(samples.max(1), jnp.ones(samples.shape[0]),
|
|
check_dtypes=False, rtol=1E-4)
|
|
|
|
@jtu.sample_product(dtype=float_dtypes)
|
|
def testExponential(self, dtype):
|
|
key = self.make_key(0)
|
|
rand = lambda key: random.exponential(key, (10000,), dtype)
|
|
crand = jax.jit(rand)
|
|
|
|
uncompiled_samples = rand(key)
|
|
compiled_samples = crand(key)
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.expon().cdf)
|
|
|
|
@jtu.sample_product(
|
|
a=[0.1, 1., 10.],
|
|
dtype=jtu.dtypes.floating,
|
|
)
|
|
@jtu.skip_on_devices("tpu") # low accuracy leads to failures.
|
|
def testGammaVsLogGamma(self, a, dtype):
|
|
# Test that gamma() and loggamma() produce equivalent samples.
|
|
key = self.make_key(0)
|
|
rand_gamma = lambda key, a: random.gamma(key, a, (100,), dtype)
|
|
rand_loggamma = lambda key, a: random.loggamma(key, a, (100,), dtype)
|
|
crand_loggamma = jax.jit(rand_loggamma)
|
|
tol = {np.float32: 1E-6, np.float64: 1E-12}
|
|
|
|
self.assertAllClose(rand_gamma(key, a), jnp.exp(rand_loggamma(key, a)),
|
|
atol=tol, rtol=tol)
|
|
self.assertAllClose(rand_gamma(key, a), jnp.exp(crand_loggamma(key, a)),
|
|
atol=tol, rtol=tol)
|
|
|
|
@jtu.sample_product(
|
|
a=[0.1, 1., 10.],
|
|
dtype=jtu.dtypes.floating,
|
|
)
|
|
def testGamma(self, a, dtype):
|
|
key = self.make_key(1)
|
|
rand = lambda key, a: random.gamma(key, a, (10000,), dtype)
|
|
crand = jax.jit(rand)
|
|
|
|
uncompiled_samples = rand(key, a)
|
|
compiled_samples = crand(key, a)
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gamma(a).cdf)
|
|
|
|
def testGammaShape(self):
|
|
key = self.make_key(0)
|
|
x = random.gamma(key, np.array([0.2, 0.3]), shape=(3, 2))
|
|
assert x.shape == (3, 2)
|
|
|
|
@jtu.sample_product(
|
|
log_space=[True, False],
|
|
alpha=[1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4],
|
|
)
|
|
def testGammaGrad(self, log_space, alpha):
|
|
rng = self.make_key(0)
|
|
alphas = np.full((100,), alpha)
|
|
z = random.gamma(rng, alphas)
|
|
if log_space:
|
|
actual_grad = jax.grad(lambda x: lax.exp(random.loggamma(rng, x)).sum())(alphas)
|
|
else:
|
|
actual_grad = jax.grad(lambda x: random.gamma(rng, x).sum())(alphas)
|
|
|
|
eps = 0.01 * alpha / (1.0 + np.sqrt(alpha))
|
|
cdf_dot = (scipy.stats.gamma.cdf(z, alpha + eps)
|
|
- scipy.stats.gamma.cdf(z, alpha - eps)) / (2 * eps)
|
|
with np.errstate(over='ignore'):
|
|
pdf = scipy.stats.gamma.pdf(z, alpha)
|
|
expected_grad = -cdf_dot / pdf
|
|
|
|
rtol = 2e-2 if jtu.device_under_test() == "tpu" else 7e-4
|
|
self.assertAllClose(actual_grad, expected_grad, check_dtypes=True,
|
|
rtol=rtol)
|
|
|
|
def testGammaGradType(self):
|
|
# Regression test for https://github.com/google/jax/issues/2130
|
|
key = self.make_key(0)
|
|
a = jnp.array(1., dtype=jnp.float32)
|
|
b = jnp.array(3., dtype=jnp.float32)
|
|
f = lambda x, y: random.gamma(key=key, a=x, dtype=jnp.float32) / y
|
|
# Should not crash with a type error.
|
|
jax.vjp(f, a, b)
|
|
|
|
@jtu.sample_product(
|
|
lam=[0.5, 3, 9, 11, 50, 500],
|
|
dtype=jtu.dtypes.supported([np.int16, np.int32, np.int64]),
|
|
)
|
|
def testPoisson(self, lam, dtype):
|
|
key = self.make_key(0)
|
|
rand = lambda key, lam: random.poisson(key, lam, (10000,), dtype)
|
|
crand = jax.jit(rand)
|
|
|
|
uncompiled_samples = rand(key, lam)
|
|
compiled_samples = crand(key, lam)
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
self._CheckChiSquared(samples, scipy.stats.poisson(lam).pmf)
|
|
# TODO(shoyer): determine error bounds for moments more rigorously (e.g.,
|
|
# based on the central limit theorem).
|
|
self.assertAllClose(samples.mean(), lam, rtol=0.02, check_dtypes=False)
|
|
self.assertAllClose(samples.var(), lam, rtol=0.03, check_dtypes=False)
|
|
|
|
def testPoissonBatched(self):
|
|
key = self.make_key(1)
|
|
lam = jnp.concatenate([2 * jnp.ones(10000), 20 * jnp.ones(10000)])
|
|
samples = random.poisson(key, lam, shape=(20000,))
|
|
self._CheckChiSquared(samples[:10000], scipy.stats.poisson(2.0).pmf)
|
|
self._CheckChiSquared(samples[10000:], scipy.stats.poisson(20.0).pmf)
|
|
|
|
def testPoissonWithoutShape(self):
|
|
key = self.make_key(1)
|
|
lam = 2 * jnp.ones(10000)
|
|
samples = random.poisson(key, lam)
|
|
self._CheckChiSquared(samples, scipy.stats.poisson(2.0).pmf)
|
|
|
|
def testPoissonShape(self):
|
|
key = self.make_key(0)
|
|
x = random.poisson(key, np.array([2.0, 20.0]), shape=(3, 2))
|
|
assert x.shape == (3, 2)
|
|
|
|
def testPoissonZeros(self):
|
|
key = self.make_key(0)
|
|
lam = jnp.concatenate([jnp.zeros(10), 20 * jnp.ones(10)])
|
|
samples = random.poisson(key, lam, shape=(2, 20))
|
|
self.assertArraysEqual(samples[:, :10], jnp.zeros_like(samples[:, :10]))
|
|
|
|
def testPoissonCornerCases(self):
|
|
key = self.make_key(0)
|
|
lam = jnp.array([-1, 0, jnp.nan])
|
|
samples = random.poisson(key, lam, shape=(3,))
|
|
self.assertArraysEqual(samples, jnp.array([-1, 0, -1]), check_dtypes=False)
|
|
|
|
@jtu.sample_product(dtype=jtu.dtypes.floating)
|
|
def testGumbel(self, dtype):
|
|
key = self.make_key(0)
|
|
rand = lambda key: random.gumbel(key, (10000,), dtype)
|
|
crand = jax.jit(rand)
|
|
|
|
uncompiled_samples = rand(key)
|
|
compiled_samples = crand(key)
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gumbel_r().cdf)
|
|
|
|
@jtu.sample_product(dtype=float_dtypes)
|
|
def testLaplace(self, dtype):
|
|
key = self.make_key(0)
|
|
rand = lambda key: random.laplace(key, (10000,), dtype)
|
|
crand = jax.jit(rand)
|
|
|
|
uncompiled_samples = rand(key)
|
|
compiled_samples = crand(key)
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.laplace().cdf)
|
|
|
|
@jtu.sample_product(dtype=float_dtypes)
|
|
def testLogistic(self, dtype):
|
|
key = self.make_key(0)
|
|
rand = lambda key: random.logistic(key, (10000,), dtype)
|
|
crand = jax.jit(rand)
|
|
|
|
uncompiled_samples = rand(key)
|
|
compiled_samples = crand(key)
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.logistic().cdf)
|
|
|
|
@jtu.sample_product(
|
|
n=range(1, 5),
|
|
shape=[(), (5,), (10, 5)],
|
|
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
|
|
)
|
|
@jax.default_matmul_precision("float32")
|
|
def testOrthogonal(self, n, shape, dtype):
|
|
key = self.make_key(0)
|
|
q = random.orthogonal(key, n, shape, dtype)
|
|
self.assertEqual(q.shape, (*shape, n, n))
|
|
self.assertEqual(q.dtype, dtype)
|
|
with jax.numpy_rank_promotion('allow'):
|
|
self.assertAllClose(
|
|
jnp.einsum('...ij,...jk->...ik', q, jnp.conj(q).swapaxes(-2, -1)),
|
|
jnp.broadcast_to(jnp.eye(n, dtype=dtype), (*shape, n, n))
|
|
)
|
|
|
|
@jtu.sample_product(
|
|
p=[.5, 1., 1.5, 2., 2.5],
|
|
shape=[(), (5,), (10, 5)],
|
|
dtype=jtu.dtypes.floating,
|
|
)
|
|
def testGeneralizedNormal(self, p, shape, dtype):
|
|
key = self.make_key(0)
|
|
rand = lambda key, p: random.generalized_normal(key, p, shape, dtype)
|
|
crand = jax.jit(rand)
|
|
uncompiled_samples = rand(key, p)
|
|
compiled_samples = crand(key, p)
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
self.assertEqual(samples.shape, shape)
|
|
self.assertEqual(samples.dtype, dtype)
|
|
self._CheckKolmogorovSmirnovCDF(samples.ravel(), scipy.stats.gennorm(p).cdf)
|
|
|
|
@jtu.sample_product(
|
|
d=range(1, 5),
|
|
p=[.5, 1., 1.5, 2., 2.5],
|
|
shape=[(), (5,), (10, 5)],
|
|
dtype=jtu.dtypes.floating,
|
|
)
|
|
@jtu.skip_on_devices("tpu") # TPU precision causes issues.
|
|
def testBall(self, d, p, shape, dtype):
|
|
key = self.make_key(123)
|
|
rand = lambda key, p: random.ball(key, d, p, shape, dtype)
|
|
crand = jax.jit(rand)
|
|
uncompiled_samples = rand(key, p)
|
|
compiled_samples = crand(key, p)
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
self.assertEqual(samples.shape, (*shape, d))
|
|
self.assertEqual(samples.dtype, dtype)
|
|
self.assertTrue(((jnp.abs(samples) ** p).sum(-1) <= 1).all())
|
|
norms = (jnp.abs(samples) ** p).sum(-1) ** (d / p)
|
|
self._CheckKolmogorovSmirnovCDF(norms.ravel(), scipy.stats.uniform().cdf)
|
|
|
|
@jtu.sample_product(
|
|
b=[0.1, 1., 10.],
|
|
dtype=jtu.dtypes.floating,
|
|
)
|
|
def testPareto(self, b, dtype):
|
|
key = self.make_key(0)
|
|
rand = lambda key, b: random.pareto(key, b, (10000,), dtype)
|
|
crand = jax.jit(rand)
|
|
|
|
uncompiled_samples = rand(key, b)
|
|
compiled_samples = crand(key, b)
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.pareto(b).cdf)
|
|
|
|
def testParetoShape(self):
|
|
key = self.make_key(0)
|
|
with jax.numpy_rank_promotion('allow'):
|
|
x = random.pareto(key, np.array([0.2, 0.3]), shape=(3, 2))
|
|
assert x.shape == (3, 2)
|
|
|
|
@jtu.sample_product(
|
|
df=[0.1, 1., 10.],
|
|
dtype=jtu.dtypes.floating,
|
|
)
|
|
@jtu.skip_on_devices("cpu", "tpu") # TODO(phawkins): slow compilation times
|
|
def testT(self, df, dtype):
|
|
key = self.make_key(1)
|
|
rand = lambda key, df: random.t(key, df, (10000,), dtype)
|
|
crand = jax.jit(rand)
|
|
|
|
uncompiled_samples = rand(key, df)
|
|
compiled_samples = crand(key, df)
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.t(df).cdf)
|
|
|
|
@jtu.sample_product(
|
|
dim=[1, 3, 5],
|
|
dtype=float_dtypes,
|
|
method=['svd', 'eigh', 'cholesky'],
|
|
)
|
|
def testMultivariateNormal(self, dim, dtype, method):
|
|
r = self.rng()
|
|
mean = r.randn(dim)
|
|
cov_factor = r.randn(dim, dim)
|
|
cov = np.dot(cov_factor, cov_factor.T) + dim * np.eye(dim)
|
|
|
|
key = self.make_key(0)
|
|
rand = partial(random.multivariate_normal, mean=mean, cov=cov,
|
|
shape=(10000,), method=method)
|
|
crand = jax.jit(rand)
|
|
|
|
with jax.numpy_rank_promotion('allow'):
|
|
uncompiled_samples = np.asarray(rand(key), np.float64)
|
|
compiled_samples = np.asarray(crand(key), np.float64)
|
|
|
|
inv_scale = scipy.linalg.lapack.dtrtri(np.linalg.cholesky(cov), lower=True)[0]
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
centered = samples - mean
|
|
whitened = np.einsum('nj,ij->ni', centered, inv_scale)
|
|
|
|
# This is a quick-and-dirty multivariate normality check that tests that a
|
|
# uniform mixture of the marginals along the covariance matrix's
|
|
# eigenvectors follow a standard normal distribution.
|
|
self._CheckKolmogorovSmirnovCDF(whitened.ravel(), scipy.stats.norm().cdf)
|
|
|
|
@jtu.sample_product(
|
|
dim=[1, 2, 4],
|
|
mean_batch_size=[(), (3,), (2, 3)],
|
|
cov_batch_size=[(), (3,), (2, 3)],
|
|
shape=[(), (1,), (5,)],
|
|
method=['cholesky', 'svd', 'eigh'],
|
|
)
|
|
def testMultivariateNormalShapes(self, dim, mean_batch_size, cov_batch_size,
|
|
shape, method):
|
|
r = self.rng()
|
|
key = self.make_key(0)
|
|
eff_batch_size = mean_batch_size \
|
|
if len(mean_batch_size) > len(cov_batch_size) else cov_batch_size
|
|
mean = r.randn(*(mean_batch_size + (dim,)))
|
|
cov_factor = r.randn(*(cov_batch_size + (dim, dim)))
|
|
cov = np.einsum('...ij,...kj->...ik', cov_factor, cov_factor)
|
|
cov += 1e-3 * np.eye(dim)
|
|
shape = shape + eff_batch_size
|
|
with jax.numpy_rank_promotion('allow'):
|
|
samples = random.multivariate_normal(key, mean, cov, shape=shape, method=method)
|
|
assert samples.shape == shape + (dim,)
|
|
|
|
def testMultivariateNormalCovariance(self):
|
|
# test code based on https://github.com/google/jax/issues/1869
|
|
N = 100000
|
|
mean = jnp.zeros(4)
|
|
cov = jnp.array([[ 0.19, 0.00, -0.13, 0.00],
|
|
[ 0.00, 0.29, 0.00, -0.23],
|
|
[ -0.13, 0.00, 0.39, 0.00],
|
|
[ 0.00, -0.23, 0.00, 0.49]], dtype=mean.dtype)
|
|
|
|
out_np = self.rng().multivariate_normal(mean, cov, N)
|
|
|
|
key = self.make_key(0)
|
|
with jax.numpy_rank_promotion('allow'):
|
|
out_jnp = random.multivariate_normal(key, mean=mean, cov=cov, shape=(N,))
|
|
|
|
var_np = out_np.var(axis=0)
|
|
var_jnp = out_jnp.var(axis=0)
|
|
self.assertAllClose(var_np, var_jnp, rtol=1e-2, atol=1e-2,
|
|
check_dtypes=False)
|
|
|
|
var_np = np.cov(out_np, rowvar=False)
|
|
var_jnp = np.cov(out_jnp, rowvar=False)
|
|
self.assertAllClose(var_np, var_jnp, rtol=1e-2, atol=1e-2,
|
|
check_dtypes=False)
|
|
|
|
@jtu.sample_product(method=['cholesky', 'eigh', 'svd'])
|
|
@jtu.skip_on_devices('gpu', 'tpu') # Some NaNs on accelerators.
|
|
def testMultivariateNormalSingularCovariance(self, method):
|
|
# Singular covariance matrix https://github.com/google/jax/discussions/13293
|
|
mu = jnp.zeros((2,))
|
|
sigma = jnp.ones((2, 2))
|
|
key = self.make_key(0)
|
|
result = random.multivariate_normal(key, mean=mu, cov=sigma, shape=(10,), method=method)
|
|
self.assertAllClose(result[:, 0], result[:, 1], atol=1e-3, rtol=1e-3)
|
|
|
|
# Cholesky fails for singular inputs.
|
|
if method == 'cholesky':
|
|
self.assertTrue(np.all(np.isnan(result)))
|
|
else:
|
|
self.assertFalse(np.any(np.isnan(result)))
|
|
|
|
def testIssue222(self):
|
|
x = random.randint(self.make_key(10003), (), 0, 0)
|
|
assert x == 0
|
|
|
|
def testFoldIn(self):
|
|
key = self.make_key(0)
|
|
keys = [_prng_key_as_array(random.fold_in(key, i)) for i in range(10)]
|
|
assert np.unique(keys, axis=0).shape[0] == 10
|
|
|
|
def testFoldInBig(self):
|
|
key = self.make_key(0)
|
|
seeds = [2 ** 32 - 2, 2 ** 32 - 1]
|
|
keys = [_prng_key_as_array(random.fold_in(key, seed)) for seed in seeds]
|
|
assert np.unique(keys, axis=0).shape[0] == 2
|
|
|
|
def testStaticShapeErrors(self):
|
|
if config.jax_disable_jit:
|
|
raise SkipTest("test only relevant when jit enabled")
|
|
|
|
@jax.jit
|
|
def feature_map(n, d, sigma=1.0, seed=123):
|
|
key = self.make_key(seed)
|
|
W = random.normal(key, (d, n)) / sigma
|
|
w = random.normal(key, (d, )) / sigma
|
|
b = 2 * jnp.pi * random.uniform(key, (d, ))
|
|
|
|
phi = lambda x, t: jnp.sqrt(2.0 / d) * jnp.cos(jnp.matmul(W, x) + w*t + b)
|
|
return phi
|
|
|
|
self.assertRaisesRegex(TypeError, 'Shapes must be 1D.*',
|
|
lambda: feature_map(5, 3))
|
|
|
|
def testIssue756(self):
|
|
key = self.make_key(0)
|
|
w = random.normal(key, ())
|
|
self.assertEqual(w.dtype, dtypes.canonicalize_dtype(jnp.float_))
|
|
|
|
def testIssue1789(self):
|
|
def f(x):
|
|
return random.gamma(self.make_key(0), x)
|
|
|
|
grad(lambda x: jnp.sum(vmap(f)(x)))(jnp.ones(2))
|
|
|
|
def testDtypeErrorMessage(self):
|
|
with self.assertRaisesRegex(ValueError, r"dtype argument to.*"):
|
|
random.normal(self.make_key(0), (), dtype=jnp.int32)
|
|
|
|
def testRandomBroadcast(self):
|
|
"""Issue 4033"""
|
|
# test for broadcast issue in https://github.com/google/jax/issues/4033
|
|
key = self.make_key(0)
|
|
shape = (10, 2)
|
|
with jax.numpy_rank_promotion('allow'):
|
|
x1 = random.uniform(key, shape, minval=jnp.zeros(2), maxval=jnp.ones(2))
|
|
x2 = random.randint(key, shape, jnp.array([0, 1]), jnp.array([1, 2]))
|
|
assert x1.shape == shape
|
|
assert x2.shape == shape
|
|
|
|
def testMaxwellSample(self):
|
|
num_samples = 10**5
|
|
rng = self.make_key(0)
|
|
|
|
rand = lambda x: random.maxwell(x, (num_samples, ))
|
|
crand = jax.jit(rand)
|
|
|
|
loc = jtu.to_default_dtype(scipy.stats.maxwell.mean())
|
|
std = jtu.to_default_dtype(scipy.stats.maxwell.std())
|
|
|
|
uncompiled_samples = rand(rng)
|
|
compiled_samples = crand(rng)
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
# Check first and second moments.
|
|
self.assertEqual((num_samples,), samples.shape)
|
|
self.assertAllClose(np.mean(samples), loc, atol=0., rtol=0.1)
|
|
self.assertAllClose(np.std(samples), std, atol=0., rtol=0.1)
|
|
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.maxwell().cdf)
|
|
|
|
@parameterized.named_parameters(
|
|
('test1', 4.0, 1.0),
|
|
('test2', 2.0, 3.0))
|
|
def testWeibullSample(self, concentration, scale):
|
|
num_samples = 10**5
|
|
rng = self.make_key(0)
|
|
|
|
rand = lambda x: random.weibull_min(x, scale, concentration, (num_samples,))
|
|
crand = jax.jit(rand)
|
|
|
|
loc = jtu.to_default_dtype(scipy.stats.weibull_min.mean(c=concentration, scale=scale))
|
|
std = jtu.to_default_dtype(scipy.stats.weibull_min.std(c=concentration, scale=scale))
|
|
|
|
uncompiled_samples = rand(rng)
|
|
compiled_samples = crand(rng)
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
# Check first and second moments.
|
|
self.assertEqual((num_samples,), samples.shape)
|
|
self.assertAllClose(np.mean(samples), loc, atol=0., rtol=0.1)
|
|
self.assertAllClose(np.std(samples), std, atol=0., rtol=0.1)
|
|
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.weibull_min(
|
|
c=concentration, scale=scale).cdf)
|
|
|
|
@parameterized.named_parameters(
|
|
('test1', 4.0, 1.0),
|
|
('test2', 2.0, 3.0))
|
|
def testDoublesidedMaxwellSample(self, loc, scale):
|
|
num_samples = 10**4
|
|
rng = self.make_key(0)
|
|
|
|
rand = lambda key: random.double_sided_maxwell(
|
|
rng, loc, scale, (num_samples,))
|
|
crand = jax.jit(rand)
|
|
|
|
mean = loc
|
|
std = np.sqrt(3.) * scale
|
|
|
|
uncompiled_samples = rand(rng)
|
|
compiled_samples = crand(rng)
|
|
|
|
# Compute the double sided maxwell CDF through the one sided maxwell cdf.
|
|
# This is done as follows:
|
|
# P(DSM <= x) = P (loc + scale * radamacher_sample * one_sided_sample <=x) =
|
|
# P (radamacher_sample * one_sided_sample <= (x - loc) / scale) =
|
|
# 1/2 P(one_sided_sample <= (x - loc) / scale)
|
|
# + 1/2 P( - one_sided_sample <= (x - loc) / scale) =
|
|
# 1/2 P(one_sided_sample <= (x - loc) / scale)
|
|
# + 1/2 P(one_sided_sample >= - (x - loc) / scale) =
|
|
# 1/2 CDF_one_maxwell((x - loc) / scale))
|
|
# + 1/2 (1 - CDF_one_maxwell(- (x - loc) / scale)))
|
|
def double_sided_maxwell_cdf(x, loc, scale):
|
|
pos = scipy.stats.maxwell().cdf((x - loc) / scale)
|
|
neg = (1 - scipy.stats.maxwell().cdf((-x + loc) / scale))
|
|
return (pos + neg) / 2
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
# Check first and second moments.
|
|
self.assertEqual((num_samples,), samples.shape)
|
|
self.assertAllClose(samples.mean(), jtu.to_default_dtype(mean), atol=0., rtol=0.1)
|
|
self.assertAllClose(samples.std(), jtu.to_default_dtype(std), atol=0., rtol=0.1)
|
|
|
|
self._CheckKolmogorovSmirnovCDF(
|
|
samples, lambda x: double_sided_maxwell_cdf(x, loc, scale))
|
|
|
|
def testRadamacher(self):
|
|
rng = self.make_key(0)
|
|
num_samples = 10**5
|
|
|
|
rand = lambda x: random.rademacher(x, (num_samples,))
|
|
crand = jax.jit(rand)
|
|
|
|
uncompiled_samples = rand(rng)
|
|
compiled_samples = crand(rng)
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
unique_values, counts = np.unique(samples, return_counts=True)
|
|
assert len(unique_values) == 2
|
|
assert len(counts) == 2
|
|
|
|
self.assertAllClose(
|
|
counts[0] / num_samples, 0.5, rtol=1e-02, atol=1e-02)
|
|
self.assertAllClose(
|
|
counts[1] / num_samples, 0.5, rtol=1e-02, atol=1e-02)
|
|
|
|
def testChoiceShapeIsNotSequenceError(self):
|
|
key = self.make_key(0)
|
|
with self.assertRaises(TypeError):
|
|
random.choice(key, 5, 2, replace=False)
|
|
with self.assertRaises(TypeError):
|
|
random.choice(key, 5, 2, replace=True)
|
|
|
|
def test_eval_shape_big_random_array(self):
|
|
def f(x):
|
|
return random.normal(self.make_key(x), (int(1e12),))
|
|
with jax.enable_checks(False): # check_jaxpr will materialize array
|
|
jax.eval_shape(f, 0) # doesn't error
|
|
|
|
@jtu.sample_product(
|
|
type_=["int", "np.array", "jnp.array"],
|
|
seed=[-1, 0, 1, (1 << 32) - 1, (1 << 63) - 1, np.uint64((1 << 64) - 1)],
|
|
)
|
|
def test_prng_jit_invariance(self, seed, type_):
|
|
if type_ == "int" and seed == (1 << 64) - 1:
|
|
self.skipTest("Expected failure: Python int too large.")
|
|
if not config.x64_enabled and seed > np.iinfo(np.int32).max:
|
|
self.skipTest("Expected failure: Python int too large.")
|
|
type_ = {"int": int, "np.array": np.array, "jnp.array": jnp.array}[type_]
|
|
args_maker = lambda: [type_(seed)]
|
|
f = lambda s: _maybe_unwrap(self.make_key(s))
|
|
self._CompileAndCheck(f, args_maker)
|
|
|
|
def test_prng_errors(self):
|
|
seed = np.iinfo(np.int64).max + 1
|
|
with self.assertRaises(OverflowError):
|
|
self.make_key(seed)
|
|
with self.assertRaises(OverflowError):
|
|
jax.jit(self.make_key)(seed)
|
|
|
|
def test_random_split_doesnt_device_put_during_tracing(self):
|
|
key = self.make_key(1).block_until_ready()
|
|
with jtu.count_device_put() as count:
|
|
jax.jit(random.split)(key)
|
|
self.assertLessEqual(count[0], 1) # 1 for the argument device_put
|
|
|
|
@jtu.sample_product(dtype=int_dtypes + uint_dtypes)
|
|
def test_randint_bounds(self, dtype):
|
|
min = np.iinfo(dtype).min
|
|
max = np.iinfo(dtype).max
|
|
key = self.make_key(1701)
|
|
shape = (10,)
|
|
if np.iinfo(dtype).bits < np.iinfo(dtypes.canonicalize_dtype(int)).bits:
|
|
expected = random.randint(key, shape, min, max, dtype)
|
|
self.assertArraysEqual(expected, random.randint(key, shape, min - 12345, max + 12345, dtype))
|
|
else:
|
|
self.assertRaises(OverflowError, random.randint, key, shape, min - 12345, max + 12345, dtype)
|
|
|
|
def test_randint_out_of_range(self):
|
|
key = self.make_key(0)
|
|
|
|
r = random.randint(key, (10,), 255, 256, np.uint8)
|
|
self.assertAllClose(r, jnp.full_like(r, 255))
|
|
|
|
r = random.randint(key, (1000,), -128, 128, np.int8)
|
|
self.assertGreater((r == -128).sum(), 0)
|
|
self.assertGreater((r == 127).sum(), 0)
|
|
|
|
r = random.randint(key, (1000,), -1000, 1000, np.uint8)
|
|
self.assertGreater((r == 0).sum(), 0)
|
|
self.assertGreater((r == 255).sum(), 0)
|
|
|
|
def test_large_prng(self):
|
|
# https://github.com/google/jax/issues/11010
|
|
def f():
|
|
return random.uniform(
|
|
self.make_key(3), (308000000, 128), dtype=jnp.bfloat16)
|
|
|
|
# just lower, don't run, takes too long
|
|
jax.jit(f).lower()
|
|
|
|
@jtu.sample_product(shape=[(3, 4)],
|
|
logits_shape_base=[(3, 4), (3, 1), (1, 4)],
|
|
axis=[-3, -2, -1, 0, 1, 2])
|
|
def test_categorical_shape_argument(self, shape, logits_shape_base, axis):
|
|
# https://github.com/google/jax/issues/13124
|
|
logits_shape = list(logits_shape_base)
|
|
logits_shape.insert(axis % (len(logits_shape_base) + 1), 10)
|
|
assert logits_shape[axis] == 10
|
|
logits = jnp.ones(logits_shape)
|
|
samples = random.categorical(self.make_key(0), logits=logits,
|
|
axis=axis, shape=shape)
|
|
self.assertEqual(samples.shape, shape)
|
|
|
|
@jtu.sample_product(
|
|
df = [0.2, 1., 10., 100.],
|
|
dtype=jtu.dtypes.floating)
|
|
def testChisquare(self, df, dtype):
|
|
key = self.make_key(1)
|
|
|
|
def rand(key, df):
|
|
return random.chisquare(key, df, shape=(10000,), dtype=dtype)
|
|
crand = jax.jit(rand)
|
|
|
|
uncompiled_samples = rand(key, df)
|
|
compiled_samples = crand(key, df)
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.chi2(df).cdf)
|
|
|
|
@jtu.sample_product(
|
|
dfnum = [1., 2., 10. ,100.],
|
|
dfden = [1. ,2., 10., 100.],
|
|
dtype=jtu.dtypes.floating)
|
|
def testF(self, dfnum, dfden, dtype):
|
|
key = self.make_key(1)
|
|
rand = lambda key: random.f(key, dfnum, dfden, shape = (10000, ), dtype = dtype)
|
|
crand = jax.jit(rand)
|
|
|
|
uncompiled_samples = rand(key)
|
|
compiled_samples = crand(key)
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.f(dfnum, dfden).cdf)
|
|
|
|
@jtu.sample_product(
|
|
scale= [0.2, 1., 2., 10. ,100.],
|
|
dtype=jtu.dtypes.floating)
|
|
def testRayleigh(self, scale, dtype):
|
|
key = self.make_key(0)
|
|
rand = lambda key: random.rayleigh(key, scale, shape = (10000, ), dtype = dtype)
|
|
crand = jax.jit(rand)
|
|
|
|
uncompiled_samples = rand(key)
|
|
compiled_samples = crand(key)
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.rayleigh(scale=scale).cdf)
|
|
|
|
@jtu.sample_product(
|
|
mean= [0.2, 1., 2., 10. ,100.],
|
|
dtype=jtu.dtypes.floating)
|
|
def testWald(self, mean, dtype):
|
|
key = self.make_key(0)
|
|
rand = lambda key: random.wald(key, mean, shape=(10000, ), dtype=dtype)
|
|
crand = jax.jit(rand)
|
|
|
|
uncompiled_samples = rand(key)
|
|
compiled_samples = crand(key)
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.invgauss(mu=mean).cdf)
|
|
|
|
@jtu.sample_product(
|
|
p=[0.2, 0.3, 0.4, 0.5 ,0.6],
|
|
dtype=jtu.dtypes.supported([np.int16, np.int32, np.int64]))
|
|
def testGeometric(self, p, dtype):
|
|
key = self.make_key(1)
|
|
rand = lambda key: random.geometric(key, p, shape=(10000, ), dtype=dtype)
|
|
crand = jax.jit(rand)
|
|
|
|
uncompiled_samples = rand(key)
|
|
compiled_samples = crand(key)
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
self._CheckChiSquared(samples, scipy.stats.geom(p).pmf)
|
|
self.assertAllClose(samples.mean(), 1 / p, rtol=0.02, check_dtypes=False)
|
|
self.assertAllClose(samples.var(), (1 - p) / (p * p) , rtol=0.05, check_dtypes=False)
|
|
|
|
@jtu.sample_product(
|
|
left = [0.2, 0.5, 1., 2.],
|
|
mode = [3., 5., 8., 9.],
|
|
right= [10., 20., 30., 40.],
|
|
dtype= jtu.dtypes.floating)
|
|
def testTriangular(self, left, mode, right, dtype):
|
|
key = self.make_key(1)
|
|
rand = lambda key: random.triangular(key, left, mode, right, shape=(10000, ), dtype=dtype)
|
|
crand = jax.jit(rand)
|
|
|
|
uncompiled_samples = rand(key)
|
|
compiled_samples = crand(key)
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.triang((mode - left) / (right - left), loc=left, scale=right - left).cdf)
|
|
|
|
@jtu.sample_product(
|
|
sigma = [0.2, 0.5, 1., 2.],
|
|
dtype=jtu.dtypes.floating)
|
|
def testLogNormal(self, sigma, dtype):
|
|
key = self.make_key(0)
|
|
rand = lambda key: random.lognormal(key, sigma, shape=(10000, ), dtype=dtype)
|
|
crand = jax.jit(rand)
|
|
|
|
uncompiled_samples = rand(key)
|
|
compiled_samples = crand(key)
|
|
|
|
for samples in [uncompiled_samples, compiled_samples]:
|
|
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.lognorm(s=sigma).cdf)
|
|
|
|
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
|
def test_copy(self, make_key):
|
|
key = make_key(8459302)
|
|
self.assertArraysEqual(key, key.copy())
|
|
self.assertArraysEqual(key, copy.copy(key))
|
|
self.assertArraysEqual(key, copy.deepcopy(key))
|
|
self.assertArraysEqual(key, jax.jit(lambda k: k.copy())(key))
|
|
|
|
|
|
class KeyArrayTest(jtu.JaxTestCase):
|
|
# Key arrays involve:
|
|
# * a Python key array type, backed by an underlying uint32 "base" array,
|
|
# * an abstract shaped array with key element type,
|
|
# * primitives that return or operate on such shaped arrays,
|
|
# * compiler lowerings,
|
|
# * a device-side data representation...
|
|
# Test it all!
|
|
#
|
|
# A handful of these tests follow CustomElementTypesTest in
|
|
# lax_tests.py as an example. If you add a test here (e.g. testing
|
|
# lowering of an key-dtyped shaped array), consider whether it
|
|
# might also be a more general test of opaque element types. If
|
|
# so, add a corresponding test to to CustomElementTypesTest as well.
|
|
|
|
def test_construction(self):
|
|
key = random.key(42)
|
|
self.assertIsInstance(key, random.PRNGKeyArray)
|
|
|
|
def test_issubdtype(self):
|
|
key = random.key(42)
|
|
self.assertTrue(jnp.issubdtype(key.dtype, dtypes.prng_key))
|
|
self.assertFalse(jnp.issubdtype(key.dtype, np.integer))
|
|
|
|
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
|
|
def test_construction_upgrade_flag(self):
|
|
key = random.PRNGKey(42)
|
|
self.assertIsInstance(key, random.PRNGKeyArray)
|
|
|
|
def make_keys(self, *shape, seed=28):
|
|
seeds = seed + jnp.arange(math.prod(shape), dtype=jnp.uint32)
|
|
return jax.vmap(random.key)(seeds).reshape(shape)
|
|
|
|
def test_key_as_seed(self):
|
|
key = self.make_keys()
|
|
with self.assertRaisesRegex(TypeError, "PRNGKey accepts a scalar seed"):
|
|
random.PRNGKey(key)
|
|
with self.assertRaisesRegex(TypeError, "key accepts a scalar seed"):
|
|
random.key(key)
|
|
|
|
def test_non_scalar_seed(self):
|
|
seed_arr = np.arange(4)
|
|
with self.assertRaisesRegex(TypeError, "PRNGKey accepts a scalar seed"):
|
|
random.PRNGKey(seed_arr)
|
|
with self.assertRaisesRegex(TypeError, "key accepts a scalar seed"):
|
|
random.key(seed_arr)
|
|
|
|
def test_non_integer_seed(self):
|
|
seed = np.pi
|
|
with self.assertRaisesRegex(TypeError, "PRNG key seed must be an integer"):
|
|
random.PRNGKey(seed)
|
|
with self.assertRaisesRegex(TypeError, "PRNG key seed must be an integer"):
|
|
random.key(seed)
|
|
|
|
def test_dtype_property(self):
|
|
k1, k2 = self.make_keys(), self.make_keys()
|
|
self.assertEqual(k1.dtype, k2.dtype)
|
|
|
|
k3, k4 = random.split(k1, 2)
|
|
self.assertEqual(k1.dtype, k3.dtype)
|
|
self.assertEqual(k3.dtype, k4.dtype)
|
|
|
|
g = []
|
|
def f(k):
|
|
g.append(k.dtype)
|
|
return random.split(k)
|
|
_ = jax.jit(f)(k1)
|
|
self.assertEqual(g[0], k1.dtype)
|
|
self.assertEqual(g[0], k2.dtype)
|
|
|
|
def test_key_dtype_attributes(self):
|
|
key = self.make_keys()
|
|
key_raw = key.unsafe_raw_array()
|
|
|
|
self.assertStartsWith(key.dtype.name, "key")
|
|
self.assertEqual(key.size * key.dtype.itemsize,
|
|
key_raw.size * key_raw.dtype.itemsize)
|
|
|
|
def test_isinstance(self):
|
|
@jax.jit
|
|
def f(k):
|
|
self.assertIsInstance(k, random.KeyArray)
|
|
return k
|
|
|
|
k1 = self.make_keys()
|
|
k2 = f(k1)
|
|
self.assertIsInstance(k1, random.KeyArray)
|
|
self.assertIsInstance(k2, random.KeyArray)
|
|
|
|
def test_cpp_dispatch_normal(self):
|
|
# Ensure we stay on the C++ dispatch path when calling a jitted
|
|
# function with a key array as an argument.
|
|
|
|
@jax.jit
|
|
def f(key):
|
|
return jax.random.normal(key)
|
|
|
|
key = self.make_keys()
|
|
with jtu.count_pjit_cpp_cache_miss() as count:
|
|
f(key).block_until_ready()
|
|
f(key).block_until_ready()
|
|
|
|
self.assertEqual(count[0], 1)
|
|
|
|
def test_cpp_dispatch_split(self):
|
|
# Ensure we stay on the C++ dispatch path when calling a jitted
|
|
# function with a key arrays as inputs and as outputs.
|
|
|
|
@jax.jit
|
|
def f(key):
|
|
return jax.random.split(key)
|
|
|
|
key = self.make_keys()
|
|
with jtu.count_pjit_cpp_cache_miss() as count:
|
|
f(key).block_until_ready()
|
|
f(key).block_until_ready()
|
|
|
|
self.assertEqual(count[0], 1)
|
|
|
|
def test_cpp_dispatch_aot_normal(self):
|
|
# Ensure we stay on the C++ dispatch path when calling an
|
|
# AOT-compiled function with a key array as an argument.
|
|
|
|
key = self.make_keys()
|
|
f = jax.jit(lambda key: jax.random.normal(key)).lower(key).compile()
|
|
|
|
with jtu.count_aot_jit_cpp_cache_miss() as count:
|
|
f(key).block_until_ready()
|
|
f(key).block_until_ready()
|
|
|
|
self.assertEqual(count[0], 1)
|
|
|
|
def test_cpp_dispatch_aot_split(self):
|
|
# Ensure we stay on the C++ dispatch path when calling an
|
|
# AOT-compiled function with a key arrays as inputs and as
|
|
# outputs.
|
|
|
|
key = self.make_keys()
|
|
f = jax.jit(lambda key: jax.random.split(key)).lower(key).compile()
|
|
|
|
with jtu.count_aot_jit_cpp_cache_miss() as count:
|
|
f(key).block_until_ready()
|
|
f(key).block_until_ready()
|
|
|
|
self.assertEqual(count[0], 1)
|
|
|
|
# -- prng primitives
|
|
|
|
def test_random_wrap_vmap(self):
|
|
f = partial(prng_internal.random_wrap, impl=prng_internal.threefry_prng_impl)
|
|
base_arr = jnp.arange(6, dtype=jnp.uint32).reshape(3, 2)
|
|
keys = jax.vmap(f, in_axes=0)(base_arr)
|
|
self.assertIsInstance(keys, random.KeyArray)
|
|
self.assertEqual(keys.shape, (3,))
|
|
keys = jax.vmap(f, in_axes=1)(base_arr.T)
|
|
self.assertIsInstance(keys, random.KeyArray)
|
|
self.assertEqual(keys.shape, (3,))
|
|
|
|
@jtu.sample_product(use_internal=[False, True])
|
|
def test_random_unwrap(self, use_internal):
|
|
unwrap = prng_internal.random_unwrap if use_internal else random.key_data
|
|
def f(k): return unwrap(k)
|
|
k = self.make_keys(3, 4)
|
|
out = f(k)
|
|
self.assertEqual(out.dtype, np.dtype('uint32'))
|
|
self.assertEqual(out.shape[:2], (3, 4))
|
|
out = jax.jit(f)(k)
|
|
self.assertEqual(out.dtype, np.dtype('uint32'))
|
|
self.assertEqual(out.shape[:2], (3, 4))
|
|
out = jax.vmap(f)(k)
|
|
self.assertEqual(out.dtype, np.dtype('uint32'))
|
|
self.assertEqual(out.shape[:2], (3, 4))
|
|
out = jax.vmap(jax.jit(f))(k)
|
|
self.assertEqual(out.dtype, np.dtype('uint32'))
|
|
self.assertEqual(out.shape[:2], (3, 4))
|
|
|
|
if not use_internal:
|
|
return
|
|
|
|
x = jnp.arange(12, dtype=np.dtype('uint32')).reshape(3, 4)
|
|
self.assertRaisesRegex(
|
|
TypeError, 'random_unwrap takes key array operand, got .*',
|
|
lambda: f(x))
|
|
self.assertRaisesRegex(
|
|
TypeError, 'random_unwrap takes key array operand, got .*',
|
|
lambda: jax.jit(f)(x))
|
|
self.assertRaisesRegex(
|
|
TypeError, 'random_unwrap takes key array operand, got .*',
|
|
lambda: jax.vmap(f)(x))
|
|
|
|
def test_eval_shape_keys_in(self):
|
|
def f(key):
|
|
return prng_internal.random_bits(key, bit_width=32, shape=(5,))
|
|
out = jax.eval_shape(f, self.make_keys())
|
|
self.assertEqual(out.shape, (5,))
|
|
self.assertEqual(out.dtype, np.dtype('uint32'))
|
|
|
|
def f(key):
|
|
return prng_internal.random_bits(key, bit_width=16, shape=(5,))
|
|
out = jax.eval_shape(f, self.make_keys())
|
|
self.assertEqual(out.shape, (5,))
|
|
self.assertEqual(out.dtype, np.dtype('uint16'))
|
|
|
|
def test_eval_shape_keys_out(self):
|
|
def f(seed):
|
|
return self.make_keys(seed=seed)
|
|
out = jax.eval_shape(f, 28)
|
|
self.assertEqual(out.shape, ())
|
|
# TODO(frostig): check dtype too when available
|
|
|
|
def test_eval_shape_keys_in_out(self):
|
|
def f(key):
|
|
return random.split(key)
|
|
out = jax.eval_shape(f, self.make_keys())
|
|
self.assertEqual(out.shape, (2,))
|
|
# TODO(frostig): check dtype too when available
|
|
|
|
def test_vmap(self):
|
|
ks = self.make_keys(3, 4, 5)
|
|
ys = jax.vmap(jax.jit(lambda k: k.T))(ks)
|
|
self.assertEqual(ys.shape, (3, 5, 4))
|
|
|
|
# -- dtype-polymorphic operation (esp. lowerings)
|
|
|
|
def test_scan_jaxpr(self):
|
|
ks = self.make_keys(3, 4, 5)
|
|
f = lambda ks: jax.lax.scan(lambda _, k: (None, k.T), None, ks)
|
|
jaxpr = jax.make_jaxpr(f)(ks).jaxpr
|
|
# { lambda ; a:key<fry>[3,4,5]. let
|
|
# b:key<fry>[3,5,4] = scan[
|
|
# jaxpr={ lambda ; c:key<fry>[4,5]. let
|
|
# d:key<fry>[5,4] = transpose[permutation=(1, 0)] c
|
|
# in (d,) }
|
|
# ] a
|
|
# in (b,) }
|
|
self.assertLen(jaxpr.invars, 1)
|
|
a, = jaxpr.invars
|
|
self.assertIsInstance(a.aval, core.ShapedArray)
|
|
self.assertEqual(a.aval.shape, (3, 4, 5))
|
|
self.assertIs(type(a.aval.dtype), prng_internal.KeyTy)
|
|
self.assertLen(jaxpr.eqns, 1)
|
|
e, = jaxpr.eqns
|
|
self.assertLen(e.outvars, 1)
|
|
b, = e.outvars
|
|
self.assertIsInstance(b.aval, core.ShapedArray)
|
|
self.assertEqual(b.aval.shape, (3, 5, 4))
|
|
self.assertIs(type(b.aval.dtype), prng_internal.KeyTy)
|
|
|
|
def test_scan_lowering(self):
|
|
ks = self.make_keys(3, 4)
|
|
f = lambda ks: jax.lax.scan(lambda _, k: (None, k.T), None, ks)
|
|
_, out = jax.jit(f)(ks) # doesn't crash
|
|
self.assertIsInstance(out, random.KeyArray)
|
|
self.assertEqual(out.shape, (3, 4))
|
|
|
|
def test_slice(self):
|
|
ks = self.make_keys(3, 4)
|
|
ys = jax.jit(lambda x: lax.slice_in_dim(x, 1, 3))(ks)
|
|
self.assertIsInstance(ys, random.KeyArray)
|
|
self.assertEqual(ys.shape, (2, 4))
|
|
|
|
def test_dynamic_slice(self):
|
|
ks = self.make_keys(3, 4)
|
|
index = np.int16(1) # non-default int type to catch type errors.
|
|
ys = jax.jit(partial(lax.dynamic_slice_in_dim, slice_size=2))(ks, index)
|
|
self.assertIsInstance(ys, random.KeyArray)
|
|
self.assertEqual(ys.shape, (2, 4))
|
|
|
|
def test_dynamic_update_slice(self):
|
|
ks = self.make_keys(3, 4)
|
|
k = self.make_keys(1, 4)
|
|
index = np.int16(1) # non-default int type to catch type errors.
|
|
ys = jax.jit(partial(lax.dynamic_update_slice_in_dim, axis=0))(ks, k, index)
|
|
self.assertIsInstance(ys, random.KeyArray)
|
|
self.assertEqual(ys.shape, (3, 4))
|
|
|
|
def test_transpose(self):
|
|
ks = self.make_keys(3, 4)
|
|
ys = jax.jit(lambda x: x.T)(ks)
|
|
self.assertIsInstance(ys, random.KeyArray)
|
|
self.assertEqual(ys.shape, (4, 3))
|
|
|
|
def test_gather(self):
|
|
ks = self.make_keys(3, 4)
|
|
ys = jax.jit(lambda x: x[1])(ks)
|
|
self.assertIsInstance(ys, random.KeyArray)
|
|
self.assertEqual(ys.shape, (4,))
|
|
|
|
ks = self.make_keys(3, 4, 5)
|
|
|
|
ys = jax.jit(lambda x: x[1])(ks)
|
|
self.assertIsInstance(ys, random.KeyArray)
|
|
self.assertEqual(ys.shape, (4, 5))
|
|
|
|
ys = jax.jit(lambda x: x[1, 2:4])(ks)
|
|
self.assertIsInstance(ys, random.KeyArray)
|
|
self.assertEqual(ys.shape, (2, 5))
|
|
|
|
ys = jax.jit(lambda x: x[1, 2:4, 3])(ks)
|
|
self.assertIsInstance(ys, random.KeyArray)
|
|
self.assertEqual(ys.shape, (2,))
|
|
|
|
ys = jax.jit(lambda x: x[:, 2:4, 3:4])(ks)
|
|
self.assertIsInstance(ys, random.KeyArray)
|
|
self.assertEqual(ys.shape, (3, 2, 1))
|
|
|
|
def test_select(self):
|
|
ks = self.make_keys(3, 2)
|
|
cs = jnp.array([True, False, False, True, False, True]).reshape(3, 2)
|
|
ys = jax.jit(lax.select)(cs, ks, ks)
|
|
self.assertIsInstance(ys, random.KeyArray)
|
|
self.assertEqual(ys.shape, (3, 2))
|
|
|
|
def test_select_scalar_cond(self):
|
|
# regression test for https://github.com/google/jax/issues/16422
|
|
ks = self.make_keys(3)
|
|
ys = lax.select(True, ks, ks)
|
|
self.assertIsInstance(ys, random.KeyArray)
|
|
self.assertEqual(ys.shape, (3,))
|
|
|
|
def test_vmap_of_cond(self):
|
|
# See https://github.com/google/jax/issues/15869
|
|
def f(x):
|
|
keys = self.make_keys(*x.shape)
|
|
return lax.select(x, keys, keys)
|
|
x = jnp.array([True, False, False])
|
|
f(x) # doesn't crash
|
|
|
|
def test_device_put(self):
|
|
device = jax.devices()[0]
|
|
keys = self.make_keys(4)
|
|
keys_on_device = jax.device_put(keys, device)
|
|
self.assertArraysEqual(keys, keys_on_device)
|
|
|
|
def test_device_put_sharded(self):
|
|
devices = jax.devices()
|
|
keys = self.make_keys(len(devices))
|
|
keys_on_device = jax.device_put_sharded(list(keys), devices)
|
|
self.assertArraysEqual(keys, keys_on_device)
|
|
|
|
def test_device_put_replicated(self):
|
|
devices = jax.devices()
|
|
key = self.make_keys()
|
|
keys_on_device = jax.device_put_replicated(key, devices)
|
|
self.assertArraysEqual(jnp.broadcast_to(key, keys_on_device.shape), keys_on_device)
|
|
|
|
def test_make_array_from_callback(self):
|
|
devices = jax.devices()
|
|
shape = (len(devices),)
|
|
mesh = jtu.create_global_mesh((len(devices),), ('x',))
|
|
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x'))
|
|
def callback(index):
|
|
i = jnp.arange(len(devices))[index[0]]
|
|
return jax.vmap(random.key)(i)
|
|
result = jax.make_array_from_callback(shape, sharding, callback)
|
|
expected = jax.vmap(random.key)(jnp.arange(len(devices)))
|
|
self.assertArraysEqual(result, expected)
|
|
|
|
def test_make_array_from_single_device_arrays(self):
|
|
devices = jax.devices()
|
|
shape = (len(devices),)
|
|
mesh = jtu.create_global_mesh((len(devices),), ('x',))
|
|
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x'))
|
|
keys = random.split(random.key(0), len(devices))
|
|
arrays = [jax.device_put(keys[i:i + 1], device) for i, device in enumerate(devices)]
|
|
result = jax.make_array_from_single_device_arrays(shape, sharding, arrays)
|
|
self.assertArraysEqual(result, keys)
|
|
|
|
def test_key_array_custom_jvp(self):
|
|
def f_raw(x, key):
|
|
return x * random.normal(key, ())
|
|
|
|
f = jax.custom_jvp(f_raw)
|
|
|
|
@f.defjvp
|
|
def f_jvp(primals, tangents):
|
|
nonlocal key_dot
|
|
x, key = primals
|
|
x_dot, key_dot = tangents
|
|
rand = random.normal(key, ())
|
|
tangent_out = x_dot * rand
|
|
primal_out = x * rand
|
|
return primal_out, tangent_out
|
|
|
|
key_dot = None
|
|
key = self.make_keys()
|
|
default_result = jax.grad(f_raw)(0.0, key)
|
|
custom_result = jax.grad(f)(0.0, key)
|
|
|
|
self.assertAllClose(default_result, custom_result)
|
|
self.assertIsInstance(key_dot, random.PRNGKeyArray)
|
|
self.assertArraysEqual(random.key_data(key_dot), np.uint32(0))
|
|
|
|
def test_key_array_indexing_0d(self):
|
|
key = self.make_keys()
|
|
self.assertEqual(key.shape, ())
|
|
self.assertEqual(key[None].shape, (1,))
|
|
self.assertRaisesRegex(IndexError, 'Too many indices.*', lambda: key[0])
|
|
|
|
def test_key_array_indexing_nd(self):
|
|
keys = self.make_keys(2, 3)
|
|
self.assertEqual(keys.shape, (2, 3))
|
|
self.assertEqual(keys[0, 0].shape, ())
|
|
self.assertEqual(keys[0, 1].shape, ())
|
|
self.assertEqual(keys[0].shape, (3,))
|
|
self.assertEqual(keys[1, :].shape, (3,))
|
|
self.assertEqual(keys[:, 1].shape, (2,))
|
|
self.assertEqual(keys[None].shape, (1, 2, 3))
|
|
self.assertEqual(keys[None, None].shape, (1, 1, 2, 3))
|
|
self.assertEqual(keys[None, :, None].shape, (1, 2, 1, 3))
|
|
self.assertEqual(keys[None, None, None, 0, None, None, None, 1].shape,
|
|
(1,) * 6)
|
|
self.assertEqual(keys[..., 1:, None].shape, (2, 2, 1))
|
|
self.assertEqual(keys[None, 0, ..., 1, None].shape, (1, 1))
|
|
self.assertRaisesRegex(IndexError, 'Too many indices.*',
|
|
lambda: keys[0, 1, 2])
|
|
self.assertRaisesRegex(IndexError, 'Too many indices.*',
|
|
lambda: keys[0, 1, None, 2])
|
|
|
|
def test_not_hashable(self):
|
|
key = self.make_keys()
|
|
with self.assertRaisesRegex(TypeError, "unhashable type"):
|
|
hash(key)
|
|
|
|
def test_array_impl_attributes(self):
|
|
# Test a number of ArrayImpl attributes
|
|
key = self.make_keys(10)
|
|
|
|
self.assertEqual(key.is_fully_addressable, key._base_array.is_fully_addressable)
|
|
self.assertEqual(key.is_fully_replicated, key._base_array.is_fully_replicated)
|
|
self.assertEqual(key.device(), key._base_array.device())
|
|
self.assertEqual(key.devices(), key._base_array.devices())
|
|
self.assertEqual(key.on_device_size_in_bytes, key._base_array.on_device_size_in_bytes)
|
|
self.assertEqual(key.unsafe_buffer_pointer, key._base_array.unsafe_buffer_pointer)
|
|
self.assertArraysEqual(key.addressable_data(0)._base_array,
|
|
key._base_array.addressable_data(0))
|
|
self.assertLen(key.addressable_shards, len(key._base_array.addressable_shards))
|
|
self.assertLen(key.global_shards, len(key._base_array.global_shards))
|
|
|
|
def test_delete(self):
|
|
key = self.make_keys(10)
|
|
|
|
self.assertFalse(key.is_deleted())
|
|
key.delete()
|
|
self.assertTrue(key.is_deleted())
|
|
self.assertTrue(key._base_array.is_deleted())
|
|
|
|
def test_async(self):
|
|
key = self.make_keys(10)
|
|
|
|
self.assertArraysEqual(key, key.block_until_ready())
|
|
self.assertIsNone(key.copy_to_host_async())
|
|
|
|
# TODO(frostig,mattjj): more polymorphic primitives tests
|
|
|
|
|
|
threefry_seed = prng_internal.threefry_seed
|
|
threefry_split = prng_internal.threefry_split
|
|
threefry_random_bits = prng_internal.threefry_random_bits
|
|
threefry_fold_in = prng_internal.threefry_fold_in
|
|
|
|
def _double_threefry_seed(seed):
|
|
int_t = seed.dtype.type if hasattr(seed, 'dtype') else type(seed)
|
|
s1, s2 = seed, seed ^ int_t(3)
|
|
return jnp.vstack([threefry_seed(s1),
|
|
threefry_seed(s2)])
|
|
|
|
def _double_threefry_split(key, shape):
|
|
return vmap(
|
|
threefry_split, (0, None), len(shape))(key, shape)
|
|
|
|
def _double_threefry_random_bits(key, bit_width, shape):
|
|
bits0 = threefry_random_bits(key[0], bit_width, shape)
|
|
bits1 = threefry_random_bits(key[1], bit_width, shape)
|
|
del bits1
|
|
# TODO(frostig): Currently this behaves like normal threefry, to
|
|
# avoid a few probabilistic test failures. Ideally we might want to
|
|
# test different generation behavior here (e.g. `bits0 ^ bits1`).
|
|
return bits0
|
|
|
|
def _double_threefry_fold_in(key, data):
|
|
return jnp.vstack([threefry_fold_in(key[0], data),
|
|
threefry_fold_in(key[1], data)])
|
|
|
|
double_threefry_prng_impl = prng_internal.PRNGImpl(
|
|
key_shape=(2, 2),
|
|
seed=_double_threefry_seed,
|
|
split=_double_threefry_split,
|
|
random_bits=_double_threefry_random_bits,
|
|
fold_in=_double_threefry_fold_in,
|
|
tag='fry2')
|
|
|
|
@jtu.with_config(jax_default_prng_impl='threefry2x32')
|
|
class LaxRandomWithCustomPRNGTest(LaxRandomTest):
|
|
def make_key(self, seed):
|
|
return prng_internal.seed_with_impl(double_threefry_prng_impl, seed)
|
|
|
|
def test_split_shape(self):
|
|
key = self.make_key(73)
|
|
keys = random.split(key, 10)
|
|
self.assertEqual(keys.shape, (10,))
|
|
|
|
def test_vmap_fold_in_shape(self):
|
|
# broadcast with scalar
|
|
keys = random.split(self.make_key(73), 2)
|
|
msgs = jnp.arange(3)
|
|
out = vmap(lambda i: random.fold_in(keys[0], i))(msgs)
|
|
self.assertEqual(out.shape, (3,))
|
|
out = vmap(lambda k: random.fold_in(k, msgs[0]))(keys)
|
|
self.assertEqual(out.shape, (2,))
|
|
out = vmap(random.fold_in, in_axes=(None, 0))(keys[0], msgs)
|
|
self.assertEqual(out.shape, (3,))
|
|
out = vmap(random.fold_in, in_axes=(0, None))(keys, msgs[0])
|
|
self.assertEqual(out.shape, (2,))
|
|
|
|
# vmap all
|
|
msgs = jnp.arange(2)
|
|
out = vmap(random.fold_in)(keys, msgs)
|
|
self.assertEqual(out.shape, (2,))
|
|
|
|
# nested vmap
|
|
keys = random.split(self.make_key(73), 2 * 3).reshape((2, 3))
|
|
msgs = jnp.arange(2 * 3).reshape((2, 3))
|
|
out = vmap(vmap(random.fold_in), in_axes=(0, 1))(keys, msgs.T)
|
|
self.assertEqual(out.shape, (2, 3))
|
|
out = vmap(vmap(random.fold_in), in_axes=(1, 0))(keys, msgs.T)
|
|
self.assertEqual(out.shape, (3, 2))
|
|
|
|
def test_vmap_split_mapped_key(self):
|
|
key = self.make_key(73)
|
|
mapped_keys = random.split(key, num=3)
|
|
forloop_keys = [random.split(k) for k in mapped_keys]
|
|
vmapped_keys = vmap(random.split)(mapped_keys)
|
|
self.assertEqual(vmapped_keys.shape, (3, 2))
|
|
for fk, vk in zip(forloop_keys, vmapped_keys):
|
|
self.assertArraysEqual(fk.unsafe_raw_array(),
|
|
vk.unsafe_raw_array())
|
|
|
|
def test_cannot_add(self):
|
|
key = self.make_key(73)
|
|
self.assertRaisesRegex(
|
|
ValueError, r'dtype=key<.*> is not a valid dtype for JAX type promotion.',
|
|
lambda: key + 47)
|
|
|
|
def test_grad_of_prng_key(self):
|
|
key = self.make_key(73)
|
|
with self.assertRaisesRegex(TypeError, 'grad requires real- or complex-valued inputs'):
|
|
jax.grad(lambda x: 1.)(key)
|
|
out = jax.grad(lambda x: 1., allow_int=True)(key)
|
|
self.assertArraysEqual(out, np.zeros(key.shape, jax.dtypes.float0))
|
|
|
|
|
|
@jtu.with_config(jax_default_prng_impl='rbg')
|
|
class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
|
def make_key(self, seed):
|
|
return random.rbg_key(seed)
|
|
|
|
def test_split_shape(self):
|
|
key = self.make_key(73)
|
|
keys = random.split(key, 10)
|
|
self.assertEqual(keys.shape, (10, *key.shape))
|
|
|
|
def test_vmap_fold_in_shape(self):
|
|
# broadcast with scalar
|
|
keys = random.split(self.make_key(73), 2)
|
|
msgs = jnp.arange(3)
|
|
|
|
out = vmap(lambda i: random.fold_in(keys[0], i))(msgs)
|
|
self.assertEqual(out.shape, (3, *keys[0].shape))
|
|
out = vmap(random.fold_in, in_axes=(None, 0))(keys[0], msgs)
|
|
self.assertEqual(out.shape, (3, *keys[0].shape))
|
|
|
|
out = vmap(lambda k: random.fold_in(k, msgs[0]))(keys)
|
|
self.assertEqual(out.shape, keys.shape)
|
|
out = vmap(random.fold_in, in_axes=(0, None))(keys, msgs[0])
|
|
self.assertEqual(out.shape, keys.shape)
|
|
|
|
def test_vmap_split_not_mapped_key(self):
|
|
key = self.make_key(73)
|
|
single_split_key = random.split(key)
|
|
vmapped_keys = vmap(lambda _: random.split(key))(jnp.zeros(3,))
|
|
self.assertEqual(vmapped_keys.shape, (3, 2, *key.shape))
|
|
for vk in vmapped_keys:
|
|
self.assertArraysEqual(_prng_key_as_array(vk),
|
|
_prng_key_as_array(single_split_key))
|
|
|
|
def test_vmap_split_mapped_key(self):
|
|
key = self.make_key(73)
|
|
mapped_keys = random.split(key, num=3)
|
|
forloop_keys = [random.split(k) for k in mapped_keys]
|
|
vmapped_keys = vmap(random.split)(mapped_keys)
|
|
self.assertEqual(vmapped_keys.shape, (3, 2, *key.shape))
|
|
for fk, vk in zip(forloop_keys, vmapped_keys):
|
|
self.assertArraysEqual(_prng_key_as_array(fk),
|
|
_prng_key_as_array(vk))
|
|
|
|
def test_vmap_random_bits(self):
|
|
rand_fun = lambda key: random.randint(key, (), 0, 100)
|
|
key = self.make_key(73)
|
|
mapped_keys = random.split(key, num=3)
|
|
forloop_rand_nums = [rand_fun(k) for k in mapped_keys]
|
|
rand_nums = vmap(rand_fun)(mapped_keys)
|
|
self.assertEqual(rand_nums.shape, (3,))
|
|
self.assertArraysEqual(rand_nums, jnp.array(forloop_rand_nums))
|
|
|
|
def test_cannot_add(self):
|
|
key = self.make_key(73)
|
|
if not jnp.issubdtype(key.dtype, dtypes.prng_key):
|
|
raise SkipTest('relies on typed key arrays')
|
|
self.assertRaisesRegex(
|
|
ValueError, r'dtype=key<.*> is not a valid dtype for JAX type promotion.',
|
|
lambda: key + 47)
|
|
|
|
def test_grad_of_prng_key(self):
|
|
key = self.make_key(73)
|
|
with self.assertRaisesRegex(TypeError, 'grad requires real- or complex-valued inputs'):
|
|
jax.grad(lambda x: 1.)(key)
|
|
out = jax.grad(lambda x: 1., allow_int=True)(key)
|
|
self.assertArraysEqual(out, np.zeros(key.shape, jax.dtypes.float0))
|
|
|
|
def test_random_split_doesnt_device_put_during_tracing(self):
|
|
return # this test doesn't apply to the RBG PRNG
|
|
|
|
def test_randint_out_of_range(self):
|
|
# TODO(mattjj): enable this test if/when RngBitGenerator supports it
|
|
raise SkipTest('8-bit types not supported with RBG PRNG')
|
|
|
|
|
|
@jtu.with_config(jax_default_prng_impl='unsafe_rbg')
|
|
class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest):
|
|
def make_key(self, seed):
|
|
return random.unsafe_rbg_key(seed)
|
|
|
|
|
|
def _sampler_unimplemented_with_custom_prng(*args, **kwargs):
|
|
raise SkipTest('sampler only implemented for default RNG')
|
|
|
|
for test_prefix in [
|
|
'testPoisson',
|
|
'testPoissonBatched',
|
|
'testPoissonShape',
|
|
'testPoissonZeros',
|
|
]:
|
|
for attr in dir(LaxRandomTest):
|
|
if attr.startswith(test_prefix):
|
|
setattr(LaxRandomWithCustomPRNGTest, attr,
|
|
_sampler_unimplemented_with_custom_prng)
|
|
setattr(LaxRandomWithRBGPRNGTest, attr,
|
|
_sampler_unimplemented_with_custom_prng)
|
|
setattr(LaxRandomWithUnsafeRBGPRNGTest, attr,
|
|
_sampler_unimplemented_with_custom_prng)
|
|
|
|
|
|
class JnpWithKeyArrayTest(jtu.JaxTestCase):
|
|
def check_shape(self, func, *args):
|
|
like = lambda keys: jnp.ones(keys.shape)
|
|
out_key = func(*args)
|
|
self.assertIsInstance(out_key, random.KeyArray)
|
|
out_like_key = func(*tree_util.tree_map(like, args))
|
|
self.assertIsInstance(out_like_key, jax.Array)
|
|
self.assertEqual(out_key.shape, out_like_key.shape)
|
|
|
|
def check_against_reference(self, key_func, arr_func, *key_args):
|
|
out_arr = arr_func(*tree_util.tree_map(lambda x: x.unsafe_raw_array(), key_args))
|
|
self.assertIsInstance(out_arr, jax.Array)
|
|
|
|
out_key = key_func(*key_args)
|
|
self.assertIsInstance(out_key, random.KeyArray)
|
|
self.assertArraysEqual(out_key.unsafe_raw_array(), out_arr)
|
|
|
|
out_key = jax.jit(key_func)(*key_args)
|
|
self.assertIsInstance(out_key, random.KeyArray)
|
|
self.assertArraysEqual(out_key.unsafe_raw_array(), out_arr)
|
|
|
|
@parameterized.parameters([
|
|
[(2, 3), 'shape', (2, 3)],
|
|
[(2, 3), 'size', 6],
|
|
[(2, 3), 'ndim', 2]
|
|
])
|
|
def test_properties(self, shape, prop, expected):
|
|
get_prop = lambda x: getattr(x, prop)
|
|
key = random.split(random.key(0), math.prod(shape)).reshape(shape)
|
|
self.assertEqual(get_prop(key), expected)
|
|
self.assertEqual(jax.jit(get_prop)(key), expected)
|
|
|
|
def test_reshape(self):
|
|
key = random.key(123)
|
|
keys = random.split(key, 4)
|
|
|
|
newshape = (2, 2)
|
|
key_func = partial(jnp.reshape, newshape=newshape)
|
|
arr_func = partial(jnp.reshape, newshape=(*newshape, *key.impl.key_shape))
|
|
|
|
self.check_shape(key_func, keys)
|
|
self.check_against_reference(key_func, arr_func, keys)
|
|
|
|
def test_tile(self):
|
|
key = random.key(123)
|
|
|
|
reps = 3
|
|
key_func = partial(jnp.tile, reps=reps)
|
|
arr_func = lambda x: jnp.tile(x[None], reps=(reps, *(1 for _ in key.impl.key_shape)))
|
|
|
|
self.check_shape(key_func, key)
|
|
self.check_against_reference(key_func, arr_func, key)
|
|
|
|
def test_concatenate(self):
|
|
key = random.key(123)
|
|
args = [random.split(k, 2) for k in random.split(key, 3)]
|
|
|
|
key_func = arr_func = partial(jnp.concatenate, axis=0)
|
|
|
|
self.check_shape(key_func, args)
|
|
self.check_against_reference(key_func, arr_func, args)
|
|
|
|
def test_broadcast_to(self):
|
|
key = random.key(123)
|
|
|
|
shape = (3,)
|
|
key_func = partial(jnp.broadcast_to, shape=shape)
|
|
arr_func = partial(jnp.broadcast_to, shape=(*shape, *key.impl.key_shape))
|
|
|
|
self.check_shape(key_func, key)
|
|
self.check_against_reference(key_func, arr_func, key)
|
|
|
|
def test_expand_dims(self):
|
|
key = random.key(123)
|
|
keys = random.split(key, 6).reshape(2, 3)
|
|
|
|
key_func = arr_func = partial(jnp.expand_dims, axis=1)
|
|
|
|
self.check_shape(key_func, keys)
|
|
self.check_against_reference(key_func, arr_func, keys)
|
|
|
|
def test_broadcast_arrays(self):
|
|
key = random.key(123)
|
|
keys = random.split(key, 3)
|
|
|
|
key_func = arr_func = lambda *args: jnp.broadcast_arrays(*args)[0]
|
|
|
|
self.check_shape(key_func, key, keys)
|
|
self.check_against_reference(key_func, arr_func, key, keys)
|
|
|
|
def test_append(self):
|
|
key = random.key(123)
|
|
keys = random.split(key, 4)
|
|
|
|
key_func = jnp.append
|
|
arr_func = lambda keys, key: jnp.append(keys, key[None], axis=0)
|
|
|
|
self.check_shape(key_func, keys, key)
|
|
self.check_against_reference(key_func, arr_func, keys, key)
|
|
|
|
def test_ravel(self):
|
|
key = random.key(123)
|
|
keys = random.split(key, 4).reshape(2, 2)
|
|
|
|
key_func = jnp.ravel
|
|
arr_func = partial(jnp.reshape, newshape=(4, *key.impl.key_shape))
|
|
|
|
self.check_shape(key_func, keys)
|
|
self.check_against_reference(key_func, arr_func, keys)
|
|
|
|
def test_stack(self):
|
|
key = random.key(123)
|
|
keys = random.split(key, 2)
|
|
|
|
key_func = arr_func = partial(jnp.stack, axis=0)
|
|
|
|
self.check_shape(key_func, keys)
|
|
self.check_against_reference(key_func, arr_func, keys)
|
|
|
|
def test_array(self):
|
|
key = random.key(123)
|
|
self.assertArraysEqual(key, jnp.array(key))
|
|
self.assertArraysEqual(key, jnp.asarray(key))
|
|
self.assertArraysEqual(key, jax.jit(jnp.array)(key))
|
|
self.assertArraysEqual(key, jax.jit(jnp.asarray)(key))
|
|
|
|
def test_array_user_dtype(self):
|
|
key = random.key(123)
|
|
self.assertArraysEqual(key, jnp.array(key, dtype=key.dtype))
|
|
self.assertArraysEqual(key, jnp.asarray(key, dtype=key.dtype))
|
|
|
|
@parameterized.parameters([
|
|
(0,),
|
|
(slice(1),),
|
|
(np.array([0, 2]),),
|
|
(np.array([False, True, True]),)
|
|
])
|
|
def test_getitem(self, idx):
|
|
key = random.key(123)
|
|
keys = random.split(key, 3)
|
|
|
|
key_func = arr_func = lambda x: x[idx]
|
|
|
|
self.check_shape(key_func, keys)
|
|
self.check_against_reference(key_func, arr_func, keys)
|
|
|
|
@parameterized.parameters([
|
|
(0,),
|
|
(slice(1),),
|
|
(np.array([0, 2]),),
|
|
(np.array([False, True, True]),)
|
|
])
|
|
def test_gather(self, idx):
|
|
key = random.key(123)
|
|
keys = random.split(key, 3)
|
|
|
|
key_func = arr_func = lambda x: x.at[idx].get()
|
|
|
|
self.check_shape(key_func, keys)
|
|
self.check_against_reference(key_func, arr_func, keys)
|
|
|
|
def test_equality(self):
|
|
key = random.key(123)
|
|
key2 = random.key(456)
|
|
|
|
self.assertTrue(key == key)
|
|
self.assertFalse(key == key2)
|
|
|
|
self.assertTrue(key != key2)
|
|
self.assertFalse(key != key)
|
|
|
|
size = 5
|
|
idx = slice(2, 4)
|
|
key_arr = random.split(key, size).at[idx].set(key)
|
|
expected = jnp.zeros(size, dtype=bool).at[idx].set(True)
|
|
|
|
self.assertArraysEqual(key == key_arr, expected)
|
|
self.assertArraysEqual(key != key_arr, ~expected)
|
|
|
|
@parameterized.parameters([
|
|
(0,),
|
|
(slice(1),),
|
|
(np.array([0, 2]),),
|
|
(np.array([False, True, True]),)
|
|
])
|
|
def test_scatter(self, idx):
|
|
key = random.key(123)
|
|
keys = random.split(key, 3)
|
|
|
|
key_func = arr_func = lambda x, y: x.at[idx].set(y)
|
|
|
|
self.check_shape(key_func, keys, key)
|
|
self.check_against_reference(key_func, arr_func, keys, key)
|
|
|
|
def test_errors(self):
|
|
key = random.key(123)
|
|
with self.assertRaisesRegex(ValueError, "dtype=key<fry> is not a valid dtype"):
|
|
jnp.add(key, 1)
|
|
with self.assertRaisesRegex(ValueError, "dtype=key<fry> is not a valid dtype"):
|
|
key + 1
|
|
with self.assertRaisesRegex(TypeError, "add does not accept dtype key<fry>"):
|
|
jnp.add(key, key)
|
|
with self.assertRaisesRegex(TypeError, "add does not accept dtype key<fry>"):
|
|
key + key
|
|
with self.assertRaisesRegex(TypeError, "neg does not accept dtype key<fry>"):
|
|
jnp.negative(key)
|
|
with self.assertRaisesRegex(TypeError, "neg does not accept dtype key<fry>"):
|
|
-key
|
|
with self.assertRaisesRegex(ValueError, "Cannot call convert_element_type on dtype key<fry>"):
|
|
lax.convert_element_type(key, int)
|
|
|
|
def test_eval_shape(self):
|
|
key = random.key(1701)
|
|
shapedtype = jax.ShapeDtypeStruct(key.shape, key.dtype)
|
|
out = jax.eval_shape(lambda x: x, shapedtype)
|
|
self.assertEqual(out, shapedtype)
|
|
|
|
def test_result_type(self):
|
|
key = random.key(123456)
|
|
self.assertEqual(jnp.result_type(key), key.dtype)
|
|
|
|
@parameterized.parameters([
|
|
(jnp.empty_like, ()),
|
|
(jnp.zeros_like, ()),
|
|
(jnp.ones_like, ()),
|
|
(jnp.full_like, (100,)),
|
|
])
|
|
def test_full_like(self, func, args):
|
|
keys = random.split(random.key(789543))
|
|
|
|
key_func = arr_func = lambda x: func(x, *args)
|
|
self.check_shape(key_func, keys)
|
|
self.check_against_reference(key_func, arr_func, keys)
|
|
|
|
def test_full_like_with_key_fillvalue(self):
|
|
keys = random.split(random.key(789543))
|
|
fill_value = random.key(42)
|
|
|
|
self.check_shape(jnp.full_like, keys, fill_value)
|
|
self.check_against_reference(jnp.full_like, jnp.full_like, keys, fill_value)
|
|
|
|
@parameterized.parameters([
|
|
(jnp.empty, {}),
|
|
(jnp.zeros, {}),
|
|
(jnp.ones, {}),
|
|
(jnp.full, {'fill_value': 100}),
|
|
])
|
|
def test_full(self, func, kwds):
|
|
keys = random.split(random.key(789543))
|
|
|
|
key_func = arr_func = lambda x: func(x.shape, dtype=x.dtype, **kwds)
|
|
self.check_shape(key_func, keys)
|
|
self.check_against_reference(key_func, arr_func, keys)
|
|
|
|
def test_full_with_key_fillvalue(self):
|
|
keys = random.split(random.key(789543))
|
|
fill_value = random.key(42)
|
|
func = lambda x, val: jnp.full(x.shape, val, dtype=x.dtype)
|
|
|
|
self.check_shape(func, keys, fill_value)
|
|
self.check_against_reference(func, func, keys, fill_value)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|