mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
consistently seed keys indirectly by test class method in LaxRandomTest
This commit is contained in:
parent
556c1123cf
commit
ff70255af9
@ -1289,7 +1289,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
# Singular covariance matrix https://github.com/google/jax/discussions/13293
|
||||
mu = jnp.zeros((2,))
|
||||
sigma = jnp.ones((2, 2))
|
||||
key = random.PRNGKey(0)
|
||||
key = self.seed_prng(0)
|
||||
result = random.multivariate_normal(key, mean=mu, cov=sigma, shape=(10,), method=method)
|
||||
self.assertAllClose(result[:, 0], result[:, 1], atol=1e-3, rtol=1e-3)
|
||||
|
||||
@ -1531,7 +1531,8 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
def test_large_prng(self):
|
||||
# https://github.com/google/jax/issues/11010
|
||||
def f():
|
||||
return random.uniform(random.PRNGKey(3), (308000000, 128), dtype=jnp.bfloat16)
|
||||
return random.uniform(
|
||||
self.seed_prng(3), (308000000, 128), dtype=jnp.bfloat16)
|
||||
|
||||
# just lower, don't run, takes too long
|
||||
jax.jit(f).lower()
|
||||
@ -1545,7 +1546,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
logits_shape.insert(axis % (len(logits_shape_base) + 1), 10)
|
||||
assert logits_shape[axis] == 10
|
||||
logits = jnp.ones(logits_shape)
|
||||
samples = random.categorical(random.PRNGKey(0), logits=logits,
|
||||
samples = random.categorical(self.seed_prng(0), logits=logits,
|
||||
axis=axis, shape=shape)
|
||||
self.assertEqual(samples.shape, shape)
|
||||
|
||||
@ -1555,7 +1556,8 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
def testChisquare(self, df, dtype):
|
||||
key = self.seed_prng(0)
|
||||
|
||||
rand = lambda key, df: random.chisquare(key, df, shape=(10000, ), dtype=dtype)
|
||||
def rand(key, df):
|
||||
return random.chisquare(key, df, shape=(10000,), dtype=dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
uncompiled_samples = rand(key, df)
|
||||
|
Loading…
x
Reference in New Issue
Block a user