consistently seed keys indirectly by test class method in LaxRandomTest

This commit is contained in:
Roy Frostig 2023-07-01 21:41:11 -07:00
parent 556c1123cf
commit ff70255af9

View File

@ -1289,7 +1289,7 @@ class LaxRandomTest(jtu.JaxTestCase):
# Singular covariance matrix https://github.com/google/jax/discussions/13293
mu = jnp.zeros((2,))
sigma = jnp.ones((2, 2))
key = random.PRNGKey(0)
key = self.seed_prng(0)
result = random.multivariate_normal(key, mean=mu, cov=sigma, shape=(10,), method=method)
self.assertAllClose(result[:, 0], result[:, 1], atol=1e-3, rtol=1e-3)
@ -1531,7 +1531,8 @@ class LaxRandomTest(jtu.JaxTestCase):
def test_large_prng(self):
# https://github.com/google/jax/issues/11010
def f():
return random.uniform(random.PRNGKey(3), (308000000, 128), dtype=jnp.bfloat16)
return random.uniform(
self.seed_prng(3), (308000000, 128), dtype=jnp.bfloat16)
# just lower, don't run, takes too long
jax.jit(f).lower()
@ -1545,7 +1546,7 @@ class LaxRandomTest(jtu.JaxTestCase):
logits_shape.insert(axis % (len(logits_shape_base) + 1), 10)
assert logits_shape[axis] == 10
logits = jnp.ones(logits_shape)
samples = random.categorical(random.PRNGKey(0), logits=logits,
samples = random.categorical(self.seed_prng(0), logits=logits,
axis=axis, shape=shape)
self.assertEqual(samples.shape, shape)
@ -1555,7 +1556,8 @@ class LaxRandomTest(jtu.JaxTestCase):
def testChisquare(self, df, dtype):
key = self.seed_prng(0)
rand = lambda key, df: random.chisquare(key, df, shape=(10000, ), dtype=dtype)
def rand(key, df):
return random.chisquare(key, df, shape=(10000,), dtype=dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key, df)