Merge pull request #16605 from froystig:random-test-with-new-and-old-rng

PiperOrigin-RevId: 544815010
This commit is contained in:
jax authors 2023-06-30 21:19:00 -07:00
commit 404e3061b6

View File

@ -54,13 +54,16 @@ uint_dtypes = jtu.dtypes.all_unsigned
def _prng_key_as_array(key):
# TODO(frostig): remove once we upgrade to always enable_custom_prng
return key.unsafe_raw_array() if config.jax_enable_custom_prng else key
# TODO(frostig): remove some day when we deprecate "raw" key arrays
if isinstance(key, jax.random.PRNGKeyArray):
return key.unsafe_raw_array()
else:
return key
def _maybe_unwrap(key):
# TODO(frostig): remove once we upgrade to always enable_custom_prng
# TODO(frostig): remove some day when we deprecate "raw" key arrays
unwrap = prng_internal.random_unwrap
return unwrap(key) if config.jax_enable_custom_prng else key
return unwrap(key) if isinstance(key, jax.random.PRNGKeyArray) else key
PRNG_IMPLS = [('threefry2x32', prng.threefry_prng_impl),
@ -199,8 +202,20 @@ _RANDOM_VALUES_CASES = [
]
KEY_CTORS = [random.key, random.PRNGKey]
class PrngTest(jtu.JaxTestCase):
def check_key_has_impl(self, key, impl):
if isinstance(key, random.PRNGKeyArray):
self.assertIs(key.impl, impl)
else:
self.assertEqual(key.dtype, jnp.dtype('uint32'))
self.assertEqual(key.shape, impl.key_shape)
def raw_key(self, *args, **kwargs):
return _prng_key_as_array(random.key(*args, **kwargs))
def testThreefry2x32(self):
# We test the hash by comparing to known values provided in the test code of
# the original reference implementation of Threefry. For the values, see
@ -252,15 +267,23 @@ class PrngTest(jtu.JaxTestCase):
xla.apply_primitive = apply_primitive
@skipIf(config.jax_threefry_partitionable, 'changed random bit values')
def testRngRandomBits(self):
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def testRngRandomBits(self, make_key):
# Test specific outputs to ensure consistent random values between JAX versions.
# TODO(frostig): remove once we always enable_custom_prng
def random_bits(key, *args):
def random_bits(key, width, shape):
# TODO(frostig): Use random.bits, as in:
#
# def random_bits(key, width, shape):
# dtype = jnp.dtype(f'uint{width}')
# return jax.random.bits(key, shape, dtype)
#
# Doing so doesn't work in width 64 at present due to
# normalization in random.bits.
key, _ = jax_random._check_prng_key(key)
return jax_random._random_bits(key, *args)
return jax_random._random_bits(key, width, shape)
key = random.PRNGKey(1701)
key = make_key(1701)
bits8 = random_bits(key, 8, (3,))
expected8 = np.array([216, 115, 43], dtype=np.uint8)
@ -283,19 +306,19 @@ class PrngTest(jtu.JaxTestCase):
expected64 = np.array([676898860, 3164047411, 4010691890], dtype=np.uint32)
self.assertArraysEqual(bits64, expected64)
@jtu.sample_product(prng_name=[name for name, _ in PRNG_IMPLS])
def testRngRandomBitsShapeDtype(self, prng_name):
@jtu.sample_product(prng_name=[name for name, _ in PRNG_IMPLS],
make_key=KEY_CTORS)
def testRngRandomBitsShapeDtype(self, prng_name, make_key):
# Like testRngRandomBits, but only meant to exercise random_bits
# on every PRNG implementation. Instead of values, only checks
# that shapes/dtypes are as expected.
# TODO(frostig): remove once we always enable_custom_prng
def random_bits(key, *args):
key, _ = jax_random._check_prng_key(key)
return jax_random._random_bits(key, *args)
def random_bits(key, width, shape):
dtype = jnp.dtype(f'uint{width}')
return jax.random.bits(key, shape, dtype)
with jax.default_prng_impl(prng_name):
key = random.PRNGKey(1701)
key = make_key(1701)
bits8 = random_bits(key, 8, (3,))
self.assertEqual(bits8.shape, (3,))
@ -316,27 +339,27 @@ class PrngTest(jtu.JaxTestCase):
self.assertEqual(bits64.dtype, expected_dtype)
@skipIf(config.jax_threefry_partitionable, 'changed random bit values')
def testRngRandomBitsViewProperty(self):
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def testRngRandomBitsViewProperty(self, make_key):
# TODO: add 64-bit if it ever supports this property.
# TODO: will this property hold across endian-ness?
# TODO(frostig): remove once we always enable_custom_prng
def random_bits(key, *args):
key, _ = jax_random._check_prng_key(key)
return jax_random._random_bits(key, *args)
def random_bits(key, width, shape):
dtype = jnp.dtype(f'uint{width}')
return jax.random.bits(key, shape, dtype)
N = 10
key = random.PRNGKey(1701)
key = make_key(1701)
nbits = [8, 16, 32]
rand_bits = [random_bits(key, n, (N * 64 // n,)) for n in nbits]
rand_bits_32 = np.array([np.array(r).view(np.uint32) for r in rand_bits])
assert np.all(rand_bits_32 == rand_bits_32[0])
@jtu.sample_product(case=_RANDOM_VALUES_CASES)
@jtu.sample_product(case=_RANDOM_VALUES_CASES, make_key=KEY_CTORS)
@skipIf(config.jax_threefry_partitionable, 'changed random bit values')
@jtu.skip_on_devices("tpu") # TPU precision causes issues.
def testRandomDistributionValues(self, case):
def testRandomDistributionValues(self, case, make_key):
"""
Tests values output by various distributions. This will catch any unintentional
changes to the implementations that could result in different random sequences.
@ -351,7 +374,7 @@ class PrngTest(jtu.JaxTestCase):
self.skipTest("test only valid when jax_enable_x64=True")
with jax.default_prng_impl(case.prng_impl):
func = getattr(random, case.name)
key = random.PRNGKey(case._seed())
key = make_key(case._seed())
if case.dtype:
actual = func(key, **case.params, shape=case.shape, dtype=case.dtype)
else:
@ -359,9 +382,10 @@ class PrngTest(jtu.JaxTestCase):
self.assertAllClose(actual, case.expected, atol=case.atol, rtol=case.rtol)
@skipIf(config.jax_threefry_partitionable, 'changed random bit values')
def testPRNGValues(self):
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def testPRNGValues(self, make_key):
# Test to ensure consistent random values between JAX versions
k = random.PRNGKey(0)
k = make_key(0)
self.assertEqual(random.randint(k, (3, 3), 0, 8).dtype,
dtypes.canonicalize_dtype(jnp.int_))
@ -388,17 +412,19 @@ class PrngTest(jtu.JaxTestCase):
_prng_key_as_array(random.fold_in(k, 4)),
np.array([2285895361, 433833334], dtype='uint32'))
def test_random_bits_error(self):
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def test_random_bits_error(self, make_key):
msg = 'dtype argument .* must be an unsigned int dtype'
with self.assertRaisesRegex(ValueError, msg):
random.bits(random.PRNGKey(0), (3, 4), np.dtype('int8'))
random.bits(make_key(0), (3, 4), np.dtype('int8'))
with self.assertRaisesRegex(ValueError, msg):
random.bits(random.PRNGKey(0), (3, 4), np.dtype('float16'))
random.bits(make_key(0), (3, 4), np.dtype('float16'))
@skipIf(not config.jax_threefry_partitionable, 'enable after upgrade')
def test_threefry_split_fold_in_symmetry(self):
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def test_threefry_split_fold_in_symmetry(self, make_key):
with jax.default_prng_impl('threefry2x32'):
key = random.PRNGKey(72)
key = make_key(72)
f1, f2, f3 = [random.fold_in(key, i) for i in range(3)]
s1, s2, s3 = random.split(key, 3)
f1, f2, f3 = map(_prng_key_as_array, [f1, f2, f3])
@ -408,10 +434,11 @@ class PrngTest(jtu.JaxTestCase):
self.assertArraysEqual(f3, s3)
@skipIf(not config.jax_threefry_partitionable, 'enable after upgrade')
def test_threefry_split_vmapped_fold_in_symmetry(self):
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def test_threefry_split_vmapped_fold_in_symmetry(self, make_key):
# See https://github.com/google/jax/issues/7708
with jax.default_prng_impl('threefry2x32'):
key = random.PRNGKey(72)
key = make_key(72)
f1, f2, f3 = vmap(lambda k, _: random.fold_in(k, lax.axis_index('batch')),
in_axes=(None, 0), axis_name='batch')(key, jnp.ones(3))
s1, s2, s3 = random.split(key, 3)
@ -421,36 +448,39 @@ class PrngTest(jtu.JaxTestCase):
self.assertArraysEqual(f2, s2)
self.assertArraysEqual(f3, s3)
@parameterized.parameters([
{"seed": 0, "typ": int, "jit": True, "key": [0, 0]},
{"seed": 0, "typ": int, "jit": False, "key": [0, 0]},
{"seed": 1, "typ": np.int32, "jit": True, "key": [0, 1]},
{"seed": 1, "typ": np.int32, "jit": False, "key": [0, 1]},
{"seed": 2, "typ": np.uint32, "jit": True, "key": [0, 2]},
{"seed": 2, "typ": np.uint32, "jit": False, "key": [0, 2]},
{"seed": 3, "typ": np.int64, "jit": True, "key": [0, 3]},
{"seed": 3, "typ": np.int64, "jit": False, "key": [0, 3]},
{"seed": -1, "typ": int, "jit": True, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
{"seed": -1, "typ": int, "jit": False, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
{"seed": -2, "typ": np.int32, "jit": True, "key": [0, 4294967294]},
{"seed": -2, "typ": np.int32, "jit": False, "key": [0, 4294967294]},
{"seed": -3, "typ": np.int64, "jit": True, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
{"seed": -3, "typ": np.int64, "jit": False, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
{"seed": np.iinfo(np.int32).max + 100, "typ": int, "jit": True, "key": [0, 2147483747]},
{"seed": np.iinfo(np.int32).max + 100, "typ": int, "jit": False, "key": [0, 2147483747]},
{"seed": np.iinfo(np.int32).max + 101, "typ": np.uint32, "jit": True, "key": [0, 2147483748]},
{"seed": np.iinfo(np.int32).max + 101, "typ": np.uint32, "jit": False, "key": [0, 2147483748]},
{"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": True, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
{"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": False, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
{"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": True, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
{"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": False, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
@parameterized.parameters([params
for d in [
{"seed": 0, "typ": int, "jit": True, "key": [0, 0]},
{"seed": 0, "typ": int, "jit": False, "key": [0, 0]},
{"seed": 1, "typ": np.int32, "jit": True, "key": [0, 1]},
{"seed": 1, "typ": np.int32, "jit": False, "key": [0, 1]},
{"seed": 2, "typ": np.uint32, "jit": True, "key": [0, 2]},
{"seed": 2, "typ": np.uint32, "jit": False, "key": [0, 2]},
{"seed": 3, "typ": np.int64, "jit": True, "key": [0, 3]},
{"seed": 3, "typ": np.int64, "jit": False, "key": [0, 3]},
{"seed": -1, "typ": int, "jit": True, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
{"seed": -1, "typ": int, "jit": False, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
{"seed": -2, "typ": np.int32, "jit": True, "key": [0, 4294967294]},
{"seed": -2, "typ": np.int32, "jit": False, "key": [0, 4294967294]},
{"seed": -3, "typ": np.int64, "jit": True, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
{"seed": -3, "typ": np.int64, "jit": False, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
{"seed": np.iinfo(np.int32).max + 100, "typ": int, "jit": True, "key": [0, 2147483747]},
{"seed": np.iinfo(np.int32).max + 100, "typ": int, "jit": False, "key": [0, 2147483747]},
{"seed": np.iinfo(np.int32).max + 101, "typ": np.uint32, "jit": True, "key": [0, 2147483748]},
{"seed": np.iinfo(np.int32).max + 101, "typ": np.uint32, "jit": False, "key": [0, 2147483748]},
{"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": True, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
{"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": False, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
{"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": True, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
{"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": False, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
]
for params in [dict(**d, make_key=ctor) for ctor in KEY_CTORS]
])
def test_prng_seeds_and_keys(self, seed, typ, jit, key):
def test_prng_seeds_and_keys(self, seed, typ, jit, key, make_key):
seed = typ(seed)
if jit:
maker = lambda k: _prng_key_as_array(jax.jit(random.PRNGKey)(k))
maker = lambda k: _prng_key_as_array(jax.jit(make_key)(k))
else:
maker = lambda k: _prng_key_as_array(random.PRNGKey(k))
maker = lambda k: _prng_key_as_array(make_key(k))
if (jit and typ is int and not config.x64_enabled and
(seed < np.iinfo('int32').min or seed > np.iinfo('int32').max)):
# We expect an error to be raised.
@ -474,109 +504,58 @@ class PrngTest(jtu.JaxTestCase):
expected = jnp.array(key, dtype=jnp.uint32)
self.assertArraysEqual(actual, expected)
def test_default_prng_selection(self):
if not config.jax_enable_custom_prng:
self.skipTest("test requires config.jax_enable_custom_prng")
for name, impl in PRNG_IMPLS:
with jax.default_prng_impl(name):
self.assertIs(random.default_prng_impl(), impl)
key = random.PRNGKey(42)
self.assertIs(key.impl, impl)
k1, k2 = random.split(key, 2)
self.assertIs(k1.impl, impl)
self.assertIs(k2.impl, impl)
def test_default_prng_selection_without_custom_prng_mode(self):
if config.jax_enable_custom_prng:
self.skipTest("test requires that config.jax_enable_custom_prng is False")
for name, impl in PRNG_IMPLS:
with jax.default_prng_impl(name):
self.assertIs(random.default_prng_impl(), impl)
key = random.PRNGKey(42)
self.assertEqual(key.shape, impl.key_shape)
k1, k2 = random.split(key, 2)
self.assertEqual(k1.shape, impl.key_shape)
self.assertEqual(k2.shape, impl.key_shape)
@parameterized.parameters([
{'make_key': ctor, 'name': name, 'impl': impl}
for ctor in KEY_CTORS
for name, impl in PRNG_IMPLS])
def test_default_prng_selection(self, make_key, name, impl):
with jax.default_prng_impl(name):
self.assertIs(random.default_prng_impl(), impl)
key = make_key(42)
self.check_key_has_impl(key, impl)
k1, k2 = random.split(key, 2)
self.check_key_has_impl(k1, impl)
self.check_key_has_impl(k2, impl)
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
def test_explicit_threefry2x32_key(self):
if not config.jax_enable_custom_prng:
self.skipTest("test requires config.jax_enable_custom_prng")
key = random.threefry2x32_key(42)
self.assertIs(key.impl, prng.threefry_prng_impl)
self.check_key_has_impl(random.threefry2x32_key(42),
prng.threefry_prng_impl)
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
def test_explicit_rbg_key(self):
if not config.jax_enable_custom_prng:
self.skipTest("test requires config.jax_enable_custom_prng")
key = random.rbg_key(42)
self.assertIs(key.impl, prng.rbg_prng_impl)
self.check_key_has_impl(random.rbg_key(42),
prng.rbg_prng_impl)
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
def test_explicit_unsafe_rbg_key(self):
if not config.jax_enable_custom_prng:
self.skipTest("test requires config.jax_enable_custom_prng")
key = random.unsafe_rbg_key(42)
self.assertIs(key.impl, prng.unsafe_rbg_prng_impl)
self.check_key_has_impl(random.unsafe_rbg_key(42),
prng.unsafe_rbg_prng_impl)
def test_key_construction_with_explicit_impl_name(self):
def check_key_has_impl(key, impl):
if isinstance(key, random.PRNGKeyArray):
self.assertIs(key.impl, impl)
else:
self.assertEqual(key.dtype, jnp.dtype('uint32'))
self.assertEqual(key.shape, impl.key_shape)
key = random.key(42, impl='threefry2x32')
check_key_has_impl(key, prng.threefry_prng_impl)
self.check_key_has_impl(key, prng.threefry_prng_impl)
key = random.key(42, impl='rbg')
check_key_has_impl(key, prng.rbg_prng_impl)
self.check_key_has_impl(key, prng.rbg_prng_impl)
key = random.key(42, impl='unsafe_rbg')
check_key_has_impl(key, prng.unsafe_rbg_prng_impl)
self.check_key_has_impl(key, prng.unsafe_rbg_prng_impl)
key = random.PRNGKey(42, impl='threefry2x32')
check_key_has_impl(key, prng.threefry_prng_impl)
self.check_key_has_impl(key, prng.threefry_prng_impl)
key = random.PRNGKey(42, impl='rbg')
check_key_has_impl(key, prng.rbg_prng_impl)
self.check_key_has_impl(key, prng.rbg_prng_impl)
key = random.PRNGKey(42, impl='unsafe_rbg')
check_key_has_impl(key, prng.unsafe_rbg_prng_impl)
self.check_key_has_impl(key, prng.unsafe_rbg_prng_impl)
def test_key_array_indexing_0d(self):
if not config.jax_enable_custom_prng:
self.skipTest("test requires config.jax_enable_custom_prng")
key = random.PRNGKey(1701)
self.assertEqual(key.shape, ())
self.assertEqual(key[None].shape, (1,))
self.assertRaisesRegex(IndexError, 'Too many indices.*', lambda: key[0])
def test_key_array_indexing_nd(self):
if not config.jax_enable_custom_prng:
self.skipTest("test requires config.jax_enable_custom_prng")
keys = vmap(vmap(random.PRNGKey))(jnp.arange(6).reshape((2, 3)))
self.assertEqual(keys.shape, (2, 3))
self.assertEqual(keys[0, 0].shape, ())
self.assertEqual(keys[0, 1].shape, ())
self.assertEqual(keys[0].shape, (3,))
self.assertEqual(keys[1, :].shape, (3,))
self.assertEqual(keys[:, 1].shape, (2,))
self.assertEqual(keys[None].shape, (1, 2, 3))
self.assertEqual(keys[None, None].shape, (1, 1, 2, 3))
self.assertEqual(keys[None, :, None].shape, (1, 2, 1, 3))
self.assertEqual(keys[None, None, None, 0, None, None, None, 1].shape,
(1,) * 6)
self.assertEqual(keys[..., 1:, None].shape, (2, 2, 1))
self.assertEqual(keys[None, 0, ..., 1, None].shape, (1, 1))
self.assertRaisesRegex(IndexError, 'Too many indices.*',
lambda: keys[0, 1, 2])
self.assertRaisesRegex(IndexError, 'Too many indices.*',
lambda: keys[0, 1, None, 2])
def test_isinstance(self):
if not config.jax_enable_custom_prng:
self.skipTest("test requires config.jax_enable_custom_prng")
key = random.PRNGKey(0)
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def test_isinstance(self, make_key):
key = make_key(0)
self.assertIsInstance(key, jax.Array)
def test_key_output_vjp(self):
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def test_key_output_vjp(self, make_key):
# See https://github.com/google/jax/issues/14856
def f(seed): return random.PRNGKey(seed)
def f(seed): return make_key(seed)
jax.vjp(f, 1) # doesn't crash
@ -602,6 +581,8 @@ class LaxRandomTest(jtu.JaxTestCase):
# conservative bound on statistical fail prob by Kolmo CDF
# bfloat16 quantization creates much lower p-values in large distributions
fail_prob = 0.003 if samples.dtype == jnp.bfloat16 else 0.01
# TODO(frostig): This reads enable_custom_prng as a proxy for
# whether RBG keys may be involved, but that's no longer exact.
if config.jax_enable_custom_prng and samples.dtype == jnp.bfloat16:
return
self.assertGreater(scipy.stats.kstest(samples, cdf).pvalue, fail_prob)
@ -1649,6 +1630,15 @@ class LaxRandomTest(jtu.JaxTestCase):
self.assertAllClose(samples.mean(), 1 / p, rtol=0.02, check_dtypes=False)
self.assertAllClose(samples.var(), (1 - p) / (p * p) , rtol=0.05, check_dtypes=False)
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def test_copy(self, make_key):
key = make_key(8459302)
self.assertArraysEqual(key, key.copy())
self.assertArraysEqual(key, copy.copy(key))
self.assertArraysEqual(key, copy.deepcopy(key))
self.assertArraysEqual(key, jax.jit(lambda k: k.copy())(key))
class KeyArrayTest(jtu.JaxTestCase):
# Key arrays involve:
# * a Python key array type, backed by an underlying uint32 "base" array,
@ -1664,6 +1654,15 @@ class KeyArrayTest(jtu.JaxTestCase):
# might also be a more general test of opaque element types. If
# so, add a corresponding test to to CustomElementTypesTest as well.
def test_construction(self):
key = random.key(42)
self.assertIsInstance(key, random.PRNGKeyArray)
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
def test_construction_upgrade_flag(self):
key = random.PRNGKey(42)
self.assertIsInstance(key, random.PRNGKeyArray)
def make_keys(self, *shape, seed=28):
seeds = seed + jnp.arange(math.prod(shape), dtype=jnp.uint32)
return jax.vmap(random.key)(seeds).reshape(shape)
@ -1754,8 +1753,7 @@ class KeyArrayTest(jtu.JaxTestCase):
self.assertEqual(out.dtype, np.dtype('uint32'))
self.assertEqual(out.shape[:2], (3, 4))
# TODO(frostig): simplify when we always enable_custom_prng
if not (config.jax_enable_custom_prng and use_internal):
if not use_internal:
return
x = jnp.arange(12, dtype=np.dtype('uint32')).reshape(3, 4)
@ -1972,6 +1970,32 @@ class KeyArrayTest(jtu.JaxTestCase):
self.assertIsInstance(key_dot, random.PRNGKeyArray)
self.assertArraysEqual(random.key_data(key_dot), np.uint32(0))
def test_key_array_indexing_0d(self):
key = self.make_keys()
self.assertEqual(key.shape, ())
self.assertEqual(key[None].shape, (1,))
self.assertRaisesRegex(IndexError, 'Too many indices.*', lambda: key[0])
def test_key_array_indexing_nd(self):
keys = self.make_keys(2, 3)
self.assertEqual(keys.shape, (2, 3))
self.assertEqual(keys[0, 0].shape, ())
self.assertEqual(keys[0, 1].shape, ())
self.assertEqual(keys[0].shape, (3,))
self.assertEqual(keys[1, :].shape, (3,))
self.assertEqual(keys[:, 1].shape, (2,))
self.assertEqual(keys[None].shape, (1, 2, 3))
self.assertEqual(keys[None, None].shape, (1, 1, 2, 3))
self.assertEqual(keys[None, :, None].shape, (1, 2, 1, 3))
self.assertEqual(keys[None, None, None, 0, None, None, None, 1].shape,
(1,) * 6)
self.assertEqual(keys[..., 1:, None].shape, (2, 2, 1))
self.assertEqual(keys[None, 0, ..., 1, None].shape, (1, 1))
self.assertRaisesRegex(IndexError, 'Too many indices.*',
lambda: keys[0, 1, 2])
self.assertRaisesRegex(IndexError, 'Too many indices.*',
lambda: keys[0, 1, None, 2])
def test_not_hashable(self):
key = self.make_keys()
with self.assertRaisesRegex(TypeError, "unhashable type"):
@ -2107,11 +2131,10 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
self.assertArraysEqual(out, np.zeros(key.shape, jax.dtypes.float0))
# TODO(frostig): remove `with_config` we always enable_custom_prng
@jtu.with_config(jax_default_prng_impl='rbg')
class LaxRandomWithRBGPRNGTest(LaxRandomTest):
def seed_prng(self, seed):
return random.rbg_key(seed)
return random.key(seed, impl='rbg')
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key arrays')
def test_split_shape(self):
@ -2162,7 +2185,6 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
@skipIf(np.__version__ == "1.21.0",
"https://github.com/numpy/numpy/issues/19305")
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key arrays')
def test_grad_of_prng_key(self):
key = self.seed_prng(73)
with self.assertRaisesRegex(TypeError, 'grad requires real- or complex-valued inputs'):
@ -2177,13 +2199,6 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
# TODO(mattjj): enable this test if/when RngBitGenerator supports it
raise SkipTest('8-bit types not supported with RBG PRNG')
def test_copy(self):
key = random.PRNGKey(8459302)
self.assertArraysEqual(key, key.copy())
self.assertArraysEqual(key, copy.copy(key))
self.assertArraysEqual(key, copy.deepcopy(key))
self.assertArraysEqual(key, jax.jit(lambda k: k.copy())(key))
# TODO(frostig): remove `with_config` we always enable_custom_prng
@jtu.with_config(jax_default_prng_impl='unsafe_rbg')