Merge pull request #16635 from froystig:random-test-cleanups2

PiperOrigin-RevId: 545838671
This commit is contained in:
jax authors 2023-07-05 18:29:09 -07:00
commit e2478685b9

View File

@ -213,9 +213,6 @@ class PrngTest(jtu.JaxTestCase):
self.assertEqual(key.dtype, jnp.dtype('uint32'))
self.assertEqual(key.shape, impl.key_shape)
def raw_key(self, *args, **kwargs):
return _prng_key_as_array(random.key(*args, **kwargs))
def testThreefry2x32(self):
# We test the hash by comparing to known values provided in the test code of
# the original reference implementation of Threefry. For the values, see
@ -532,20 +529,12 @@ class PrngTest(jtu.JaxTestCase):
self.check_key_has_impl(random.unsafe_rbg_key(42),
prng.unsafe_rbg_prng_impl)
def test_key_construction_with_explicit_impl_name(self):
key = random.key(42, impl='threefry2x32')
self.check_key_has_impl(key, prng.threefry_prng_impl)
key = random.key(42, impl='rbg')
self.check_key_has_impl(key, prng.rbg_prng_impl)
key = random.key(42, impl='unsafe_rbg')
self.check_key_has_impl(key, prng.unsafe_rbg_prng_impl)
key = random.PRNGKey(42, impl='threefry2x32')
self.check_key_has_impl(key, prng.threefry_prng_impl)
key = random.PRNGKey(42, impl='rbg')
self.check_key_has_impl(key, prng.rbg_prng_impl)
key = random.PRNGKey(42, impl='unsafe_rbg')
self.check_key_has_impl(key, prng.unsafe_rbg_prng_impl)
@parameterized.parameters([{'make_key': ctor, 'name': name, 'impl': impl}
for ctor in KEY_CTORS
for name, impl in PRNG_IMPLS])
def test_key_construction_with_explicit_impl_name(self, make_key, name, impl):
key = make_key(42, impl=name)
self.check_key_has_impl(key, impl)
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def test_isinstance(self, make_key):
@ -560,10 +549,14 @@ class PrngTest(jtu.JaxTestCase):
class ThreefryPrngTest(jtu.JaxTestCase):
def test_seed_no_implicit_transfers(self):
@parameterized.parameters([{'make_key': ctor} for ctor in [
random.threefry2x32_key,
partial(random.PRNGKey, impl='threefry2x32'),
partial(random.key, impl='threefry2x32')]])
def test_seed_no_implicit_transfers(self, make_key):
# See https://github.com/google/jax/issues/15613
with jax.transfer_guard('disallow'):
random.threefry2x32_key(jax.device_put(42)) # doesn't crash
make_key(jax.device_put(42)) # doesn't crash
class LaxRandomTest(jtu.JaxTestCase):
@ -616,7 +609,7 @@ class LaxRandomTest(jtu.JaxTestCase):
'Expected vs. actual frequencies:\n'
f'{expected_freq}\n{actual_freq}')
def seed_prng(self, seed):
def make_key(self, seed):
return random.threefry2x32_key(seed)
@jtu.sample_product(dtype=jtu.dtypes.floating)
@ -629,7 +622,7 @@ class LaxRandomTest(jtu.JaxTestCase):
@jtu.sample_product(dtype=float_dtypes)
def testRngUniform(self, dtype):
key = self.seed_prng(0)
key = self.make_key(0)
rand = lambda key: random.uniform(key, (10000,), dtype)
crand = jax.jit(rand)
@ -645,7 +638,7 @@ class LaxRandomTest(jtu.JaxTestCase):
lo = 5
hi = 10
key = self.seed_prng(0)
key = self.make_key(0)
rand = lambda key: random.randint(key, (10000,), lo, hi, dtype)
crand = jax.jit(rand)
@ -658,7 +651,7 @@ class LaxRandomTest(jtu.JaxTestCase):
@jtu.sample_product(dtype=float_dtypes)
def testNormal(self, dtype):
key = self.seed_prng(0)
key = self.make_key(0)
rand = lambda key: random.normal(key, (10000,), dtype)
crand = jax.jit(rand)
@ -671,13 +664,13 @@ class LaxRandomTest(jtu.JaxTestCase):
def testNormalBfloat16(self):
# Passing bfloat16 as dtype string.
# https://github.com/google/jax/issues/6813
res_bfloat16_str = random.normal(self.seed_prng(0), dtype='bfloat16')
res_bfloat16 = random.normal(self.seed_prng(0), dtype=jnp.bfloat16)
res_bfloat16_str = random.normal(self.make_key(0), dtype='bfloat16')
res_bfloat16 = random.normal(self.make_key(0), dtype=jnp.bfloat16)
self.assertAllClose(res_bfloat16, res_bfloat16_str)
@jtu.sample_product(dtype=complex_dtypes)
def testNormalComplex(self, dtype):
key = self.seed_prng(0)
key = self.make_key(0)
rand = lambda key: random.normal(key, (10000,), dtype)
crand = jax.jit(rand)
@ -691,7 +684,7 @@ class LaxRandomTest(jtu.JaxTestCase):
@jtu.sample_product(dtype=float_dtypes)
def testTruncatedNormal(self, dtype):
key = self.seed_prng(0)
key = self.make_key(0)
rand = lambda key: random.truncated_normal(key, -0.3, 0.3, (10000,), dtype)
crand = jax.jit(rand)
@ -707,7 +700,7 @@ class LaxRandomTest(jtu.JaxTestCase):
@jtu.sample_product(dtype=jtu.dtypes.floating + jtu.dtypes.integer)
def testShuffle(self, dtype):
key = self.seed_prng(0)
key = self.make_key(0)
x = np.arange(100).astype(dtype)
rand = lambda key: random.shuffle(key, x)
crand = jax.jit(rand)
@ -741,7 +734,7 @@ class LaxRandomTest(jtu.JaxTestCase):
np_choice = np.random.default_rng(0).choice
p_dtype = dtypes.to_inexact_dtype(dtype)
key = self.seed_prng(0)
key = self.make_key(0)
is_range = type(input_range_or_shape) is int
x = (input_range_or_shape if is_range else
self.rng().permutation(np.arange(math.prod(
@ -781,7 +774,7 @@ class LaxRandomTest(jtu.JaxTestCase):
independent=[True, False],
)
def testPermutation(self, dtype, range_or_shape, axis, independent):
key = self.seed_prng(0)
key = self.make_key(0)
is_range = type(range_or_shape) is int
x = (range_or_shape if is_range else
self.rng().permutation(np.arange(
@ -810,7 +803,7 @@ class LaxRandomTest(jtu.JaxTestCase):
'x' if is_range else None)(key, x))
def testPermutationErrors(self):
key = self.seed_prng(0)
key = self.make_key(0)
with self.assertRaises(ValueError):
random.permutation(key, 10, axis=3)
with self.assertRaises(TypeError):
@ -823,7 +816,7 @@ class LaxRandomTest(jtu.JaxTestCase):
dtype=jtu.dtypes.floating,
)
def testBernoulli(self, p, dtype):
key = self.seed_prng(0)
key = self.make_key(0)
p = np.array(p, dtype=dtype)
rand = lambda key, p: random.bernoulli(key, p, (10000,))
crand = jax.jit(rand)
@ -847,7 +840,7 @@ class LaxRandomTest(jtu.JaxTestCase):
dtype=jtu.dtypes.floating,
)
def testCategorical(self, p, axis, dtype, sample_shape):
key = self.seed_prng(0)
key = self.make_key(0)
p = np.array(p, dtype=dtype)
logits = np.log(p) - 42 # test unnormalized
out_shape = tuple(np.delete(logits.shape, axis))
@ -874,7 +867,7 @@ class LaxRandomTest(jtu.JaxTestCase):
self._CheckChiSquared(samples, pmf=pmf)
def testBernoulliShape(self):
key = self.seed_prng(0)
key = self.make_key(0)
with jax.numpy_rank_promotion('allow'):
x = random.bernoulli(key, np.array([0.2, 0.3]), shape=(3, 2))
assert x.shape == (3, 2)
@ -887,7 +880,7 @@ class LaxRandomTest(jtu.JaxTestCase):
def testBeta(self, a, b, dtype):
if not config.x64_enabled:
raise SkipTest("skip test except on X64")
key = self.seed_prng(0)
key = self.make_key(0)
rand = lambda key, a, b: random.beta(key, a, b, (10000,), dtype)
crand = jax.jit(rand)
@ -899,7 +892,7 @@ class LaxRandomTest(jtu.JaxTestCase):
def testBetaSmallParameters(self, dtype=np.float32):
# Regression test for beta version of https://github.com/google/jax/issues/9896
key = self.seed_prng(0)
key = self.make_key(0)
a, b = 0.0001, 0.0002
samples = random.beta(key, a, b, shape=(100,), dtype=dtype)
@ -914,7 +907,7 @@ class LaxRandomTest(jtu.JaxTestCase):
@jtu.sample_product(dtype=float_dtypes)
def testCauchy(self, dtype):
key = self.seed_prng(0)
key = self.make_key(0)
rand = lambda key: random.cauchy(key, (10000,), dtype)
crand = jax.jit(rand)
@ -930,7 +923,7 @@ class LaxRandomTest(jtu.JaxTestCase):
)
@jtu.skip_on_devices("tpu") # TODO(mattjj): slow compilation times
def testDirichlet(self, alpha, dtype):
key = self.seed_prng(0)
key = self.make_key(0)
rand = lambda key, alpha: random.dirichlet(key, alpha, (10000,), dtype)
crand = jax.jit(rand)
@ -945,7 +938,7 @@ class LaxRandomTest(jtu.JaxTestCase):
def testDirichletSmallAlpha(self, dtype=np.float32):
# Regression test for https://github.com/google/jax/issues/9896
key = self.seed_prng(0)
key = self.make_key(0)
alpha = 0.0001 * jnp.ones(3)
samples = random.dirichlet(key, alpha, shape=(100,), dtype=dtype)
@ -960,7 +953,7 @@ class LaxRandomTest(jtu.JaxTestCase):
@jtu.sample_product(dtype=float_dtypes)
def testExponential(self, dtype):
key = self.seed_prng(0)
key = self.make_key(0)
rand = lambda key: random.exponential(key, (10000,), dtype)
crand = jax.jit(rand)
@ -975,7 +968,7 @@ class LaxRandomTest(jtu.JaxTestCase):
dtype=jtu.dtypes.floating,
)
def testGammaVsLogGamma(self, a, dtype):
key = self.seed_prng(0)
key = self.make_key(0)
rand_gamma = lambda key, a: random.gamma(key, a, (10000,), dtype)
rand_loggamma = lambda key, a: random.loggamma(key, a, (10000,), dtype)
crand_loggamma = jax.jit(rand_loggamma)
@ -988,7 +981,7 @@ class LaxRandomTest(jtu.JaxTestCase):
dtype=jtu.dtypes.floating,
)
def testGamma(self, a, dtype):
key = self.seed_prng(0)
key = self.make_key(0)
rand = lambda key, a: random.gamma(key, a, (10000,), dtype)
crand = jax.jit(rand)
@ -999,7 +992,7 @@ class LaxRandomTest(jtu.JaxTestCase):
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gamma(a).cdf)
def testGammaShape(self):
key = self.seed_prng(0)
key = self.make_key(0)
x = random.gamma(key, np.array([0.2, 0.3]), shape=(3, 2))
assert x.shape == (3, 2)
@ -1008,7 +1001,7 @@ class LaxRandomTest(jtu.JaxTestCase):
alpha=[1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4],
)
def testGammaGrad(self, log_space, alpha):
rng = self.seed_prng(0)
rng = self.make_key(0)
alphas = np.full((100,), alpha)
z = random.gamma(rng, alphas)
if log_space:
@ -1032,7 +1025,7 @@ class LaxRandomTest(jtu.JaxTestCase):
def testGammaGradType(self):
# Regression test for https://github.com/google/jax/issues/2130
key = self.seed_prng(0)
key = self.make_key(0)
a = jnp.array(1., dtype=jnp.float32)
b = jnp.array(3., dtype=jnp.float32)
f = lambda x, y: random.gamma(key=key, a=x, dtype=jnp.float32) / y
@ -1044,7 +1037,7 @@ class LaxRandomTest(jtu.JaxTestCase):
dtype=[np.int16, np.int32, np.int64],
)
def testPoisson(self, lam, dtype):
key = self.seed_prng(0)
key = self.make_key(0)
rand = lambda key, lam: random.poisson(key, lam, (10000,), dtype)
crand = jax.jit(rand)
@ -1059,38 +1052,38 @@ class LaxRandomTest(jtu.JaxTestCase):
self.assertAllClose(samples.var(), lam, rtol=0.03, check_dtypes=False)
def testPoissonBatched(self):
key = self.seed_prng(1)
key = self.make_key(1)
lam = jnp.concatenate([2 * jnp.ones(10000), 20 * jnp.ones(10000)])
samples = random.poisson(key, lam, shape=(20000,))
self._CheckChiSquared(samples[:10000], scipy.stats.poisson(2.0).pmf)
self._CheckChiSquared(samples[10000:], scipy.stats.poisson(20.0).pmf)
def testPoissonWithoutShape(self):
key = self.seed_prng(1)
key = self.make_key(1)
lam = 2 * jnp.ones(10000)
samples = random.poisson(key, lam)
self._CheckChiSquared(samples, scipy.stats.poisson(2.0).pmf)
def testPoissonShape(self):
key = self.seed_prng(0)
key = self.make_key(0)
x = random.poisson(key, np.array([2.0, 20.0]), shape=(3, 2))
assert x.shape == (3, 2)
def testPoissonZeros(self):
key = self.seed_prng(0)
key = self.make_key(0)
lam = jnp.concatenate([jnp.zeros(10), 20 * jnp.ones(10)])
samples = random.poisson(key, lam, shape=(2, 20))
self.assertArraysEqual(samples[:, :10], jnp.zeros_like(samples[:, :10]))
def testPoissonCornerCases(self):
key = self.seed_prng(0)
key = self.make_key(0)
lam = jnp.array([-1, 0, jnp.nan])
samples = random.poisson(key, lam, shape=(3,))
self.assertArraysEqual(samples, jnp.array([-1, 0, -1]), check_dtypes=False)
@jtu.sample_product(dtype=jtu.dtypes.floating)
def testGumbel(self, dtype):
key = self.seed_prng(0)
key = self.make_key(0)
rand = lambda key: random.gumbel(key, (10000,), dtype)
crand = jax.jit(rand)
@ -1102,7 +1095,7 @@ class LaxRandomTest(jtu.JaxTestCase):
@jtu.sample_product(dtype=float_dtypes)
def testLaplace(self, dtype):
key = self.seed_prng(0)
key = self.make_key(0)
rand = lambda key: random.laplace(key, (10000,), dtype)
crand = jax.jit(rand)
@ -1114,7 +1107,7 @@ class LaxRandomTest(jtu.JaxTestCase):
@jtu.sample_product(dtype=float_dtypes)
def testLogistic(self, dtype):
key = self.seed_prng(0)
key = self.make_key(0)
rand = lambda key: random.logistic(key, (10000,), dtype)
crand = jax.jit(rand)
@ -1131,7 +1124,7 @@ class LaxRandomTest(jtu.JaxTestCase):
)
@jax.default_matmul_precision("float32")
def testOrthogonal(self, n, shape, dtype):
key = self.seed_prng(0)
key = self.make_key(0)
q = random.orthogonal(key, n, shape, dtype)
self.assertEqual(q.shape, (*shape, n, n))
self.assertEqual(q.dtype, dtype)
@ -1147,7 +1140,7 @@ class LaxRandomTest(jtu.JaxTestCase):
dtype=jtu.dtypes.floating,
)
def testGeneralizedNormal(self, p, shape, dtype):
key = self.seed_prng(0)
key = self.make_key(0)
rand = lambda key, p: random.generalized_normal(key, p, shape, dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key, p)
@ -1164,7 +1157,7 @@ class LaxRandomTest(jtu.JaxTestCase):
dtype=jtu.dtypes.floating,
)
def testBall(self, d, p, shape, dtype):
key = self.seed_prng(0)
key = self.make_key(0)
rand = lambda key, p: random.ball(key, d, p, shape, dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key, p)
@ -1181,7 +1174,7 @@ class LaxRandomTest(jtu.JaxTestCase):
dtype=jtu.dtypes.floating,
)
def testPareto(self, b, dtype):
key = self.seed_prng(0)
key = self.make_key(0)
rand = lambda key, b: random.pareto(key, b, (10000,), dtype)
crand = jax.jit(rand)
@ -1192,7 +1185,7 @@ class LaxRandomTest(jtu.JaxTestCase):
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.pareto(b).cdf)
def testParetoShape(self):
key = self.seed_prng(0)
key = self.make_key(0)
with jax.numpy_rank_promotion('allow'):
x = random.pareto(key, np.array([0.2, 0.3]), shape=(3, 2))
assert x.shape == (3, 2)
@ -1203,7 +1196,7 @@ class LaxRandomTest(jtu.JaxTestCase):
)
@jtu.skip_on_devices("cpu", "tpu") # TODO(phawkins): slow compilation times
def testT(self, df, dtype):
key = self.seed_prng(1)
key = self.make_key(1)
rand = lambda key, df: random.t(key, df, (10000,), dtype)
crand = jax.jit(rand)
@ -1224,7 +1217,7 @@ class LaxRandomTest(jtu.JaxTestCase):
cov_factor = r.randn(dim, dim)
cov = np.dot(cov_factor, cov_factor.T) + dim * np.eye(dim)
key = self.seed_prng(0)
key = self.make_key(0)
rand = partial(random.multivariate_normal, mean=mean, cov=cov,
shape=(10000,), method=method)
crand = jax.jit(rand)
@ -1253,7 +1246,7 @@ class LaxRandomTest(jtu.JaxTestCase):
def testMultivariateNormalShapes(self, dim, mean_batch_size, cov_batch_size,
shape, method):
r = self.rng()
key = self.seed_prng(0)
key = self.make_key(0)
eff_batch_size = mean_batch_size \
if len(mean_batch_size) > len(cov_batch_size) else cov_batch_size
mean = r.randn(*(mean_batch_size + (dim,)))
@ -1276,7 +1269,7 @@ class LaxRandomTest(jtu.JaxTestCase):
out_np = self.rng().multivariate_normal(mean, cov, N)
key = self.seed_prng(0)
key = self.make_key(0)
with jax.numpy_rank_promotion('allow'):
out_jnp = random.multivariate_normal(key, mean=mean, cov=cov, shape=(N,))
@ -1296,7 +1289,7 @@ class LaxRandomTest(jtu.JaxTestCase):
# Singular covariance matrix https://github.com/google/jax/discussions/13293
mu = jnp.zeros((2,))
sigma = jnp.ones((2, 2))
key = random.PRNGKey(0)
key = self.make_key(0)
result = random.multivariate_normal(key, mean=mu, cov=sigma, shape=(10,), method=method)
self.assertAllClose(result[:, 0], result[:, 1], atol=1e-3, rtol=1e-3)
@ -1307,16 +1300,16 @@ class LaxRandomTest(jtu.JaxTestCase):
self.assertFalse(np.any(np.isnan(result)))
def testIssue222(self):
x = random.randint(self.seed_prng(10003), (), 0, 0)
x = random.randint(self.make_key(10003), (), 0, 0)
assert x == 0
def testFoldIn(self):
key = self.seed_prng(0)
key = self.make_key(0)
keys = [_prng_key_as_array(random.fold_in(key, i)) for i in range(10)]
assert np.unique(keys, axis=0).shape[0] == 10
def testFoldInBig(self):
key = self.seed_prng(0)
key = self.make_key(0)
seeds = [2 ** 32 - 2, 2 ** 32 - 1]
keys = [_prng_key_as_array(random.fold_in(key, seed)) for seed in seeds]
assert np.unique(keys, axis=0).shape[0] == 2
@ -1327,7 +1320,7 @@ class LaxRandomTest(jtu.JaxTestCase):
@jax.jit
def feature_map(n, d, sigma=1.0, seed=123):
key = self.seed_prng(seed)
key = self.make_key(seed)
W = random.normal(key, (d, n)) / sigma
w = random.normal(key, (d, )) / sigma
b = 2 * jnp.pi * random.uniform(key, (d, ))
@ -1339,24 +1332,24 @@ class LaxRandomTest(jtu.JaxTestCase):
lambda: feature_map(5, 3))
def testIssue756(self):
key = self.seed_prng(0)
key = self.make_key(0)
w = random.normal(key, ())
self.assertEqual(w.dtype, dtypes.canonicalize_dtype(jnp.float_))
def testIssue1789(self):
def f(x):
return random.gamma(self.seed_prng(0), x)
return random.gamma(self.make_key(0), x)
grad(lambda x: jnp.sum(vmap(f)(x)))(jnp.ones(2))
def testDtypeErrorMessage(self):
with self.assertRaisesRegex(ValueError, r"dtype argument to.*"):
random.normal(self.seed_prng(0), (), dtype=jnp.int32)
random.normal(self.make_key(0), (), dtype=jnp.int32)
def testRandomBroadcast(self):
"""Issue 4033"""
# test for broadcast issue in https://github.com/google/jax/issues/4033
key = self.seed_prng(0)
key = self.make_key(0)
shape = (10, 2)
with jax.numpy_rank_promotion('allow'):
x1 = random.uniform(key, shape, minval=jnp.zeros(2), maxval=jnp.ones(2))
@ -1366,7 +1359,7 @@ class LaxRandomTest(jtu.JaxTestCase):
def testMaxwellSample(self):
num_samples = 10**5
rng = self.seed_prng(0)
rng = self.make_key(0)
rand = lambda x: random.maxwell(x, (num_samples, ))
crand = jax.jit(rand)
@ -1389,7 +1382,7 @@ class LaxRandomTest(jtu.JaxTestCase):
('test2', 2.0, 3.0))
def testWeibullSample(self, concentration, scale):
num_samples = 10**5
rng = self.seed_prng(0)
rng = self.make_key(0)
rand = lambda x: random.weibull_min(x, scale, concentration, (num_samples,))
crand = jax.jit(rand)
@ -1413,7 +1406,7 @@ class LaxRandomTest(jtu.JaxTestCase):
('test2', 2.0, 3.0))
def testDoublesidedMaxwellSample(self, loc, scale):
num_samples = 10**4
rng = self.seed_prng(0)
rng = self.make_key(0)
rand = lambda key: random.double_sided_maxwell(
rng, loc, scale, (num_samples,))
@ -1450,7 +1443,7 @@ class LaxRandomTest(jtu.JaxTestCase):
samples, lambda x: double_sided_maxwell_cdf(x, loc, scale))
def testRadamacher(self):
rng = self.seed_prng(0)
rng = self.make_key(0)
num_samples = 10**5
rand = lambda x: random.rademacher(x, (num_samples,))
@ -1470,7 +1463,7 @@ class LaxRandomTest(jtu.JaxTestCase):
counts[1] / num_samples, 0.5, rtol=1e-02, atol=1e-02)
def testChoiceShapeIsNotSequenceError(self):
key = self.seed_prng(0)
key = self.make_key(0)
with self.assertRaises(TypeError):
random.choice(key, 5, 2, replace=False)
with self.assertRaises(TypeError):
@ -1478,7 +1471,7 @@ class LaxRandomTest(jtu.JaxTestCase):
def test_eval_shape_big_random_array(self):
def f(x):
return random.normal(self.seed_prng(x), (int(1e12),))
return random.normal(self.make_key(x), (int(1e12),))
with jax.enable_checks(False): # check_jaxpr will materialize array
jax.eval_shape(f, 0) # doesn't error
@ -1493,18 +1486,18 @@ class LaxRandomTest(jtu.JaxTestCase):
self.skipTest("Expected failure: Python int too large.")
type_ = {"int": int, "np.array": np.array, "jnp.array": jnp.array}[type_]
args_maker = lambda: [type_(seed)]
f = lambda s: _maybe_unwrap(self.seed_prng(s))
f = lambda s: _maybe_unwrap(self.make_key(s))
self._CompileAndCheck(f, args_maker)
def test_prng_errors(self):
seed = np.iinfo(np.int64).max + 1
with self.assertRaises(OverflowError):
self.seed_prng(seed)
self.make_key(seed)
with self.assertRaises(OverflowError):
jax.jit(self.seed_prng)(seed)
jax.jit(self.make_key)(seed)
def test_random_split_doesnt_device_put_during_tracing(self):
key = self.seed_prng(1).block_until_ready()
key = self.make_key(1).block_until_ready()
with jtu.count_device_put() as count:
jax.jit(random.split)(key)
self.assertLessEqual(count[0], 1) # 1 for the argument device_put
@ -1513,7 +1506,7 @@ class LaxRandomTest(jtu.JaxTestCase):
def test_randint_bounds(self, dtype):
min = np.iinfo(dtype).min
max = np.iinfo(dtype).max
key = self.seed_prng(1701)
key = self.make_key(1701)
shape = (10,)
if np.iinfo(dtype).bits < np.iinfo(dtypes.canonicalize_dtype(int)).bits:
expected = random.randint(key, shape, min, max, dtype)
@ -1522,7 +1515,7 @@ class LaxRandomTest(jtu.JaxTestCase):
self.assertRaises(OverflowError, random.randint, key, shape, min - 12345, max + 12345, dtype)
def test_randint_out_of_range(self):
key = self.seed_prng(0)
key = self.make_key(0)
r = random.randint(key, (10,), 255, 256, np.uint8)
self.assertAllClose(r, jnp.full_like(r, 255))
@ -1538,7 +1531,8 @@ class LaxRandomTest(jtu.JaxTestCase):
def test_large_prng(self):
# https://github.com/google/jax/issues/11010
def f():
return random.uniform(random.PRNGKey(3), (308000000, 128), dtype=jnp.bfloat16)
return random.uniform(
self.make_key(3), (308000000, 128), dtype=jnp.bfloat16)
# just lower, don't run, takes too long
jax.jit(f).lower()
@ -1552,7 +1546,7 @@ class LaxRandomTest(jtu.JaxTestCase):
logits_shape.insert(axis % (len(logits_shape_base) + 1), 10)
assert logits_shape[axis] == 10
logits = jnp.ones(logits_shape)
samples = random.categorical(random.PRNGKey(0), logits=logits,
samples = random.categorical(self.make_key(0), logits=logits,
axis=axis, shape=shape)
self.assertEqual(samples.shape, shape)
@ -1560,9 +1554,10 @@ class LaxRandomTest(jtu.JaxTestCase):
df = [0.2, 1., 10., 100.],
dtype=jtu.dtypes.floating)
def testChisquare(self, df, dtype):
key = self.seed_prng(0)
key = self.make_key(0)
rand = lambda key, df: random.chisquare(key, df, shape=(10000, ), dtype=dtype)
def rand(key, df):
return random.chisquare(key, df, shape=(10000,), dtype=dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key, df)
@ -1576,7 +1571,7 @@ class LaxRandomTest(jtu.JaxTestCase):
dfden = [1. ,2., 10., 100.],
dtype=jtu.dtypes.floating)
def testF(self, dfnum, dfden, dtype):
key = self.seed_prng(1)
key = self.make_key(1)
rand = lambda key: random.f(key, dfnum, dfden, shape = (10000, ), dtype = dtype)
crand = jax.jit(rand)
@ -1590,7 +1585,7 @@ class LaxRandomTest(jtu.JaxTestCase):
scale= [0.2, 1., 2., 10. ,100.],
dtype=jtu.dtypes.floating)
def testRayleigh(self, scale, dtype):
key = self.seed_prng(0)
key = self.make_key(0)
rand = lambda key: random.rayleigh(key, scale, shape = (10000, ), dtype = dtype)
crand = jax.jit(rand)
@ -1604,7 +1599,7 @@ class LaxRandomTest(jtu.JaxTestCase):
mean= [0.2, 1., 2., 10. ,100.],
dtype=jtu.dtypes.floating)
def testWald(self, mean, dtype):
key = self.seed_prng(0)
key = self.make_key(0)
rand = lambda key: random.wald(key, mean, shape=(10000, ), dtype=dtype)
crand = jax.jit(rand)
@ -1618,7 +1613,7 @@ class LaxRandomTest(jtu.JaxTestCase):
p= [0.2, 0.3, 0.4, 0.5 ,0.6],
dtype= [np.int16, np.int32, np.int64])
def testGeometric(self, p, dtype):
key = self.seed_prng(1)
key = self.make_key(1)
rand = lambda key: random.geometric(key, p, shape=(10000, ), dtype=dtype)
crand = jax.jit(rand)
@ -2074,17 +2069,17 @@ double_threefry_prng_impl = prng.PRNGImpl(
@jtu.with_config(jax_default_prng_impl='threefry2x32')
class LaxRandomWithCustomPRNGTest(LaxRandomTest):
def seed_prng(self, seed):
def make_key(self, seed):
return prng.seed_with_impl(double_threefry_prng_impl, seed)
def test_split_shape(self):
key = self.seed_prng(73)
key = self.make_key(73)
keys = random.split(key, 10)
self.assertEqual(keys.shape, (10,))
def test_vmap_fold_in_shape(self):
# broadcast with scalar
keys = random.split(self.seed_prng(73), 2)
keys = random.split(self.make_key(73), 2)
msgs = jnp.arange(3)
out = vmap(lambda i: random.fold_in(keys[0], i))(msgs)
self.assertEqual(out.shape, (3,))
@ -2101,7 +2096,7 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
self.assertEqual(out.shape, (2,))
# nested vmap
keys = random.split(self.seed_prng(73), 2 * 3).reshape((2, 3))
keys = random.split(self.make_key(73), 2 * 3).reshape((2, 3))
msgs = jnp.arange(2 * 3).reshape((2, 3))
out = vmap(vmap(random.fold_in), in_axes=(0, 1))(keys, msgs.T)
self.assertEqual(out.shape, (2, 3))
@ -2109,7 +2104,7 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
self.assertEqual(out.shape, (3, 2))
def test_vmap_split_mapped_key(self):
key = self.seed_prng(73)
key = self.make_key(73)
mapped_keys = random.split(key, num=3)
forloop_keys = [random.split(k) for k in mapped_keys]
vmapped_keys = vmap(random.split)(mapped_keys)
@ -2119,7 +2114,7 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
vk.unsafe_raw_array())
def test_cannot_add(self):
key = self.seed_prng(73)
key = self.make_key(73)
self.assertRaisesRegex(
ValueError, r'dtype=key<.*> is not a valid dtype for JAX type promotion.',
lambda: key + 47)
@ -2127,7 +2122,7 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
@skipIf(np.__version__ == "1.21.0",
"https://github.com/numpy/numpy/issues/19305")
def test_grad_of_prng_key(self):
key = self.seed_prng(73)
key = self.make_key(73)
with self.assertRaisesRegex(TypeError, 'grad requires real- or complex-valued inputs'):
jax.grad(lambda x: 1.)(key)
out = jax.grad(lambda x: 1., allow_int=True)(key)
@ -2136,17 +2131,17 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
@jtu.with_config(jax_default_prng_impl='rbg')
class LaxRandomWithRBGPRNGTest(LaxRandomTest):
def seed_prng(self, seed):
return random.PRNGKey(seed, impl='rbg')
def make_key(self, seed):
return random.rbg_key(seed)
def test_split_shape(self):
key = self.seed_prng(73)
key = self.make_key(73)
keys = random.split(key, 10)
self.assertEqual(keys.shape, (10, *key.shape))
def test_vmap_fold_in_shape(self):
# broadcast with scalar
keys = random.split(self.seed_prng(73), 2)
keys = random.split(self.make_key(73), 2)
msgs = jnp.arange(3)
out = vmap(lambda i: random.fold_in(keys[0], i))(msgs)
@ -2160,7 +2155,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
self.assertEqual(out.shape, keys.shape)
def test_vmap_split_not_mapped_key(self):
key = self.seed_prng(73)
key = self.make_key(73)
single_split_key = random.split(key)
vmapped_keys = vmap(lambda _: random.split(key))(jnp.zeros(3,))
self.assertEqual(vmapped_keys.shape, (3, 2, *key.shape))
@ -2169,7 +2164,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
_prng_key_as_array(single_split_key))
def test_vmap_split_mapped_key(self):
key = self.seed_prng(73)
key = self.make_key(73)
mapped_keys = random.split(key, num=3)
forloop_keys = [random.split(k) for k in mapped_keys]
vmapped_keys = vmap(random.split)(mapped_keys)
@ -2180,7 +2175,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
def test_vmap_random_bits(self):
rand_fun = lambda key: random.randint(key, (), 0, 100)
key = self.seed_prng(73)
key = self.make_key(73)
mapped_keys = random.split(key, num=3)
forloop_rand_nums = [rand_fun(k) for k in mapped_keys]
rand_nums = vmap(rand_fun)(mapped_keys)
@ -2188,7 +2183,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
self.assertArraysEqual(rand_nums, jnp.array(forloop_rand_nums))
def test_cannot_add(self):
key = self.seed_prng(73)
key = self.make_key(73)
if not isinstance(key, random.PRNGKeyArray):
raise SkipTest('relies on typed key arrays')
self.assertRaisesRegex(
@ -2198,7 +2193,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
@skipIf(np.__version__ == "1.21.0",
"https://github.com/numpy/numpy/issues/19305")
def test_grad_of_prng_key(self):
key = self.seed_prng(73)
key = self.make_key(73)
with self.assertRaisesRegex(TypeError, 'grad requires real- or complex-valued inputs'):
jax.grad(lambda x: 1.)(key)
out = jax.grad(lambda x: 1., allow_int=True)(key)
@ -2214,15 +2209,32 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
@jtu.with_config(jax_default_prng_impl='unsafe_rbg')
class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest):
def seed_prng(self, seed):
def make_key(self, seed):
return random.unsafe_rbg_key(seed)
def like(keys):
return jnp.ones(keys.shape)
def _sampler_unimplemented_with_custom_prng(*args, **kwargs):
raise SkipTest('sampler only implemented for default RNG')
for test_prefix in [
'testPoisson',
'testPoissonBatched',
'testPoissonShape',
'testPoissonZeros',
]:
for attr in dir(LaxRandomTest):
if attr.startswith(test_prefix):
setattr(LaxRandomWithCustomPRNGTest, attr,
_sampler_unimplemented_with_custom_prng)
setattr(LaxRandomWithRBGPRNGTest, attr,
_sampler_unimplemented_with_custom_prng)
setattr(LaxRandomWithUnsafeRBGPRNGTest, attr,
_sampler_unimplemented_with_custom_prng)
class JnpWithKeyArrayTest(jtu.JaxTestCase):
def check_shape(self, func, *args):
like = lambda keys: jnp.ones(keys.shape)
out_key = func(*args)
self.assertIsInstance(out_key, random.KeyArray)
out_like_key = func(*tree_util.tree_map(like, args))
@ -2483,24 +2495,5 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
self.check_against_reference(func, func, keys, fill_value)
def _sampler_unimplemented_with_custom_prng(*args, **kwargs):
raise SkipTest('sampler only implemented for default RNG')
for test_prefix in [
'testPoisson',
'testPoissonBatched',
'testPoissonShape',
'testPoissonZeros',
]:
for attr in dir(LaxRandomTest):
if attr.startswith(test_prefix):
setattr(LaxRandomWithCustomPRNGTest, attr,
_sampler_unimplemented_with_custom_prng)
setattr(LaxRandomWithRBGPRNGTest, attr,
_sampler_unimplemented_with_custom_prng)
setattr(LaxRandomWithUnsafeRBGPRNGTest, attr,
_sampler_unimplemented_with_custom_prng)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())