mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
remove limit on size of random arrays
This commit is contained in:
parent
c875ab3ec9
commit
c42d736e34
@ -308,12 +308,17 @@ def _random_bits(key, bit_width, shape):
|
||||
raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
|
||||
size = prod(shape)
|
||||
max_count = int(np.ceil(bit_width * size / 32))
|
||||
if max_count >= jnp.iinfo(np.uint32).max:
|
||||
# TODO(mattjj): just split the key here
|
||||
raise TypeError("requesting more random bits than a single call provides.")
|
||||
|
||||
counts = lax.iota(np.uint32, max_count)
|
||||
bits = threefry_2x32(key, counts)
|
||||
nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max)
|
||||
if not nblocks:
|
||||
bits = threefry_2x32(key, lax.iota(np.uint32, rem))
|
||||
else:
|
||||
*subkeys, last_key = split(key, nblocks + 1)
|
||||
blocks = [threefry_2x32(k, lax.iota(np.uint32, jnp.iinfo(np.uint32).max))
|
||||
for k in subkeys]
|
||||
last = threefry_2x32(last_key, lax.iota(np.uint32, rem))
|
||||
bits = lax.concatenate(blocks + [last], 0)
|
||||
|
||||
dtype = _UINT_DTYPES[bit_width]
|
||||
if bit_width == 64:
|
||||
bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)]
|
||||
|
@ -857,6 +857,11 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
with self.assertRaises(TypeError):
|
||||
random.choice(key, 5, 2, replace=True)
|
||||
|
||||
def test_eval_shape_big_random_array(self):
|
||||
def f():
|
||||
return random.normal(random.PRNGKey(0), (int(1e10),))
|
||||
api.eval_shape(f) # doesn't error
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user