remove limit on size of random arrays

This commit is contained in:
Matthew Johnson 2020-09-23 19:37:34 -07:00
parent c875ab3ec9
commit c42d736e34
2 changed files with 15 additions and 5 deletions

View File

@ -308,12 +308,17 @@ def _random_bits(key, bit_width, shape):
raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
size = prod(shape)
max_count = int(np.ceil(bit_width * size / 32))
if max_count >= jnp.iinfo(np.uint32).max:
# TODO(mattjj): just split the key here
raise TypeError("requesting more random bits than a single call provides.")
counts = lax.iota(np.uint32, max_count)
bits = threefry_2x32(key, counts)
nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max)
if not nblocks:
bits = threefry_2x32(key, lax.iota(np.uint32, rem))
else:
*subkeys, last_key = split(key, nblocks + 1)
blocks = [threefry_2x32(k, lax.iota(np.uint32, jnp.iinfo(np.uint32).max))
for k in subkeys]
last = threefry_2x32(last_key, lax.iota(np.uint32, rem))
bits = lax.concatenate(blocks + [last], 0)
dtype = _UINT_DTYPES[bit_width]
if bit_width == 64:
bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)]

View File

@ -857,6 +857,11 @@ class LaxRandomTest(jtu.JaxTestCase):
with self.assertRaises(TypeError):
random.choice(key, 5, 2, replace=True)
def test_eval_shape_big_random_array(self):
def f():
return random.normal(random.PRNGKey(0), (int(1e10),))
api.eval_shape(f) # doesn't error
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())