mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
remove limit on size of random arrays
This commit is contained in:
parent
c875ab3ec9
commit
c42d736e34
@ -308,12 +308,17 @@ def _random_bits(key, bit_width, shape):
|
|||||||
raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
|
raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
|
||||||
size = prod(shape)
|
size = prod(shape)
|
||||||
max_count = int(np.ceil(bit_width * size / 32))
|
max_count = int(np.ceil(bit_width * size / 32))
|
||||||
if max_count >= jnp.iinfo(np.uint32).max:
|
|
||||||
# TODO(mattjj): just split the key here
|
|
||||||
raise TypeError("requesting more random bits than a single call provides.")
|
|
||||||
|
|
||||||
counts = lax.iota(np.uint32, max_count)
|
nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max)
|
||||||
bits = threefry_2x32(key, counts)
|
if not nblocks:
|
||||||
|
bits = threefry_2x32(key, lax.iota(np.uint32, rem))
|
||||||
|
else:
|
||||||
|
*subkeys, last_key = split(key, nblocks + 1)
|
||||||
|
blocks = [threefry_2x32(k, lax.iota(np.uint32, jnp.iinfo(np.uint32).max))
|
||||||
|
for k in subkeys]
|
||||||
|
last = threefry_2x32(last_key, lax.iota(np.uint32, rem))
|
||||||
|
bits = lax.concatenate(blocks + [last], 0)
|
||||||
|
|
||||||
dtype = _UINT_DTYPES[bit_width]
|
dtype = _UINT_DTYPES[bit_width]
|
||||||
if bit_width == 64:
|
if bit_width == 64:
|
||||||
bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)]
|
bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)]
|
||||||
|
@ -857,6 +857,11 @@ class LaxRandomTest(jtu.JaxTestCase):
|
|||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
random.choice(key, 5, 2, replace=True)
|
random.choice(key, 5, 2, replace=True)
|
||||||
|
|
||||||
|
def test_eval_shape_big_random_array(self):
|
||||||
|
def f():
|
||||||
|
return random.normal(random.PRNGKey(0), (int(1e10),))
|
||||||
|
api.eval_shape(f) # doesn't error
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user