define a loop-free untrue batching rule for rng_bit_generator

This commit is contained in:
Roy Frostig 2024-03-05 20:09:14 -08:00
parent f0afc1b43d
commit 29edfd8925
5 changed files with 83 additions and 17 deletions

View File

@ -2046,16 +2046,18 @@ def map(f, xs):
return ys
def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, algorithm):
"""Calls RBG in a loop and stacks the results."""
key, = batched_args
keys, = batched_args
bd, = batch_dims
if bd is batching.not_mapped:
return lax.rng_bit_generator_p.bind(key, shape=shape, dtype=dtype,
return lax.rng_bit_generator_p.bind(keys, shape=shape, dtype=dtype,
algorithm=algorithm), (None, None)
key = batching.moveaxis(key, bd, 0)
map_body = lambda k: lax.rng_bit_generator_p.bind(k, shape=shape, dtype=dtype, algorithm=algorithm)
stacked_keys, stacked_bits = map(map_body, key)
return (stacked_keys, stacked_bits), (0, 0)
keys = batching.moveaxis(keys, bd, 0)
batch_size = keys.shape[0]
key = keys[0]
new_key, bits = lax.rng_bit_generator_p.bind(key, shape=(batch_size, *shape),
dtype=dtype, algorithm=algorithm)
new_keys = jax.lax.dynamic_update_index_in_dim(keys, new_key, 0, axis=0)
return (new_keys, bits), (0, 0)
batching.primitive_batchers[lax.rng_bit_generator_p] = _rng_bit_generator_batching_rule # type: ignore

View File

@ -1233,7 +1233,7 @@ def _gamma_impl(key, a, *, log_space, use_vmap=False):
keys = keys.flatten()
alphas = a.flatten()
if use_vmap:
if use_vmap and _key_impl(key) is prng.threefry_prng_impl:
samples = vmap(partial(_gamma_one, log_space=log_space))(keys, alphas)
else:
samples = lax.map(

View File

@ -784,6 +784,9 @@ jax_test(
"notsan", # Times out
],
},
backend_variant_args = {
"gpu": ["--jax_num_generated_cases=40"],
},
shard_count = {
"cpu": 40,
"gpu": 30,

View File

@ -2652,6 +2652,24 @@ class LaxTest(jtu.JaxTestCase):
new_key, _ = lax.rng_bit_generator(key, (0,))
self.assertAllClose(key, new_key)
def test_rng_bit_generator_vmap(self):
def f(key):
return lax.rng_bit_generator(key, shape=(5, 7))
keys = np.arange(3 * 4).reshape((3, 4)).astype(np.uint32)
out_keys, bits = jax.vmap(f)(keys)
self.assertEqual(out_keys.shape, (3, 4))
self.assertEqual(bits.shape, (3, 5, 7))
def test_rng_bit_generator_vmap_vmap(self):
def f(key):
return lax.rng_bit_generator(key, shape=(5, 7))
keys = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.uint32)
out_keys, bits = jax.vmap(jax.vmap(f))(keys)
self.assertEqual(out_keys.shape, (2, 3, 4))
self.assertEqual(bits.shape, (2, 3, 5, 7))
@jtu.sample_product(
dtype=lax_test_util.all_dtypes + lax_test_util.python_scalar_types,
weak_type=[True, False],

View File

@ -1348,6 +1348,7 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
out = vmap(vmap(random.fold_in), in_axes=(1, 0))(keys(), msgs.T)
self.assertEqual(out.shape, (3, 2))
@jax.enable_key_reuse_checks(False)
def test_vmap_split_mapped_key(self):
key = self.make_key(73)
mapped_keys = random.split(key, num=3)
@ -1408,24 +1409,57 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
self.assertArraysEqual(random.key_data(vk),
random.key_data(single_split_key))
def test_vmap_split_mapped_key(self):
@jax.enable_key_reuse_checks(False)
def test_vmap_split_mapped_key_shape(self):
key = self.make_key(73)
mapped_keys = random.split(key, num=3)
forloop_keys = [random.split(k) for k in mapped_keys]
vmapped_keys = vmap(random.split)(mapped_keys)
self.assertEqual(vmapped_keys.shape, (3, 2, *key.shape))
for fk, vk in zip(forloop_keys, vmapped_keys):
self.assertArraysEqual(random.key_data(fk),
random.key_data(vk))
def test_vmap_random_bits(self):
rand_fun = lambda key: random.randint(key, (), 0, 100)
@jax.enable_key_reuse_checks(False)
def test_vmap_split_mapped_key_values(self):
key = self.make_key(73)
mapped_keys = random.split(key, num=3)
vmapped_keys = vmap(random.split)(mapped_keys)
ref_keys = [random.split(k) for k in mapped_keys]
for rk, vk in zip(ref_keys, vmapped_keys):
self.assertArraysEqual(random.key_data(rk),
random.key_data(vk))
@jax.enable_key_reuse_checks(False)
def test_vmap_random_bits_shape(self):
rand_fun = lambda key, shape=(): random.randint(key, shape, 0, 100)
key = self.make_key(73)
mapped_keys = random.split(key, num=3)
forloop_rand_nums = [rand_fun(k) for k in mapped_keys]
rand_nums = vmap(rand_fun)(mapped_keys)
self.assertEqual(rand_nums.shape, (3,))
self.assertArraysEqual(rand_nums, jnp.array(forloop_rand_nums))
@jtu.skip_on_devices("tpu")
@jax.enable_key_reuse_checks(False)
def test_vmap_random_bits_value(self):
rand_fun = lambda key, shape=(): random.randint(key, shape, 0, 100)
key = self.make_key(73)
mapped_keys = random.split(key, num=3)
rand_nums = vmap(rand_fun)(mapped_keys)
ref_nums = rand_fun(mapped_keys[0], shape=(3,))
self.assertArraysEqual(rand_nums, ref_nums)
def test_vmap_random_bits_distribution(self):
dtype = jnp.float32
keys = lambda: jax.random.split(self.make_key(0), 10)
def rand(key):
nums = jax.vmap(lambda key: random.uniform(key, (1000,), dtype))(key)
return nums.flatten()
crand = jax.jit(rand)
uncompiled_samples = rand(keys())
compiled_samples = crand(keys())
for samples in [uncompiled_samples, compiled_samples]:
self._CheckCollisions(samples, jnp.finfo(dtype).nmant)
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf)
def test_cannot_add(self):
key = self.make_key(73)
@ -1455,6 +1489,15 @@ class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest):
def make_key(self, seed):
return random.PRNGKey(seed, impl="unsafe_rbg")
@jtu.skip_on_devices("tpu")
@jax.enable_key_reuse_checks(False)
def test_vmap_split_mapped_key_values(self):
key = self.make_key(73)
mapped_keys = random.split(key, num=3)
vmapped_keys = vmap(random.split)(mapped_keys)
ref_keys = random.split(mapped_keys[0], (3, 2))
self.assertArraysEqual(random.key_data(vmapped_keys),
random.key_data(ref_keys))
def _sampler_unimplemented_with_custom_prng(*args, **kwargs):
raise SkipTest('sampler only implemented for default RNG')