Enable tests related to the Gamma distribution for non-default PRNG implementations only when jax_enable_custom_prng is enabled, for consistency with other tests.

PiperOrigin-RevId: 440300882
This commit is contained in:
Joan Puigcerver 2022-04-08 01:08:24 -07:00 committed by jax authors
parent 58bdcb89e8
commit 0c02f7935a

View File

@ -869,14 +869,12 @@ class LaxRandomTest(jtu.JaxTestCase):
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.expon().cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_a={}_dtype={}_prng={}".format(a, np.dtype(dtype).name,
prng_name),
"a": a, "dtype": dtype, "prng_impl": prng_impl}
for prng_name, prng_impl in PRNG_IMPLS
{"testcase_name": "_a={}_dtype={}".format(a, np.dtype(dtype).name),
"a": a, "dtype": dtype}
for a in [0.1, 1., 10.]
for dtype in jtu.dtypes.floating))
def testGammaVsLogGamma(self, prng_impl, a, dtype):
key = prng.seed_with_impl(prng_impl, 0)
def testGammaVsLogGamma(self, a, dtype):
key = self.seed_prng(0)
rand_gamma = lambda key, a: random.gamma(key, a, (10000,), dtype)
rand_loggamma = lambda key, a: random.loggamma(key, a, (10000,), dtype)
crand_loggamma = jax.jit(rand_loggamma)
@ -885,14 +883,12 @@ class LaxRandomTest(jtu.JaxTestCase):
self.assertAllClose(rand_gamma(key, a), jnp.exp(crand_loggamma(key, a)))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_a={}_dtype={}_prng={}".format(a, np.dtype(dtype).name,
prng_name),
"a": a, "dtype": dtype, "prng_impl": prng_impl}
for prng_name, prng_impl in PRNG_IMPLS
{"testcase_name": "_a={}_dtype={}".format(a, np.dtype(dtype).name),
"a": a, "dtype": dtype}
for a in [0.1, 1., 10.]
for dtype in jtu.dtypes.floating))
def testGamma(self, prng_impl, a, dtype):
key = prng.seed_with_impl(prng_impl, 0)
def testGamma(self, a, dtype):
key = self.seed_prng(0)
rand = lambda key, a: random.gamma(key, a, (10000,), dtype)
crand = jax.jit(rand)
@ -908,13 +904,12 @@ class LaxRandomTest(jtu.JaxTestCase):
assert x.shape == (3, 2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_a={}_prng={}_logspace={}".format(alpha, prng_name, log_space),
"alpha": alpha, "log_space": log_space, "prng_impl": prng_impl}
for prng_name, prng_impl in PRNG_IMPLS
{"testcase_name": "_a={}_logspace={}".format(alpha, log_space),
"alpha": alpha, "log_space": log_space}
for log_space in [True, False]
for alpha in [1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]))
def testGammaGrad(self, log_space, prng_impl, alpha):
rng = prng.seed_with_impl(prng_impl, 0)
def testGammaGrad(self, log_space, alpha):
rng = self.seed_prng(0)
alphas = np.full((100,), alpha)
z = random.gamma(rng, alphas)
if log_space:
@ -1609,18 +1604,10 @@ def _sampler_unimplemented_with_custom_prng(*args, **kwargs):
raise SkipTest('sampler only implemented for default RNG')
for test_prefix in [
'testBeta',
'testDirichlet',
'testGamma',
'testGammaGrad',
'testGammaGradType',
'testGammaShape',
'testIssue1789',
'testPoisson',
'testPoissonBatched',
'testPoissonShape',
'testPoissonZeros',
'testT',
]:
for attr in dir(LaxRandomTest):
if attr.startswith(test_prefix):