Merge pull request #16607 from froystig:random-test-double-threefry

PiperOrigin-RevId: 545799083
This commit is contained in:
jax authors 2023-07-05 15:15:35 -07:00
commit 7c7051a4cc

View File

@ -2040,21 +2040,25 @@ threefry_fold_in = prng_internal.threefry_fold_in
def _double_threefry_seed(seed):
int_t = seed.dtype.type if hasattr(seed, 'dtype') else type(seed)
s1, s2 = seed ^ int_t(1), seed ^ int_t(3)
s1, s2 = seed, seed ^ int_t(3)
return jnp.vstack([threefry_seed(s1),
threefry_seed(s2)])
def _double_threefry_split(key, num):
split0 = threefry_split(key[0], num)
split1 = threefry_split(key[1], num)
merge = jnp.vstack([jnp.expand_dims(split0.T, axis=0),
jnp.expand_dims(split1.T, axis=0)])
return merge.transpose((2, 0, 1))
merge = jnp.vstack([jnp.expand_dims(split0, axis=0),
jnp.expand_dims(split1, axis=0)])
return merge.transpose((1, 0, 2))
def _double_threefry_random_bits(key, bit_width, shape):
bits0 = threefry_random_bits(key[0], bit_width, shape)
bits1 = threefry_random_bits(key[1], bit_width, shape)
return bits0 * bits1
del bits1
# TODO(frostig): Currently this behaves like normal threefry, to
# avoid a few probabilistic test failures. Ideally we might want to
# test different generation behavior here (e.g. `bits0 ^ bits1`).
return bits0
def _double_threefry_fold_in(key, data):
return jnp.vstack([threefry_fold_in(key[0], data),
@ -2068,8 +2072,7 @@ double_threefry_prng_impl = prng.PRNGImpl(
fold_in=_double_threefry_fold_in,
tag='fry2')
@skipIf(not config.jax_enable_custom_prng,
'custom PRNG tests require config.jax_enable_custom_prng')
@jtu.with_config(jax_default_prng_impl='threefry2x32')
class LaxRandomWithCustomPRNGTest(LaxRandomTest):
def seed_prng(self, seed):
return prng.seed_with_impl(double_threefry_prng_impl, seed)