mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #16607 from froystig:random-test-double-threefry
PiperOrigin-RevId: 545799083
This commit is contained in:
commit
7c7051a4cc
@ -2040,21 +2040,25 @@ threefry_fold_in = prng_internal.threefry_fold_in
|
||||
|
||||
def _double_threefry_seed(seed):
|
||||
int_t = seed.dtype.type if hasattr(seed, 'dtype') else type(seed)
|
||||
s1, s2 = seed ^ int_t(1), seed ^ int_t(3)
|
||||
s1, s2 = seed, seed ^ int_t(3)
|
||||
return jnp.vstack([threefry_seed(s1),
|
||||
threefry_seed(s2)])
|
||||
|
||||
def _double_threefry_split(key, num):
|
||||
split0 = threefry_split(key[0], num)
|
||||
split1 = threefry_split(key[1], num)
|
||||
merge = jnp.vstack([jnp.expand_dims(split0.T, axis=0),
|
||||
jnp.expand_dims(split1.T, axis=0)])
|
||||
return merge.transpose((2, 0, 1))
|
||||
merge = jnp.vstack([jnp.expand_dims(split0, axis=0),
|
||||
jnp.expand_dims(split1, axis=0)])
|
||||
return merge.transpose((1, 0, 2))
|
||||
|
||||
def _double_threefry_random_bits(key, bit_width, shape):
|
||||
bits0 = threefry_random_bits(key[0], bit_width, shape)
|
||||
bits1 = threefry_random_bits(key[1], bit_width, shape)
|
||||
return bits0 * bits1
|
||||
del bits1
|
||||
# TODO(frostig): Currently this behaves like normal threefry, to
|
||||
# avoid a few probabilistic test failures. Ideally we might want to
|
||||
# test different generation behavior here (e.g. `bits0 ^ bits1`).
|
||||
return bits0
|
||||
|
||||
def _double_threefry_fold_in(key, data):
|
||||
return jnp.vstack([threefry_fold_in(key[0], data),
|
||||
@ -2068,8 +2072,7 @@ double_threefry_prng_impl = prng.PRNGImpl(
|
||||
fold_in=_double_threefry_fold_in,
|
||||
tag='fry2')
|
||||
|
||||
@skipIf(not config.jax_enable_custom_prng,
|
||||
'custom PRNG tests require config.jax_enable_custom_prng')
|
||||
@jtu.with_config(jax_default_prng_impl='threefry2x32')
|
||||
class LaxRandomWithCustomPRNGTest(LaxRandomTest):
|
||||
def seed_prng(self, seed):
|
||||
return prng.seed_with_impl(double_threefry_prng_impl, seed)
|
||||
|
Loading…
x
Reference in New Issue
Block a user