mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
increase random test coverage over RNG key constructors and representations
This is an incremental change to our random tests that primarily: * Increases test coverage of both key constructors (`random.key` and `random.PRNGKey`), often by parameterizing tests over both. * Increases test coverage of both key representations (typed key arrays and `uint32` arrays). * Removes a handful of guards on `config.jax_enable_custom_prng`, either replacing them with `isinstance` checks for typed keys or removing them altogether if possible. * Makes a handful of other individual test improvements and fixes, and leaves comments for more.
This commit is contained in:
parent
7b0596b90f
commit
f8dee51d9a
@ -54,13 +54,16 @@ uint_dtypes = jtu.dtypes.all_unsigned
|
||||
|
||||
|
||||
def _prng_key_as_array(key):
|
||||
# TODO(frostig): remove once we upgrade to always enable_custom_prng
|
||||
return key.unsafe_raw_array() if config.jax_enable_custom_prng else key
|
||||
# TODO(frostig): remove some day when we deprecate "raw" key arrays
|
||||
if isinstance(key, jax.random.PRNGKeyArray):
|
||||
return key.unsafe_raw_array()
|
||||
else:
|
||||
return key
|
||||
|
||||
def _maybe_unwrap(key):
|
||||
# TODO(frostig): remove once we upgrade to always enable_custom_prng
|
||||
# TODO(frostig): remove some day when we deprecate "raw" key arrays
|
||||
unwrap = prng_internal.random_unwrap
|
||||
return unwrap(key) if config.jax_enable_custom_prng else key
|
||||
return unwrap(key) if isinstance(key, jax.random.PRNGKeyArray) else key
|
||||
|
||||
|
||||
PRNG_IMPLS = [('threefry2x32', prng.threefry_prng_impl),
|
||||
@ -199,8 +202,20 @@ _RANDOM_VALUES_CASES = [
|
||||
]
|
||||
|
||||
|
||||
KEY_CTORS = [random.key, random.PRNGKey]
|
||||
|
||||
class PrngTest(jtu.JaxTestCase):
|
||||
|
||||
def check_key_has_impl(self, key, impl):
|
||||
if isinstance(key, random.PRNGKeyArray):
|
||||
self.assertIs(key.impl, impl)
|
||||
else:
|
||||
self.assertEqual(key.dtype, jnp.dtype('uint32'))
|
||||
self.assertEqual(key.shape, impl.key_shape)
|
||||
|
||||
def raw_key(self, *args, **kwargs):
|
||||
return _prng_key_as_array(random.key(*args, **kwargs))
|
||||
|
||||
def testThreefry2x32(self):
|
||||
# We test the hash by comparing to known values provided in the test code of
|
||||
# the original reference implementation of Threefry. For the values, see
|
||||
@ -252,15 +267,23 @@ class PrngTest(jtu.JaxTestCase):
|
||||
xla.apply_primitive = apply_primitive
|
||||
|
||||
@skipIf(config.jax_threefry_partitionable, 'changed random bit values')
|
||||
def testRngRandomBits(self):
|
||||
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
||||
def testRngRandomBits(self, make_key):
|
||||
# Test specific outputs to ensure consistent random values between JAX versions.
|
||||
|
||||
# TODO(frostig): remove once we always enable_custom_prng
|
||||
def random_bits(key, *args):
|
||||
def random_bits(key, width, shape):
|
||||
# TODO(frostig): Use random.bits, as in:
|
||||
#
|
||||
# def random_bits(key, width, shape):
|
||||
# dtype = jnp.dtype(f'uint{width}')
|
||||
# return jax.random.bits(key, shape, dtype)
|
||||
#
|
||||
# Doing so doesn't work in width 64 at present due to
|
||||
# normalization in random.bits.
|
||||
key, _ = jax_random._check_prng_key(key)
|
||||
return jax_random._random_bits(key, *args)
|
||||
return jax_random._random_bits(key, width, shape)
|
||||
|
||||
key = random.PRNGKey(1701)
|
||||
key = make_key(1701)
|
||||
|
||||
bits8 = random_bits(key, 8, (3,))
|
||||
expected8 = np.array([216, 115, 43], dtype=np.uint8)
|
||||
@ -283,19 +306,19 @@ class PrngTest(jtu.JaxTestCase):
|
||||
expected64 = np.array([676898860, 3164047411, 4010691890], dtype=np.uint32)
|
||||
self.assertArraysEqual(bits64, expected64)
|
||||
|
||||
@jtu.sample_product(prng_name=[name for name, _ in PRNG_IMPLS])
|
||||
def testRngRandomBitsShapeDtype(self, prng_name):
|
||||
@jtu.sample_product(prng_name=[name for name, _ in PRNG_IMPLS],
|
||||
make_key=KEY_CTORS)
|
||||
def testRngRandomBitsShapeDtype(self, prng_name, make_key):
|
||||
# Like testRngRandomBits, but only meant to exercise random_bits
|
||||
# on every PRNG implementation. Instead of values, only checks
|
||||
# that shapes/dtypes are as expected.
|
||||
|
||||
# TODO(frostig): remove once we always enable_custom_prng
|
||||
def random_bits(key, *args):
|
||||
key, _ = jax_random._check_prng_key(key)
|
||||
return jax_random._random_bits(key, *args)
|
||||
def random_bits(key, width, shape):
|
||||
dtype = jnp.dtype(f'uint{width}')
|
||||
return jax.random.bits(key, shape, dtype)
|
||||
|
||||
with jax.default_prng_impl(prng_name):
|
||||
key = random.PRNGKey(1701)
|
||||
key = make_key(1701)
|
||||
|
||||
bits8 = random_bits(key, 8, (3,))
|
||||
self.assertEqual(bits8.shape, (3,))
|
||||
@ -316,27 +339,27 @@ class PrngTest(jtu.JaxTestCase):
|
||||
self.assertEqual(bits64.dtype, expected_dtype)
|
||||
|
||||
@skipIf(config.jax_threefry_partitionable, 'changed random bit values')
|
||||
def testRngRandomBitsViewProperty(self):
|
||||
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
||||
def testRngRandomBitsViewProperty(self, make_key):
|
||||
# TODO: add 64-bit if it ever supports this property.
|
||||
# TODO: will this property hold across endian-ness?
|
||||
|
||||
# TODO(frostig): remove once we always enable_custom_prng
|
||||
def random_bits(key, *args):
|
||||
key, _ = jax_random._check_prng_key(key)
|
||||
return jax_random._random_bits(key, *args)
|
||||
def random_bits(key, width, shape):
|
||||
dtype = jnp.dtype(f'uint{width}')
|
||||
return jax.random.bits(key, shape, dtype)
|
||||
|
||||
N = 10
|
||||
key = random.PRNGKey(1701)
|
||||
key = make_key(1701)
|
||||
nbits = [8, 16, 32]
|
||||
rand_bits = [random_bits(key, n, (N * 64 // n,)) for n in nbits]
|
||||
rand_bits_32 = np.array([np.array(r).view(np.uint32) for r in rand_bits])
|
||||
assert np.all(rand_bits_32 == rand_bits_32[0])
|
||||
|
||||
|
||||
@jtu.sample_product(case=_RANDOM_VALUES_CASES)
|
||||
@jtu.sample_product(case=_RANDOM_VALUES_CASES, make_key=KEY_CTORS)
|
||||
@skipIf(config.jax_threefry_partitionable, 'changed random bit values')
|
||||
@jtu.skip_on_devices("tpu") # TPU precision causes issues.
|
||||
def testRandomDistributionValues(self, case):
|
||||
def testRandomDistributionValues(self, case, make_key):
|
||||
"""
|
||||
Tests values output by various distributions. This will catch any unintentional
|
||||
changes to the implementations that could result in different random sequences.
|
||||
@ -351,7 +374,7 @@ class PrngTest(jtu.JaxTestCase):
|
||||
self.skipTest("test only valid when jax_enable_x64=True")
|
||||
with jax.default_prng_impl(case.prng_impl):
|
||||
func = getattr(random, case.name)
|
||||
key = random.PRNGKey(case._seed())
|
||||
key = make_key(case._seed())
|
||||
if case.dtype:
|
||||
actual = func(key, **case.params, shape=case.shape, dtype=case.dtype)
|
||||
else:
|
||||
@ -359,9 +382,10 @@ class PrngTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(actual, case.expected, atol=case.atol, rtol=case.rtol)
|
||||
|
||||
@skipIf(config.jax_threefry_partitionable, 'changed random bit values')
|
||||
def testPRNGValues(self):
|
||||
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
||||
def testPRNGValues(self, make_key):
|
||||
# Test to ensure consistent random values between JAX versions
|
||||
k = random.PRNGKey(0)
|
||||
k = make_key(0)
|
||||
|
||||
self.assertEqual(random.randint(k, (3, 3), 0, 8).dtype,
|
||||
dtypes.canonicalize_dtype(jnp.int_))
|
||||
@ -388,17 +412,19 @@ class PrngTest(jtu.JaxTestCase):
|
||||
_prng_key_as_array(random.fold_in(k, 4)),
|
||||
np.array([2285895361, 433833334], dtype='uint32'))
|
||||
|
||||
def test_random_bits_error(self):
|
||||
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
||||
def test_random_bits_error(self, make_key):
|
||||
msg = 'dtype argument .* must be an unsigned int dtype'
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
random.bits(random.PRNGKey(0), (3, 4), np.dtype('int8'))
|
||||
random.bits(make_key(0), (3, 4), np.dtype('int8'))
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
random.bits(random.PRNGKey(0), (3, 4), np.dtype('float16'))
|
||||
random.bits(make_key(0), (3, 4), np.dtype('float16'))
|
||||
|
||||
@skipIf(not config.jax_threefry_partitionable, 'enable after upgrade')
|
||||
def test_threefry_split_fold_in_symmetry(self):
|
||||
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
||||
def test_threefry_split_fold_in_symmetry(self, make_key):
|
||||
with jax.default_prng_impl('threefry2x32'):
|
||||
key = random.PRNGKey(72)
|
||||
key = make_key(72)
|
||||
f1, f2, f3 = [random.fold_in(key, i) for i in range(3)]
|
||||
s1, s2, s3 = random.split(key, 3)
|
||||
f1, f2, f3 = map(_prng_key_as_array, [f1, f2, f3])
|
||||
@ -408,10 +434,11 @@ class PrngTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(f3, s3)
|
||||
|
||||
@skipIf(not config.jax_threefry_partitionable, 'enable after upgrade')
|
||||
def test_threefry_split_vmapped_fold_in_symmetry(self):
|
||||
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
||||
def test_threefry_split_vmapped_fold_in_symmetry(self, make_key):
|
||||
# See https://github.com/google/jax/issues/7708
|
||||
with jax.default_prng_impl('threefry2x32'):
|
||||
key = random.PRNGKey(72)
|
||||
key = make_key(72)
|
||||
f1, f2, f3 = vmap(lambda k, _: random.fold_in(k, lax.axis_index('batch')),
|
||||
in_axes=(None, 0), axis_name='batch')(key, jnp.ones(3))
|
||||
s1, s2, s3 = random.split(key, 3)
|
||||
@ -421,36 +448,39 @@ class PrngTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(f2, s2)
|
||||
self.assertArraysEqual(f3, s3)
|
||||
|
||||
@parameterized.parameters([
|
||||
{"seed": 0, "typ": int, "jit": True, "key": [0, 0]},
|
||||
{"seed": 0, "typ": int, "jit": False, "key": [0, 0]},
|
||||
{"seed": 1, "typ": np.int32, "jit": True, "key": [0, 1]},
|
||||
{"seed": 1, "typ": np.int32, "jit": False, "key": [0, 1]},
|
||||
{"seed": 2, "typ": np.uint32, "jit": True, "key": [0, 2]},
|
||||
{"seed": 2, "typ": np.uint32, "jit": False, "key": [0, 2]},
|
||||
{"seed": 3, "typ": np.int64, "jit": True, "key": [0, 3]},
|
||||
{"seed": 3, "typ": np.int64, "jit": False, "key": [0, 3]},
|
||||
{"seed": -1, "typ": int, "jit": True, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
|
||||
{"seed": -1, "typ": int, "jit": False, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
|
||||
{"seed": -2, "typ": np.int32, "jit": True, "key": [0, 4294967294]},
|
||||
{"seed": -2, "typ": np.int32, "jit": False, "key": [0, 4294967294]},
|
||||
{"seed": -3, "typ": np.int64, "jit": True, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
|
||||
{"seed": -3, "typ": np.int64, "jit": False, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
|
||||
{"seed": np.iinfo(np.int32).max + 100, "typ": int, "jit": True, "key": [0, 2147483747]},
|
||||
{"seed": np.iinfo(np.int32).max + 100, "typ": int, "jit": False, "key": [0, 2147483747]},
|
||||
{"seed": np.iinfo(np.int32).max + 101, "typ": np.uint32, "jit": True, "key": [0, 2147483748]},
|
||||
{"seed": np.iinfo(np.int32).max + 101, "typ": np.uint32, "jit": False, "key": [0, 2147483748]},
|
||||
{"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": True, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
|
||||
{"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": False, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
|
||||
{"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": True, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
|
||||
{"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": False, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
|
||||
@parameterized.parameters([params
|
||||
for d in [
|
||||
{"seed": 0, "typ": int, "jit": True, "key": [0, 0]},
|
||||
{"seed": 0, "typ": int, "jit": False, "key": [0, 0]},
|
||||
{"seed": 1, "typ": np.int32, "jit": True, "key": [0, 1]},
|
||||
{"seed": 1, "typ": np.int32, "jit": False, "key": [0, 1]},
|
||||
{"seed": 2, "typ": np.uint32, "jit": True, "key": [0, 2]},
|
||||
{"seed": 2, "typ": np.uint32, "jit": False, "key": [0, 2]},
|
||||
{"seed": 3, "typ": np.int64, "jit": True, "key": [0, 3]},
|
||||
{"seed": 3, "typ": np.int64, "jit": False, "key": [0, 3]},
|
||||
{"seed": -1, "typ": int, "jit": True, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
|
||||
{"seed": -1, "typ": int, "jit": False, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
|
||||
{"seed": -2, "typ": np.int32, "jit": True, "key": [0, 4294967294]},
|
||||
{"seed": -2, "typ": np.int32, "jit": False, "key": [0, 4294967294]},
|
||||
{"seed": -3, "typ": np.int64, "jit": True, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
|
||||
{"seed": -3, "typ": np.int64, "jit": False, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
|
||||
{"seed": np.iinfo(np.int32).max + 100, "typ": int, "jit": True, "key": [0, 2147483747]},
|
||||
{"seed": np.iinfo(np.int32).max + 100, "typ": int, "jit": False, "key": [0, 2147483747]},
|
||||
{"seed": np.iinfo(np.int32).max + 101, "typ": np.uint32, "jit": True, "key": [0, 2147483748]},
|
||||
{"seed": np.iinfo(np.int32).max + 101, "typ": np.uint32, "jit": False, "key": [0, 2147483748]},
|
||||
{"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": True, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
|
||||
{"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": False, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
|
||||
{"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": True, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
|
||||
{"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": False, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
|
||||
]
|
||||
for params in [dict(**d, make_key=ctor) for ctor in KEY_CTORS]
|
||||
])
|
||||
def test_prng_seeds_and_keys(self, seed, typ, jit, key):
|
||||
def test_prng_seeds_and_keys(self, seed, typ, jit, key, make_key):
|
||||
seed = typ(seed)
|
||||
if jit:
|
||||
maker = lambda k: _prng_key_as_array(jax.jit(random.PRNGKey)(k))
|
||||
maker = lambda k: _prng_key_as_array(jax.jit(make_key)(k))
|
||||
else:
|
||||
maker = lambda k: _prng_key_as_array(random.PRNGKey(k))
|
||||
maker = lambda k: _prng_key_as_array(make_key(k))
|
||||
if (jit and typ is int and not config.x64_enabled and
|
||||
(seed < np.iinfo('int32').min or seed > np.iinfo('int32').max)):
|
||||
# We expect an error to be raised.
|
||||
@ -474,109 +504,58 @@ class PrngTest(jtu.JaxTestCase):
|
||||
expected = jnp.array(key, dtype=jnp.uint32)
|
||||
self.assertArraysEqual(actual, expected)
|
||||
|
||||
def test_default_prng_selection(self):
|
||||
if not config.jax_enable_custom_prng:
|
||||
self.skipTest("test requires config.jax_enable_custom_prng")
|
||||
for name, impl in PRNG_IMPLS:
|
||||
with jax.default_prng_impl(name):
|
||||
self.assertIs(random.default_prng_impl(), impl)
|
||||
key = random.PRNGKey(42)
|
||||
self.assertIs(key.impl, impl)
|
||||
k1, k2 = random.split(key, 2)
|
||||
self.assertIs(k1.impl, impl)
|
||||
self.assertIs(k2.impl, impl)
|
||||
|
||||
def test_default_prng_selection_without_custom_prng_mode(self):
|
||||
if config.jax_enable_custom_prng:
|
||||
self.skipTest("test requires that config.jax_enable_custom_prng is False")
|
||||
for name, impl in PRNG_IMPLS:
|
||||
with jax.default_prng_impl(name):
|
||||
self.assertIs(random.default_prng_impl(), impl)
|
||||
key = random.PRNGKey(42)
|
||||
self.assertEqual(key.shape, impl.key_shape)
|
||||
k1, k2 = random.split(key, 2)
|
||||
self.assertEqual(k1.shape, impl.key_shape)
|
||||
self.assertEqual(k2.shape, impl.key_shape)
|
||||
@parameterized.parameters([
|
||||
{'make_key': ctor, 'name': name, 'impl': impl}
|
||||
for ctor in KEY_CTORS
|
||||
for name, impl in PRNG_IMPLS])
|
||||
def test_default_prng_selection(self, make_key, name, impl):
|
||||
with jax.default_prng_impl(name):
|
||||
self.assertIs(random.default_prng_impl(), impl)
|
||||
key = make_key(42)
|
||||
self.check_key_has_impl(key, impl)
|
||||
k1, k2 = random.split(key, 2)
|
||||
self.check_key_has_impl(k1, impl)
|
||||
self.check_key_has_impl(k2, impl)
|
||||
|
||||
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
|
||||
def test_explicit_threefry2x32_key(self):
|
||||
if not config.jax_enable_custom_prng:
|
||||
self.skipTest("test requires config.jax_enable_custom_prng")
|
||||
key = random.threefry2x32_key(42)
|
||||
self.assertIs(key.impl, prng.threefry_prng_impl)
|
||||
self.check_key_has_impl(random.threefry2x32_key(42),
|
||||
prng.threefry_prng_impl)
|
||||
|
||||
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
|
||||
def test_explicit_rbg_key(self):
|
||||
if not config.jax_enable_custom_prng:
|
||||
self.skipTest("test requires config.jax_enable_custom_prng")
|
||||
key = random.rbg_key(42)
|
||||
self.assertIs(key.impl, prng.rbg_prng_impl)
|
||||
self.check_key_has_impl(random.rbg_key(42),
|
||||
prng.rbg_prng_impl)
|
||||
|
||||
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
|
||||
def test_explicit_unsafe_rbg_key(self):
|
||||
if not config.jax_enable_custom_prng:
|
||||
self.skipTest("test requires config.jax_enable_custom_prng")
|
||||
key = random.unsafe_rbg_key(42)
|
||||
self.assertIs(key.impl, prng.unsafe_rbg_prng_impl)
|
||||
self.check_key_has_impl(random.unsafe_rbg_key(42),
|
||||
prng.unsafe_rbg_prng_impl)
|
||||
|
||||
def test_key_construction_with_explicit_impl_name(self):
|
||||
def check_key_has_impl(key, impl):
|
||||
if isinstance(key, random.PRNGKeyArray):
|
||||
self.assertIs(key.impl, impl)
|
||||
else:
|
||||
self.assertEqual(key.dtype, jnp.dtype('uint32'))
|
||||
self.assertEqual(key.shape, impl.key_shape)
|
||||
|
||||
key = random.key(42, impl='threefry2x32')
|
||||
check_key_has_impl(key, prng.threefry_prng_impl)
|
||||
self.check_key_has_impl(key, prng.threefry_prng_impl)
|
||||
key = random.key(42, impl='rbg')
|
||||
check_key_has_impl(key, prng.rbg_prng_impl)
|
||||
self.check_key_has_impl(key, prng.rbg_prng_impl)
|
||||
key = random.key(42, impl='unsafe_rbg')
|
||||
check_key_has_impl(key, prng.unsafe_rbg_prng_impl)
|
||||
self.check_key_has_impl(key, prng.unsafe_rbg_prng_impl)
|
||||
|
||||
key = random.PRNGKey(42, impl='threefry2x32')
|
||||
check_key_has_impl(key, prng.threefry_prng_impl)
|
||||
self.check_key_has_impl(key, prng.threefry_prng_impl)
|
||||
key = random.PRNGKey(42, impl='rbg')
|
||||
check_key_has_impl(key, prng.rbg_prng_impl)
|
||||
self.check_key_has_impl(key, prng.rbg_prng_impl)
|
||||
key = random.PRNGKey(42, impl='unsafe_rbg')
|
||||
check_key_has_impl(key, prng.unsafe_rbg_prng_impl)
|
||||
self.check_key_has_impl(key, prng.unsafe_rbg_prng_impl)
|
||||
|
||||
def test_key_array_indexing_0d(self):
|
||||
if not config.jax_enable_custom_prng:
|
||||
self.skipTest("test requires config.jax_enable_custom_prng")
|
||||
key = random.PRNGKey(1701)
|
||||
self.assertEqual(key.shape, ())
|
||||
self.assertEqual(key[None].shape, (1,))
|
||||
self.assertRaisesRegex(IndexError, 'Too many indices.*', lambda: key[0])
|
||||
|
||||
def test_key_array_indexing_nd(self):
|
||||
if not config.jax_enable_custom_prng:
|
||||
self.skipTest("test requires config.jax_enable_custom_prng")
|
||||
keys = vmap(vmap(random.PRNGKey))(jnp.arange(6).reshape((2, 3)))
|
||||
self.assertEqual(keys.shape, (2, 3))
|
||||
self.assertEqual(keys[0, 0].shape, ())
|
||||
self.assertEqual(keys[0, 1].shape, ())
|
||||
self.assertEqual(keys[0].shape, (3,))
|
||||
self.assertEqual(keys[1, :].shape, (3,))
|
||||
self.assertEqual(keys[:, 1].shape, (2,))
|
||||
self.assertEqual(keys[None].shape, (1, 2, 3))
|
||||
self.assertEqual(keys[None, None].shape, (1, 1, 2, 3))
|
||||
self.assertEqual(keys[None, :, None].shape, (1, 2, 1, 3))
|
||||
self.assertEqual(keys[None, None, None, 0, None, None, None, 1].shape,
|
||||
(1,) * 6)
|
||||
self.assertEqual(keys[..., 1:, None].shape, (2, 2, 1))
|
||||
self.assertEqual(keys[None, 0, ..., 1, None].shape, (1, 1))
|
||||
self.assertRaisesRegex(IndexError, 'Too many indices.*',
|
||||
lambda: keys[0, 1, 2])
|
||||
self.assertRaisesRegex(IndexError, 'Too many indices.*',
|
||||
lambda: keys[0, 1, None, 2])
|
||||
|
||||
def test_isinstance(self):
|
||||
if not config.jax_enable_custom_prng:
|
||||
self.skipTest("test requires config.jax_enable_custom_prng")
|
||||
key = random.PRNGKey(0)
|
||||
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
||||
def test_isinstance(self, make_key):
|
||||
key = make_key(0)
|
||||
self.assertIsInstance(key, jax.Array)
|
||||
|
||||
def test_key_output_vjp(self):
|
||||
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
||||
def test_key_output_vjp(self, make_key):
|
||||
# See https://github.com/google/jax/issues/14856
|
||||
def f(seed): return random.PRNGKey(seed)
|
||||
def f(seed): return make_key(seed)
|
||||
jax.vjp(f, 1) # doesn't crash
|
||||
|
||||
|
||||
@ -602,6 +581,8 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
# conservative bound on statistical fail prob by Kolmo CDF
|
||||
# bfloat16 quantization creates much lower p-values in large distributions
|
||||
fail_prob = 0.003 if samples.dtype == jnp.bfloat16 else 0.01
|
||||
# TODO(frostig): This reads enable_custom_prng as a proxy for
|
||||
# whether RBG keys may be involved, but that's no longer exact.
|
||||
if config.jax_enable_custom_prng and samples.dtype == jnp.bfloat16:
|
||||
return
|
||||
self.assertGreater(scipy.stats.kstest(samples, cdf).pvalue, fail_prob)
|
||||
@ -1649,6 +1630,15 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(samples.mean(), 1 / p, rtol=0.02, check_dtypes=False)
|
||||
self.assertAllClose(samples.var(), (1 - p) / (p * p) , rtol=0.05, check_dtypes=False)
|
||||
|
||||
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
||||
def test_copy(self, make_key):
|
||||
key = make_key(8459302)
|
||||
self.assertArraysEqual(key, key.copy())
|
||||
self.assertArraysEqual(key, copy.copy(key))
|
||||
self.assertArraysEqual(key, copy.deepcopy(key))
|
||||
self.assertArraysEqual(key, jax.jit(lambda k: k.copy())(key))
|
||||
|
||||
|
||||
class KeyArrayTest(jtu.JaxTestCase):
|
||||
# Key arrays involve:
|
||||
# * a Python key array type, backed by an underlying uint32 "base" array,
|
||||
@ -1664,6 +1654,15 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
# might also be a more general test of opaque element types. If
|
||||
# so, add a corresponding test to to CustomElementTypesTest as well.
|
||||
|
||||
def test_construction(self):
|
||||
key = random.key(42)
|
||||
self.assertIsInstance(key, random.PRNGKeyArray)
|
||||
|
||||
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
|
||||
def test_construction_upgrade_flag(self):
|
||||
key = random.PRNGKey(42)
|
||||
self.assertIsInstance(key, random.PRNGKeyArray)
|
||||
|
||||
def make_keys(self, *shape, seed=28):
|
||||
seeds = seed + jnp.arange(math.prod(shape), dtype=jnp.uint32)
|
||||
return jax.vmap(random.key)(seeds).reshape(shape)
|
||||
@ -1754,8 +1753,7 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.dtype, np.dtype('uint32'))
|
||||
self.assertEqual(out.shape[:2], (3, 4))
|
||||
|
||||
# TODO(frostig): simplify when we always enable_custom_prng
|
||||
if not (config.jax_enable_custom_prng and use_internal):
|
||||
if not use_internal:
|
||||
return
|
||||
|
||||
x = jnp.arange(12, dtype=np.dtype('uint32')).reshape(3, 4)
|
||||
@ -1972,6 +1970,32 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(key_dot, random.PRNGKeyArray)
|
||||
self.assertArraysEqual(random.key_data(key_dot), np.uint32(0))
|
||||
|
||||
def test_key_array_indexing_0d(self):
|
||||
key = self.make_keys()
|
||||
self.assertEqual(key.shape, ())
|
||||
self.assertEqual(key[None].shape, (1,))
|
||||
self.assertRaisesRegex(IndexError, 'Too many indices.*', lambda: key[0])
|
||||
|
||||
def test_key_array_indexing_nd(self):
|
||||
keys = self.make_keys(2, 3)
|
||||
self.assertEqual(keys.shape, (2, 3))
|
||||
self.assertEqual(keys[0, 0].shape, ())
|
||||
self.assertEqual(keys[0, 1].shape, ())
|
||||
self.assertEqual(keys[0].shape, (3,))
|
||||
self.assertEqual(keys[1, :].shape, (3,))
|
||||
self.assertEqual(keys[:, 1].shape, (2,))
|
||||
self.assertEqual(keys[None].shape, (1, 2, 3))
|
||||
self.assertEqual(keys[None, None].shape, (1, 1, 2, 3))
|
||||
self.assertEqual(keys[None, :, None].shape, (1, 2, 1, 3))
|
||||
self.assertEqual(keys[None, None, None, 0, None, None, None, 1].shape,
|
||||
(1,) * 6)
|
||||
self.assertEqual(keys[..., 1:, None].shape, (2, 2, 1))
|
||||
self.assertEqual(keys[None, 0, ..., 1, None].shape, (1, 1))
|
||||
self.assertRaisesRegex(IndexError, 'Too many indices.*',
|
||||
lambda: keys[0, 1, 2])
|
||||
self.assertRaisesRegex(IndexError, 'Too many indices.*',
|
||||
lambda: keys[0, 1, None, 2])
|
||||
|
||||
def test_not_hashable(self):
|
||||
key = self.make_keys()
|
||||
with self.assertRaisesRegex(TypeError, "unhashable type"):
|
||||
@ -2107,11 +2131,10 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
|
||||
self.assertArraysEqual(out, np.zeros(key.shape, jax.dtypes.float0))
|
||||
|
||||
|
||||
# TODO(frostig): remove `with_config` we always enable_custom_prng
|
||||
@jtu.with_config(jax_default_prng_impl='rbg')
|
||||
class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
def seed_prng(self, seed):
|
||||
return random.rbg_key(seed)
|
||||
return random.key(seed, impl='rbg')
|
||||
|
||||
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key arrays')
|
||||
def test_split_shape(self):
|
||||
@ -2162,7 +2185,6 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
|
||||
@skipIf(np.__version__ == "1.21.0",
|
||||
"https://github.com/numpy/numpy/issues/19305")
|
||||
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key arrays')
|
||||
def test_grad_of_prng_key(self):
|
||||
key = self.seed_prng(73)
|
||||
with self.assertRaisesRegex(TypeError, 'grad requires real- or complex-valued inputs'):
|
||||
@ -2177,13 +2199,6 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
# TODO(mattjj): enable this test if/when RngBitGenerator supports it
|
||||
raise SkipTest('8-bit types not supported with RBG PRNG')
|
||||
|
||||
def test_copy(self):
|
||||
key = random.PRNGKey(8459302)
|
||||
self.assertArraysEqual(key, key.copy())
|
||||
self.assertArraysEqual(key, copy.copy(key))
|
||||
self.assertArraysEqual(key, copy.deepcopy(key))
|
||||
self.assertArraysEqual(key, jax.jit(lambda k: k.copy())(key))
|
||||
|
||||
|
||||
# TODO(frostig): remove `with_config` we always enable_custom_prng
|
||||
@jtu.with_config(jax_default_prng_impl='unsafe_rbg')
|
||||
|
Loading…
x
Reference in New Issue
Block a user