Fix overflow of large prng computation

Fixes: #11010
This commit is contained in:
George Necula 2022-06-20 10:48:15 +02:00
parent b50d77cc98
commit 4b03ebf4f5
2 changed files with 9 additions and 1 deletions

View File

@ -542,7 +542,7 @@ def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
)
)
)
bits = lax.reshape(bits, (np.uint32(max_count * 32 // bit_width),), (1, 0))
bits = lax.reshape(bits, ((max_count * 32 // bit_width),), (1, 0))
bits = lax.convert_element_type(bits, dtype)[:size]
return lax.reshape(bits, shape)

View File

@ -1454,6 +1454,14 @@ class LaxRandomTest(jtu.JaxTestCase):
self.assertGreater((r == 0).sum(), 0)
self.assertGreater((r == 255).sum(), 0)
def test_large_prng(self):
# https://github.com/google/jax/issues/11010
def f():
return jax.random.uniform(jax.random.PRNGKey(3), (308000000, 128), dtype=jnp.bfloat16)
# just lower, don't run, takes too long
jax.jit(f).lower()
threefry_seed = jax._src.prng.threefry_seed
threefry_split = jax._src.prng.threefry_split