mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
define a loop-free untrue batching rule for rng_bit_generator
This commit is contained in:
parent
f0afc1b43d
commit
29edfd8925
@ -2046,16 +2046,18 @@ def map(f, xs):
|
||||
return ys
|
||||
|
||||
def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, algorithm):
|
||||
"""Calls RBG in a loop and stacks the results."""
|
||||
key, = batched_args
|
||||
keys, = batched_args
|
||||
bd, = batch_dims
|
||||
if bd is batching.not_mapped:
|
||||
return lax.rng_bit_generator_p.bind(key, shape=shape, dtype=dtype,
|
||||
return lax.rng_bit_generator_p.bind(keys, shape=shape, dtype=dtype,
|
||||
algorithm=algorithm), (None, None)
|
||||
key = batching.moveaxis(key, bd, 0)
|
||||
map_body = lambda k: lax.rng_bit_generator_p.bind(k, shape=shape, dtype=dtype, algorithm=algorithm)
|
||||
stacked_keys, stacked_bits = map(map_body, key)
|
||||
return (stacked_keys, stacked_bits), (0, 0)
|
||||
keys = batching.moveaxis(keys, bd, 0)
|
||||
batch_size = keys.shape[0]
|
||||
key = keys[0]
|
||||
new_key, bits = lax.rng_bit_generator_p.bind(key, shape=(batch_size, *shape),
|
||||
dtype=dtype, algorithm=algorithm)
|
||||
new_keys = jax.lax.dynamic_update_index_in_dim(keys, new_key, 0, axis=0)
|
||||
return (new_keys, bits), (0, 0)
|
||||
|
||||
batching.primitive_batchers[lax.rng_bit_generator_p] = _rng_bit_generator_batching_rule # type: ignore
|
||||
|
||||
|
@ -1233,7 +1233,7 @@ def _gamma_impl(key, a, *, log_space, use_vmap=False):
|
||||
keys = keys.flatten()
|
||||
alphas = a.flatten()
|
||||
|
||||
if use_vmap:
|
||||
if use_vmap and _key_impl(key) is prng.threefry_prng_impl:
|
||||
samples = vmap(partial(_gamma_one, log_space=log_space))(keys, alphas)
|
||||
else:
|
||||
samples = lax.map(
|
||||
|
@ -784,6 +784,9 @@ jax_test(
|
||||
"notsan", # Times out
|
||||
],
|
||||
},
|
||||
backend_variant_args = {
|
||||
"gpu": ["--jax_num_generated_cases=40"],
|
||||
},
|
||||
shard_count = {
|
||||
"cpu": 40,
|
||||
"gpu": 30,
|
||||
|
@ -2652,6 +2652,24 @@ class LaxTest(jtu.JaxTestCase):
|
||||
new_key, _ = lax.rng_bit_generator(key, (0,))
|
||||
self.assertAllClose(key, new_key)
|
||||
|
||||
def test_rng_bit_generator_vmap(self):
|
||||
def f(key):
|
||||
return lax.rng_bit_generator(key, shape=(5, 7))
|
||||
|
||||
keys = np.arange(3 * 4).reshape((3, 4)).astype(np.uint32)
|
||||
out_keys, bits = jax.vmap(f)(keys)
|
||||
self.assertEqual(out_keys.shape, (3, 4))
|
||||
self.assertEqual(bits.shape, (3, 5, 7))
|
||||
|
||||
def test_rng_bit_generator_vmap_vmap(self):
|
||||
def f(key):
|
||||
return lax.rng_bit_generator(key, shape=(5, 7))
|
||||
|
||||
keys = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.uint32)
|
||||
out_keys, bits = jax.vmap(jax.vmap(f))(keys)
|
||||
self.assertEqual(out_keys.shape, (2, 3, 4))
|
||||
self.assertEqual(bits.shape, (2, 3, 5, 7))
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=lax_test_util.all_dtypes + lax_test_util.python_scalar_types,
|
||||
weak_type=[True, False],
|
||||
|
@ -1348,6 +1348,7 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
|
||||
out = vmap(vmap(random.fold_in), in_axes=(1, 0))(keys(), msgs.T)
|
||||
self.assertEqual(out.shape, (3, 2))
|
||||
|
||||
@jax.enable_key_reuse_checks(False)
|
||||
def test_vmap_split_mapped_key(self):
|
||||
key = self.make_key(73)
|
||||
mapped_keys = random.split(key, num=3)
|
||||
@ -1408,24 +1409,57 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
self.assertArraysEqual(random.key_data(vk),
|
||||
random.key_data(single_split_key))
|
||||
|
||||
def test_vmap_split_mapped_key(self):
|
||||
@jax.enable_key_reuse_checks(False)
|
||||
def test_vmap_split_mapped_key_shape(self):
|
||||
key = self.make_key(73)
|
||||
mapped_keys = random.split(key, num=3)
|
||||
forloop_keys = [random.split(k) for k in mapped_keys]
|
||||
vmapped_keys = vmap(random.split)(mapped_keys)
|
||||
self.assertEqual(vmapped_keys.shape, (3, 2, *key.shape))
|
||||
for fk, vk in zip(forloop_keys, vmapped_keys):
|
||||
self.assertArraysEqual(random.key_data(fk),
|
||||
random.key_data(vk))
|
||||
|
||||
def test_vmap_random_bits(self):
|
||||
rand_fun = lambda key: random.randint(key, (), 0, 100)
|
||||
@jax.enable_key_reuse_checks(False)
|
||||
def test_vmap_split_mapped_key_values(self):
|
||||
key = self.make_key(73)
|
||||
mapped_keys = random.split(key, num=3)
|
||||
vmapped_keys = vmap(random.split)(mapped_keys)
|
||||
ref_keys = [random.split(k) for k in mapped_keys]
|
||||
for rk, vk in zip(ref_keys, vmapped_keys):
|
||||
self.assertArraysEqual(random.key_data(rk),
|
||||
random.key_data(vk))
|
||||
|
||||
@jax.enable_key_reuse_checks(False)
|
||||
def test_vmap_random_bits_shape(self):
|
||||
rand_fun = lambda key, shape=(): random.randint(key, shape, 0, 100)
|
||||
key = self.make_key(73)
|
||||
mapped_keys = random.split(key, num=3)
|
||||
forloop_rand_nums = [rand_fun(k) for k in mapped_keys]
|
||||
rand_nums = vmap(rand_fun)(mapped_keys)
|
||||
self.assertEqual(rand_nums.shape, (3,))
|
||||
self.assertArraysEqual(rand_nums, jnp.array(forloop_rand_nums))
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
@jax.enable_key_reuse_checks(False)
|
||||
def test_vmap_random_bits_value(self):
|
||||
rand_fun = lambda key, shape=(): random.randint(key, shape, 0, 100)
|
||||
key = self.make_key(73)
|
||||
mapped_keys = random.split(key, num=3)
|
||||
rand_nums = vmap(rand_fun)(mapped_keys)
|
||||
ref_nums = rand_fun(mapped_keys[0], shape=(3,))
|
||||
self.assertArraysEqual(rand_nums, ref_nums)
|
||||
|
||||
def test_vmap_random_bits_distribution(self):
|
||||
dtype = jnp.float32
|
||||
keys = lambda: jax.random.split(self.make_key(0), 10)
|
||||
|
||||
def rand(key):
|
||||
nums = jax.vmap(lambda key: random.uniform(key, (1000,), dtype))(key)
|
||||
return nums.flatten()
|
||||
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(keys())
|
||||
compiled_samples = crand(keys())
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckCollisions(samples, jnp.finfo(dtype).nmant)
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf)
|
||||
|
||||
def test_cannot_add(self):
|
||||
key = self.make_key(73)
|
||||
@ -1455,6 +1489,15 @@ class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest):
|
||||
def make_key(self, seed):
|
||||
return random.PRNGKey(seed, impl="unsafe_rbg")
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
@jax.enable_key_reuse_checks(False)
|
||||
def test_vmap_split_mapped_key_values(self):
|
||||
key = self.make_key(73)
|
||||
mapped_keys = random.split(key, num=3)
|
||||
vmapped_keys = vmap(random.split)(mapped_keys)
|
||||
ref_keys = random.split(mapped_keys[0], (3, 2))
|
||||
self.assertArraysEqual(random.key_data(vmapped_keys),
|
||||
random.key_data(ref_keys))
|
||||
|
||||
def _sampler_unimplemented_with_custom_prng(*args, **kwargs):
|
||||
raise SkipTest('sampler only implemented for default RNG')
|
||||
|
Loading…
x
Reference in New Issue
Block a user