mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Merge pull request #4398 from google:lift-randomness-limit
PiperOrigin-RevId: 333433816
This commit is contained in:
commit
c7e0ef4075
@ -308,12 +308,17 @@ def _random_bits(key, bit_width, shape):
|
||||
raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
|
||||
size = prod(shape)
|
||||
max_count = int(np.ceil(bit_width * size / 32))
|
||||
if max_count >= jnp.iinfo(np.uint32).max:
|
||||
# TODO(mattjj): just split the key here
|
||||
raise TypeError("requesting more random bits than a single call provides.")
|
||||
|
||||
counts = lax.iota(np.uint32, max_count)
|
||||
bits = threefry_2x32(key, counts)
|
||||
nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max)
|
||||
if not nblocks:
|
||||
bits = threefry_2x32(key, lax.iota(np.uint32, rem))
|
||||
else:
|
||||
*subkeys, last_key = split(key, nblocks + 1)
|
||||
blocks = [threefry_2x32(k, lax.iota(np.uint32, jnp.iinfo(np.uint32).max))
|
||||
for k in subkeys]
|
||||
last = threefry_2x32(last_key, lax.iota(np.uint32, rem))
|
||||
bits = lax.concatenate(blocks + [last], 0)
|
||||
|
||||
dtype = _UINT_DTYPES[bit_width]
|
||||
if bit_width == 64:
|
||||
bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)]
|
||||
|
@ -857,6 +857,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
with self.assertRaises(TypeError):
|
||||
random.choice(key, 5, 2, replace=True)
|
||||
|
||||
def test_eval_shape_big_random_array(self):
|
||||
def f(x):
|
||||
return random.normal(random.PRNGKey(x), (int(1e12),))
|
||||
with core.skipping_checks(): # check_jaxpr will materialize array
|
||||
api.eval_shape(f, 0) # doesn't error
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user