From c42d736e347d059290ab20c013635d62a1ee6c45 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 23 Sep 2020 19:37:34 -0700 Subject: [PATCH] remove limit on size of random arrays --- jax/random.py | 15 ++++++++++----- tests/random_test.py | 5 +++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/jax/random.py b/jax/random.py index 0bbb012df..3823c8c11 100644 --- a/jax/random.py +++ b/jax/random.py @@ -308,12 +308,17 @@ def _random_bits(key, bit_width, shape): raise TypeError("requires 8-, 16-, 32- or 64-bit field width.") size = prod(shape) max_count = int(np.ceil(bit_width * size / 32)) - if max_count >= jnp.iinfo(np.uint32).max: - # TODO(mattjj): just split the key here - raise TypeError("requesting more random bits than a single call provides.") - counts = lax.iota(np.uint32, max_count) - bits = threefry_2x32(key, counts) + nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max) + if not nblocks: + bits = threefry_2x32(key, lax.iota(np.uint32, rem)) + else: + *subkeys, last_key = split(key, nblocks + 1) + blocks = [threefry_2x32(k, lax.iota(np.uint32, jnp.iinfo(np.uint32).max)) + for k in subkeys] + last = threefry_2x32(last_key, lax.iota(np.uint32, rem)) + bits = lax.concatenate(blocks + [last], 0) + dtype = _UINT_DTYPES[bit_width] if bit_width == 64: bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)] diff --git a/tests/random_test.py b/tests/random_test.py index 90f347677..6ac454dbe 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -857,6 +857,11 @@ class LaxRandomTest(jtu.JaxTestCase): with self.assertRaises(TypeError): random.choice(key, 5, 2, replace=True) + def test_eval_shape_big_random_array(self): + def f(): + return random.normal(random.PRNGKey(0), (int(1e10),)) + api.eval_shape(f) # doesn't error + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())