mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
parent
b50d77cc98
commit
4b03ebf4f5
@ -542,7 +542,7 @@ def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
|
||||
)
|
||||
)
|
||||
)
|
||||
bits = lax.reshape(bits, (np.uint32(max_count * 32 // bit_width),), (1, 0))
|
||||
bits = lax.reshape(bits, ((max_count * 32 // bit_width),), (1, 0))
|
||||
bits = lax.convert_element_type(bits, dtype)[:size]
|
||||
return lax.reshape(bits, shape)
|
||||
|
||||
|
@ -1454,6 +1454,14 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self.assertGreater((r == 0).sum(), 0)
|
||||
self.assertGreater((r == 255).sum(), 0)
|
||||
|
||||
def test_large_prng(self):
|
||||
# https://github.com/google/jax/issues/11010
|
||||
def f():
|
||||
return jax.random.uniform(jax.random.PRNGKey(3), (308000000, 128), dtype=jnp.bfloat16)
|
||||
|
||||
# just lower, don't run, takes too long
|
||||
jax.jit(f).lower()
|
||||
|
||||
|
||||
threefry_seed = jax._src.prng.threefry_seed
|
||||
threefry_split = jax._src.prng.threefry_split
|
||||
|
Loading…
x
Reference in New Issue
Block a user