mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
rename seed_prng
test method to make_key
This commit is contained in:
parent
ff70255af9
commit
ce9c2d650a
@ -609,7 +609,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
'Expected vs. actual frequencies:\n'
|
||||
f'{expected_freq}\n{actual_freq}')
|
||||
|
||||
def seed_prng(self, seed):
|
||||
def make_key(self, seed):
|
||||
return random.threefry2x32_key(seed)
|
||||
|
||||
@jtu.sample_product(dtype=jtu.dtypes.floating)
|
||||
@ -622,7 +622,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(dtype=float_dtypes)
|
||||
def testRngUniform(self, dtype):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
rand = lambda key: random.uniform(key, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
@ -638,7 +638,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
lo = 5
|
||||
hi = 10
|
||||
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
rand = lambda key: random.randint(key, (10000,), lo, hi, dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
@ -651,7 +651,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(dtype=float_dtypes)
|
||||
def testNormal(self, dtype):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
rand = lambda key: random.normal(key, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
@ -664,13 +664,13 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
def testNormalBfloat16(self):
|
||||
# Passing bfloat16 as dtype string.
|
||||
# https://github.com/google/jax/issues/6813
|
||||
res_bfloat16_str = random.normal(self.seed_prng(0), dtype='bfloat16')
|
||||
res_bfloat16 = random.normal(self.seed_prng(0), dtype=jnp.bfloat16)
|
||||
res_bfloat16_str = random.normal(self.make_key(0), dtype='bfloat16')
|
||||
res_bfloat16 = random.normal(self.make_key(0), dtype=jnp.bfloat16)
|
||||
self.assertAllClose(res_bfloat16, res_bfloat16_str)
|
||||
|
||||
@jtu.sample_product(dtype=complex_dtypes)
|
||||
def testNormalComplex(self, dtype):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
rand = lambda key: random.normal(key, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
@ -684,7 +684,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(dtype=float_dtypes)
|
||||
def testTruncatedNormal(self, dtype):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
rand = lambda key: random.truncated_normal(key, -0.3, 0.3, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
@ -700,7 +700,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(dtype=jtu.dtypes.floating + jtu.dtypes.integer)
|
||||
def testShuffle(self, dtype):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
x = np.arange(100).astype(dtype)
|
||||
rand = lambda key: random.shuffle(key, x)
|
||||
crand = jax.jit(rand)
|
||||
@ -734,7 +734,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
np_choice = np.random.default_rng(0).choice
|
||||
p_dtype = dtypes.to_inexact_dtype(dtype)
|
||||
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
is_range = type(input_range_or_shape) is int
|
||||
x = (input_range_or_shape if is_range else
|
||||
self.rng().permutation(np.arange(math.prod(
|
||||
@ -774,7 +774,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
independent=[True, False],
|
||||
)
|
||||
def testPermutation(self, dtype, range_or_shape, axis, independent):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
is_range = type(range_or_shape) is int
|
||||
x = (range_or_shape if is_range else
|
||||
self.rng().permutation(np.arange(
|
||||
@ -803,7 +803,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
'x' if is_range else None)(key, x))
|
||||
|
||||
def testPermutationErrors(self):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
with self.assertRaises(ValueError):
|
||||
random.permutation(key, 10, axis=3)
|
||||
with self.assertRaises(TypeError):
|
||||
@ -816,7 +816,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def testBernoulli(self, p, dtype):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
p = np.array(p, dtype=dtype)
|
||||
rand = lambda key, p: random.bernoulli(key, p, (10000,))
|
||||
crand = jax.jit(rand)
|
||||
@ -840,7 +840,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def testCategorical(self, p, axis, dtype, sample_shape):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
p = np.array(p, dtype=dtype)
|
||||
logits = np.log(p) - 42 # test unnormalized
|
||||
out_shape = tuple(np.delete(logits.shape, axis))
|
||||
@ -867,7 +867,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self._CheckChiSquared(samples, pmf=pmf)
|
||||
|
||||
def testBernoulliShape(self):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
with jax.numpy_rank_promotion('allow'):
|
||||
x = random.bernoulli(key, np.array([0.2, 0.3]), shape=(3, 2))
|
||||
assert x.shape == (3, 2)
|
||||
@ -880,7 +880,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
def testBeta(self, a, b, dtype):
|
||||
if not config.x64_enabled:
|
||||
raise SkipTest("skip test except on X64")
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
rand = lambda key, a, b: random.beta(key, a, b, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
@ -892,7 +892,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
def testBetaSmallParameters(self, dtype=np.float32):
|
||||
# Regression test for beta version of https://github.com/google/jax/issues/9896
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
a, b = 0.0001, 0.0002
|
||||
samples = random.beta(key, a, b, shape=(100,), dtype=dtype)
|
||||
|
||||
@ -907,7 +907,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(dtype=float_dtypes)
|
||||
def testCauchy(self, dtype):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
rand = lambda key: random.cauchy(key, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
@ -923,7 +923,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
)
|
||||
@jtu.skip_on_devices("tpu") # TODO(mattjj): slow compilation times
|
||||
def testDirichlet(self, alpha, dtype):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
rand = lambda key, alpha: random.dirichlet(key, alpha, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
@ -938,7 +938,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
def testDirichletSmallAlpha(self, dtype=np.float32):
|
||||
# Regression test for https://github.com/google/jax/issues/9896
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
alpha = 0.0001 * jnp.ones(3)
|
||||
samples = random.dirichlet(key, alpha, shape=(100,), dtype=dtype)
|
||||
|
||||
@ -953,7 +953,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(dtype=float_dtypes)
|
||||
def testExponential(self, dtype):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
rand = lambda key: random.exponential(key, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
@ -968,7 +968,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def testGammaVsLogGamma(self, a, dtype):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
rand_gamma = lambda key, a: random.gamma(key, a, (10000,), dtype)
|
||||
rand_loggamma = lambda key, a: random.loggamma(key, a, (10000,), dtype)
|
||||
crand_loggamma = jax.jit(rand_loggamma)
|
||||
@ -981,7 +981,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def testGamma(self, a, dtype):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
rand = lambda key, a: random.gamma(key, a, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
@ -992,7 +992,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gamma(a).cdf)
|
||||
|
||||
def testGammaShape(self):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
x = random.gamma(key, np.array([0.2, 0.3]), shape=(3, 2))
|
||||
assert x.shape == (3, 2)
|
||||
|
||||
@ -1001,7 +1001,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
alpha=[1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4],
|
||||
)
|
||||
def testGammaGrad(self, log_space, alpha):
|
||||
rng = self.seed_prng(0)
|
||||
rng = self.make_key(0)
|
||||
alphas = np.full((100,), alpha)
|
||||
z = random.gamma(rng, alphas)
|
||||
if log_space:
|
||||
@ -1025,7 +1025,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
def testGammaGradType(self):
|
||||
# Regression test for https://github.com/google/jax/issues/2130
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
a = jnp.array(1., dtype=jnp.float32)
|
||||
b = jnp.array(3., dtype=jnp.float32)
|
||||
f = lambda x, y: random.gamma(key=key, a=x, dtype=jnp.float32) / y
|
||||
@ -1037,7 +1037,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
dtype=[np.int16, np.int32, np.int64],
|
||||
)
|
||||
def testPoisson(self, lam, dtype):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
rand = lambda key, lam: random.poisson(key, lam, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
@ -1052,38 +1052,38 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(samples.var(), lam, rtol=0.03, check_dtypes=False)
|
||||
|
||||
def testPoissonBatched(self):
|
||||
key = self.seed_prng(1)
|
||||
key = self.make_key(1)
|
||||
lam = jnp.concatenate([2 * jnp.ones(10000), 20 * jnp.ones(10000)])
|
||||
samples = random.poisson(key, lam, shape=(20000,))
|
||||
self._CheckChiSquared(samples[:10000], scipy.stats.poisson(2.0).pmf)
|
||||
self._CheckChiSquared(samples[10000:], scipy.stats.poisson(20.0).pmf)
|
||||
|
||||
def testPoissonWithoutShape(self):
|
||||
key = self.seed_prng(1)
|
||||
key = self.make_key(1)
|
||||
lam = 2 * jnp.ones(10000)
|
||||
samples = random.poisson(key, lam)
|
||||
self._CheckChiSquared(samples, scipy.stats.poisson(2.0).pmf)
|
||||
|
||||
def testPoissonShape(self):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
x = random.poisson(key, np.array([2.0, 20.0]), shape=(3, 2))
|
||||
assert x.shape == (3, 2)
|
||||
|
||||
def testPoissonZeros(self):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
lam = jnp.concatenate([jnp.zeros(10), 20 * jnp.ones(10)])
|
||||
samples = random.poisson(key, lam, shape=(2, 20))
|
||||
self.assertArraysEqual(samples[:, :10], jnp.zeros_like(samples[:, :10]))
|
||||
|
||||
def testPoissonCornerCases(self):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
lam = jnp.array([-1, 0, jnp.nan])
|
||||
samples = random.poisson(key, lam, shape=(3,))
|
||||
self.assertArraysEqual(samples, jnp.array([-1, 0, -1]), check_dtypes=False)
|
||||
|
||||
@jtu.sample_product(dtype=jtu.dtypes.floating)
|
||||
def testGumbel(self, dtype):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
rand = lambda key: random.gumbel(key, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
@ -1095,7 +1095,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(dtype=float_dtypes)
|
||||
def testLaplace(self, dtype):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
rand = lambda key: random.laplace(key, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
@ -1107,7 +1107,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(dtype=float_dtypes)
|
||||
def testLogistic(self, dtype):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
rand = lambda key: random.logistic(key, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
@ -1124,7 +1124,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
)
|
||||
@jax.default_matmul_precision("float32")
|
||||
def testOrthogonal(self, n, shape, dtype):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
q = random.orthogonal(key, n, shape, dtype)
|
||||
self.assertEqual(q.shape, (*shape, n, n))
|
||||
self.assertEqual(q.dtype, dtype)
|
||||
@ -1140,7 +1140,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def testGeneralizedNormal(self, p, shape, dtype):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
rand = lambda key, p: random.generalized_normal(key, p, shape, dtype)
|
||||
crand = jax.jit(rand)
|
||||
uncompiled_samples = rand(key, p)
|
||||
@ -1157,7 +1157,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def testBall(self, d, p, shape, dtype):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
rand = lambda key, p: random.ball(key, d, p, shape, dtype)
|
||||
crand = jax.jit(rand)
|
||||
uncompiled_samples = rand(key, p)
|
||||
@ -1174,7 +1174,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def testPareto(self, b, dtype):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
rand = lambda key, b: random.pareto(key, b, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
@ -1185,7 +1185,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.pareto(b).cdf)
|
||||
|
||||
def testParetoShape(self):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
with jax.numpy_rank_promotion('allow'):
|
||||
x = random.pareto(key, np.array([0.2, 0.3]), shape=(3, 2))
|
||||
assert x.shape == (3, 2)
|
||||
@ -1196,7 +1196,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
)
|
||||
@jtu.skip_on_devices("cpu", "tpu") # TODO(phawkins): slow compilation times
|
||||
def testT(self, df, dtype):
|
||||
key = self.seed_prng(1)
|
||||
key = self.make_key(1)
|
||||
rand = lambda key, df: random.t(key, df, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
@ -1217,7 +1217,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
cov_factor = r.randn(dim, dim)
|
||||
cov = np.dot(cov_factor, cov_factor.T) + dim * np.eye(dim)
|
||||
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
rand = partial(random.multivariate_normal, mean=mean, cov=cov,
|
||||
shape=(10000,), method=method)
|
||||
crand = jax.jit(rand)
|
||||
@ -1246,7 +1246,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
def testMultivariateNormalShapes(self, dim, mean_batch_size, cov_batch_size,
|
||||
shape, method):
|
||||
r = self.rng()
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
eff_batch_size = mean_batch_size \
|
||||
if len(mean_batch_size) > len(cov_batch_size) else cov_batch_size
|
||||
mean = r.randn(*(mean_batch_size + (dim,)))
|
||||
@ -1269,7 +1269,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
out_np = self.rng().multivariate_normal(mean, cov, N)
|
||||
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
with jax.numpy_rank_promotion('allow'):
|
||||
out_jnp = random.multivariate_normal(key, mean=mean, cov=cov, shape=(N,))
|
||||
|
||||
@ -1289,7 +1289,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
# Singular covariance matrix https://github.com/google/jax/discussions/13293
|
||||
mu = jnp.zeros((2,))
|
||||
sigma = jnp.ones((2, 2))
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
result = random.multivariate_normal(key, mean=mu, cov=sigma, shape=(10,), method=method)
|
||||
self.assertAllClose(result[:, 0], result[:, 1], atol=1e-3, rtol=1e-3)
|
||||
|
||||
@ -1300,16 +1300,16 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self.assertFalse(np.any(np.isnan(result)))
|
||||
|
||||
def testIssue222(self):
|
||||
x = random.randint(self.seed_prng(10003), (), 0, 0)
|
||||
x = random.randint(self.make_key(10003), (), 0, 0)
|
||||
assert x == 0
|
||||
|
||||
def testFoldIn(self):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
keys = [_prng_key_as_array(random.fold_in(key, i)) for i in range(10)]
|
||||
assert np.unique(keys, axis=0).shape[0] == 10
|
||||
|
||||
def testFoldInBig(self):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
seeds = [2 ** 32 - 2, 2 ** 32 - 1]
|
||||
keys = [_prng_key_as_array(random.fold_in(key, seed)) for seed in seeds]
|
||||
assert np.unique(keys, axis=0).shape[0] == 2
|
||||
@ -1320,7 +1320,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def feature_map(n, d, sigma=1.0, seed=123):
|
||||
key = self.seed_prng(seed)
|
||||
key = self.make_key(seed)
|
||||
W = random.normal(key, (d, n)) / sigma
|
||||
w = random.normal(key, (d, )) / sigma
|
||||
b = 2 * jnp.pi * random.uniform(key, (d, ))
|
||||
@ -1332,24 +1332,24 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
lambda: feature_map(5, 3))
|
||||
|
||||
def testIssue756(self):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
w = random.normal(key, ())
|
||||
self.assertEqual(w.dtype, dtypes.canonicalize_dtype(jnp.float_))
|
||||
|
||||
def testIssue1789(self):
|
||||
def f(x):
|
||||
return random.gamma(self.seed_prng(0), x)
|
||||
return random.gamma(self.make_key(0), x)
|
||||
|
||||
grad(lambda x: jnp.sum(vmap(f)(x)))(jnp.ones(2))
|
||||
|
||||
def testDtypeErrorMessage(self):
|
||||
with self.assertRaisesRegex(ValueError, r"dtype argument to.*"):
|
||||
random.normal(self.seed_prng(0), (), dtype=jnp.int32)
|
||||
random.normal(self.make_key(0), (), dtype=jnp.int32)
|
||||
|
||||
def testRandomBroadcast(self):
|
||||
"""Issue 4033"""
|
||||
# test for broadcast issue in https://github.com/google/jax/issues/4033
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
shape = (10, 2)
|
||||
with jax.numpy_rank_promotion('allow'):
|
||||
x1 = random.uniform(key, shape, minval=jnp.zeros(2), maxval=jnp.ones(2))
|
||||
@ -1359,7 +1359,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
def testMaxwellSample(self):
|
||||
num_samples = 10**5
|
||||
rng = self.seed_prng(0)
|
||||
rng = self.make_key(0)
|
||||
|
||||
rand = lambda x: random.maxwell(x, (num_samples, ))
|
||||
crand = jax.jit(rand)
|
||||
@ -1382,7 +1382,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
('test2', 2.0, 3.0))
|
||||
def testWeibullSample(self, concentration, scale):
|
||||
num_samples = 10**5
|
||||
rng = self.seed_prng(0)
|
||||
rng = self.make_key(0)
|
||||
|
||||
rand = lambda x: random.weibull_min(x, scale, concentration, (num_samples,))
|
||||
crand = jax.jit(rand)
|
||||
@ -1406,7 +1406,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
('test2', 2.0, 3.0))
|
||||
def testDoublesidedMaxwellSample(self, loc, scale):
|
||||
num_samples = 10**4
|
||||
rng = self.seed_prng(0)
|
||||
rng = self.make_key(0)
|
||||
|
||||
rand = lambda key: random.double_sided_maxwell(
|
||||
rng, loc, scale, (num_samples,))
|
||||
@ -1443,7 +1443,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
samples, lambda x: double_sided_maxwell_cdf(x, loc, scale))
|
||||
|
||||
def testRadamacher(self):
|
||||
rng = self.seed_prng(0)
|
||||
rng = self.make_key(0)
|
||||
num_samples = 10**5
|
||||
|
||||
rand = lambda x: random.rademacher(x, (num_samples,))
|
||||
@ -1463,7 +1463,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
counts[1] / num_samples, 0.5, rtol=1e-02, atol=1e-02)
|
||||
|
||||
def testChoiceShapeIsNotSequenceError(self):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
with self.assertRaises(TypeError):
|
||||
random.choice(key, 5, 2, replace=False)
|
||||
with self.assertRaises(TypeError):
|
||||
@ -1471,7 +1471,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
def test_eval_shape_big_random_array(self):
|
||||
def f(x):
|
||||
return random.normal(self.seed_prng(x), (int(1e12),))
|
||||
return random.normal(self.make_key(x), (int(1e12),))
|
||||
with jax.enable_checks(False): # check_jaxpr will materialize array
|
||||
jax.eval_shape(f, 0) # doesn't error
|
||||
|
||||
@ -1486,18 +1486,18 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self.skipTest("Expected failure: Python int too large.")
|
||||
type_ = {"int": int, "np.array": np.array, "jnp.array": jnp.array}[type_]
|
||||
args_maker = lambda: [type_(seed)]
|
||||
f = lambda s: _maybe_unwrap(self.seed_prng(s))
|
||||
f = lambda s: _maybe_unwrap(self.make_key(s))
|
||||
self._CompileAndCheck(f, args_maker)
|
||||
|
||||
def test_prng_errors(self):
|
||||
seed = np.iinfo(np.int64).max + 1
|
||||
with self.assertRaises(OverflowError):
|
||||
self.seed_prng(seed)
|
||||
self.make_key(seed)
|
||||
with self.assertRaises(OverflowError):
|
||||
jax.jit(self.seed_prng)(seed)
|
||||
jax.jit(self.make_key)(seed)
|
||||
|
||||
def test_random_split_doesnt_device_put_during_tracing(self):
|
||||
key = self.seed_prng(1).block_until_ready()
|
||||
key = self.make_key(1).block_until_ready()
|
||||
with jtu.count_device_put() as count:
|
||||
jax.jit(random.split)(key)
|
||||
self.assertLessEqual(count[0], 1) # 1 for the argument device_put
|
||||
@ -1506,7 +1506,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
def test_randint_bounds(self, dtype):
|
||||
min = np.iinfo(dtype).min
|
||||
max = np.iinfo(dtype).max
|
||||
key = self.seed_prng(1701)
|
||||
key = self.make_key(1701)
|
||||
shape = (10,)
|
||||
if np.iinfo(dtype).bits < np.iinfo(dtypes.canonicalize_dtype(int)).bits:
|
||||
expected = random.randint(key, shape, min, max, dtype)
|
||||
@ -1515,7 +1515,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self.assertRaises(OverflowError, random.randint, key, shape, min - 12345, max + 12345, dtype)
|
||||
|
||||
def test_randint_out_of_range(self):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
|
||||
r = random.randint(key, (10,), 255, 256, np.uint8)
|
||||
self.assertAllClose(r, jnp.full_like(r, 255))
|
||||
@ -1532,7 +1532,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
# https://github.com/google/jax/issues/11010
|
||||
def f():
|
||||
return random.uniform(
|
||||
self.seed_prng(3), (308000000, 128), dtype=jnp.bfloat16)
|
||||
self.make_key(3), (308000000, 128), dtype=jnp.bfloat16)
|
||||
|
||||
# just lower, don't run, takes too long
|
||||
jax.jit(f).lower()
|
||||
@ -1546,7 +1546,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
logits_shape.insert(axis % (len(logits_shape_base) + 1), 10)
|
||||
assert logits_shape[axis] == 10
|
||||
logits = jnp.ones(logits_shape)
|
||||
samples = random.categorical(self.seed_prng(0), logits=logits,
|
||||
samples = random.categorical(self.make_key(0), logits=logits,
|
||||
axis=axis, shape=shape)
|
||||
self.assertEqual(samples.shape, shape)
|
||||
|
||||
@ -1554,7 +1554,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
df = [0.2, 1., 10., 100.],
|
||||
dtype=jtu.dtypes.floating)
|
||||
def testChisquare(self, df, dtype):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
|
||||
def rand(key, df):
|
||||
return random.chisquare(key, df, shape=(10000,), dtype=dtype)
|
||||
@ -1571,7 +1571,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
dfden = [1. ,2., 10., 100.],
|
||||
dtype=jtu.dtypes.floating)
|
||||
def testF(self, dfnum, dfden, dtype):
|
||||
key = self.seed_prng(1)
|
||||
key = self.make_key(1)
|
||||
rand = lambda key: random.f(key, dfnum, dfden, shape = (10000, ), dtype = dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
@ -1585,7 +1585,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
scale= [0.2, 1., 2., 10. ,100.],
|
||||
dtype=jtu.dtypes.floating)
|
||||
def testRayleigh(self, scale, dtype):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
rand = lambda key: random.rayleigh(key, scale, shape = (10000, ), dtype = dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
@ -1599,7 +1599,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
mean= [0.2, 1., 2., 10. ,100.],
|
||||
dtype=jtu.dtypes.floating)
|
||||
def testWald(self, mean, dtype):
|
||||
key = self.seed_prng(0)
|
||||
key = self.make_key(0)
|
||||
rand = lambda key: random.wald(key, mean, shape=(10000, ), dtype=dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
@ -1613,7 +1613,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
p= [0.2, 0.3, 0.4, 0.5 ,0.6],
|
||||
dtype= [np.int16, np.int32, np.int64])
|
||||
def testGeometric(self, p, dtype):
|
||||
key = self.seed_prng(1)
|
||||
key = self.make_key(1)
|
||||
rand = lambda key: random.geometric(key, p, shape=(10000, ), dtype=dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
@ -2069,17 +2069,17 @@ double_threefry_prng_impl = prng.PRNGImpl(
|
||||
|
||||
@jtu.with_config(jax_default_prng_impl='threefry2x32')
|
||||
class LaxRandomWithCustomPRNGTest(LaxRandomTest):
|
||||
def seed_prng(self, seed):
|
||||
def make_key(self, seed):
|
||||
return prng.seed_with_impl(double_threefry_prng_impl, seed)
|
||||
|
||||
def test_split_shape(self):
|
||||
key = self.seed_prng(73)
|
||||
key = self.make_key(73)
|
||||
keys = random.split(key, 10)
|
||||
self.assertEqual(keys.shape, (10,))
|
||||
|
||||
def test_vmap_fold_in_shape(self):
|
||||
# broadcast with scalar
|
||||
keys = random.split(self.seed_prng(73), 2)
|
||||
keys = random.split(self.make_key(73), 2)
|
||||
msgs = jnp.arange(3)
|
||||
out = vmap(lambda i: random.fold_in(keys[0], i))(msgs)
|
||||
self.assertEqual(out.shape, (3,))
|
||||
@ -2096,7 +2096,7 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
|
||||
self.assertEqual(out.shape, (2,))
|
||||
|
||||
# nested vmap
|
||||
keys = random.split(self.seed_prng(73), 2 * 3).reshape((2, 3))
|
||||
keys = random.split(self.make_key(73), 2 * 3).reshape((2, 3))
|
||||
msgs = jnp.arange(2 * 3).reshape((2, 3))
|
||||
out = vmap(vmap(random.fold_in), in_axes=(0, 1))(keys, msgs.T)
|
||||
self.assertEqual(out.shape, (2, 3))
|
||||
@ -2104,7 +2104,7 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
|
||||
self.assertEqual(out.shape, (3, 2))
|
||||
|
||||
def test_vmap_split_mapped_key(self):
|
||||
key = self.seed_prng(73)
|
||||
key = self.make_key(73)
|
||||
mapped_keys = random.split(key, num=3)
|
||||
forloop_keys = [random.split(k) for k in mapped_keys]
|
||||
vmapped_keys = vmap(random.split)(mapped_keys)
|
||||
@ -2114,7 +2114,7 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
|
||||
vk.unsafe_raw_array())
|
||||
|
||||
def test_cannot_add(self):
|
||||
key = self.seed_prng(73)
|
||||
key = self.make_key(73)
|
||||
self.assertRaisesRegex(
|
||||
ValueError, r'dtype=key<.*> is not a valid dtype for JAX type promotion.',
|
||||
lambda: key + 47)
|
||||
@ -2122,7 +2122,7 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
|
||||
@skipIf(np.__version__ == "1.21.0",
|
||||
"https://github.com/numpy/numpy/issues/19305")
|
||||
def test_grad_of_prng_key(self):
|
||||
key = self.seed_prng(73)
|
||||
key = self.make_key(73)
|
||||
with self.assertRaisesRegex(TypeError, 'grad requires real- or complex-valued inputs'):
|
||||
jax.grad(lambda x: 1.)(key)
|
||||
out = jax.grad(lambda x: 1., allow_int=True)(key)
|
||||
@ -2131,17 +2131,17 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
|
||||
|
||||
@jtu.with_config(jax_default_prng_impl='rbg')
|
||||
class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
def seed_prng(self, seed):
|
||||
return random.PRNGKey(seed, impl='rbg')
|
||||
def make_key(self, seed):
|
||||
return random.rbg_key(seed)
|
||||
|
||||
def test_split_shape(self):
|
||||
key = self.seed_prng(73)
|
||||
key = self.make_key(73)
|
||||
keys = random.split(key, 10)
|
||||
self.assertEqual(keys.shape, (10, *key.shape))
|
||||
|
||||
def test_vmap_fold_in_shape(self):
|
||||
# broadcast with scalar
|
||||
keys = random.split(self.seed_prng(73), 2)
|
||||
keys = random.split(self.make_key(73), 2)
|
||||
msgs = jnp.arange(3)
|
||||
|
||||
out = vmap(lambda i: random.fold_in(keys[0], i))(msgs)
|
||||
@ -2155,7 +2155,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
self.assertEqual(out.shape, keys.shape)
|
||||
|
||||
def test_vmap_split_not_mapped_key(self):
|
||||
key = self.seed_prng(73)
|
||||
key = self.make_key(73)
|
||||
single_split_key = random.split(key)
|
||||
vmapped_keys = vmap(lambda _: random.split(key))(jnp.zeros(3,))
|
||||
self.assertEqual(vmapped_keys.shape, (3, 2, *key.shape))
|
||||
@ -2164,7 +2164,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
_prng_key_as_array(single_split_key))
|
||||
|
||||
def test_vmap_split_mapped_key(self):
|
||||
key = self.seed_prng(73)
|
||||
key = self.make_key(73)
|
||||
mapped_keys = random.split(key, num=3)
|
||||
forloop_keys = [random.split(k) for k in mapped_keys]
|
||||
vmapped_keys = vmap(random.split)(mapped_keys)
|
||||
@ -2175,7 +2175,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
|
||||
def test_vmap_random_bits(self):
|
||||
rand_fun = lambda key: random.randint(key, (), 0, 100)
|
||||
key = self.seed_prng(73)
|
||||
key = self.make_key(73)
|
||||
mapped_keys = random.split(key, num=3)
|
||||
forloop_rand_nums = [rand_fun(k) for k in mapped_keys]
|
||||
rand_nums = vmap(rand_fun)(mapped_keys)
|
||||
@ -2183,7 +2183,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
self.assertArraysEqual(rand_nums, jnp.array(forloop_rand_nums))
|
||||
|
||||
def test_cannot_add(self):
|
||||
key = self.seed_prng(73)
|
||||
key = self.make_key(73)
|
||||
if not isinstance(key, random.PRNGKeyArray):
|
||||
raise SkipTest('relies on typed key arrays')
|
||||
self.assertRaisesRegex(
|
||||
@ -2193,7 +2193,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
@skipIf(np.__version__ == "1.21.0",
|
||||
"https://github.com/numpy/numpy/issues/19305")
|
||||
def test_grad_of_prng_key(self):
|
||||
key = self.seed_prng(73)
|
||||
key = self.make_key(73)
|
||||
with self.assertRaisesRegex(TypeError, 'grad requires real- or complex-valued inputs'):
|
||||
jax.grad(lambda x: 1.)(key)
|
||||
out = jax.grad(lambda x: 1., allow_int=True)(key)
|
||||
@ -2209,7 +2209,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
|
||||
@jtu.with_config(jax_default_prng_impl='unsafe_rbg')
|
||||
class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest):
|
||||
def seed_prng(self, seed):
|
||||
def make_key(self, seed):
|
||||
return random.unsafe_rbg_key(seed)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user