Merge pull request #4398 from google:lift-randomness-limit

PiperOrigin-RevId: 333433816
This commit is contained in:
jax authors 2020-09-23 20:51:27 -07:00
commit c7e0ef4075
2 changed files with 16 additions and 5 deletions

View File

@ -308,12 +308,17 @@ def _random_bits(key, bit_width, shape):
raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
size = prod(shape)
max_count = int(np.ceil(bit_width * size / 32))
if max_count >= jnp.iinfo(np.uint32).max:
# TODO(mattjj): just split the key here
raise TypeError("requesting more random bits than a single call provides.")
counts = lax.iota(np.uint32, max_count)
bits = threefry_2x32(key, counts)
nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max)
if not nblocks:
bits = threefry_2x32(key, lax.iota(np.uint32, rem))
else:
*subkeys, last_key = split(key, nblocks + 1)
blocks = [threefry_2x32(k, lax.iota(np.uint32, jnp.iinfo(np.uint32).max))
for k in subkeys]
last = threefry_2x32(last_key, lax.iota(np.uint32, rem))
bits = lax.concatenate(blocks + [last], 0)
dtype = _UINT_DTYPES[bit_width]
if bit_width == 64:
bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)]

View File

@ -857,6 +857,12 @@ class LaxRandomTest(jtu.JaxTestCase):
with self.assertRaises(TypeError):
random.choice(key, 5, 2, replace=True)
def test_eval_shape_big_random_array(self):
def f(x):
return random.normal(random.PRNGKey(x), (int(1e12),))
with core.skipping_checks(): # check_jaxpr will materialize array
api.eval_shape(f, 0) # doesn't error
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())