mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Enable tests related to the Gamma distribution for non-default PRNG implementations only when jax_enable_custom_prng is enabled, for consistency with other tests.
PiperOrigin-RevId: 440300882
This commit is contained in:
parent
58bdcb89e8
commit
0c02f7935a
@ -869,14 +869,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.expon().cdf)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_a={}_dtype={}_prng={}".format(a, np.dtype(dtype).name,
|
||||
prng_name),
|
||||
"a": a, "dtype": dtype, "prng_impl": prng_impl}
|
||||
for prng_name, prng_impl in PRNG_IMPLS
|
||||
{"testcase_name": "_a={}_dtype={}".format(a, np.dtype(dtype).name),
|
||||
"a": a, "dtype": dtype}
|
||||
for a in [0.1, 1., 10.]
|
||||
for dtype in jtu.dtypes.floating))
|
||||
def testGammaVsLogGamma(self, prng_impl, a, dtype):
|
||||
key = prng.seed_with_impl(prng_impl, 0)
|
||||
def testGammaVsLogGamma(self, a, dtype):
|
||||
key = self.seed_prng(0)
|
||||
rand_gamma = lambda key, a: random.gamma(key, a, (10000,), dtype)
|
||||
rand_loggamma = lambda key, a: random.loggamma(key, a, (10000,), dtype)
|
||||
crand_loggamma = jax.jit(rand_loggamma)
|
||||
@ -885,14 +883,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(rand_gamma(key, a), jnp.exp(crand_loggamma(key, a)))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_a={}_dtype={}_prng={}".format(a, np.dtype(dtype).name,
|
||||
prng_name),
|
||||
"a": a, "dtype": dtype, "prng_impl": prng_impl}
|
||||
for prng_name, prng_impl in PRNG_IMPLS
|
||||
{"testcase_name": "_a={}_dtype={}".format(a, np.dtype(dtype).name),
|
||||
"a": a, "dtype": dtype}
|
||||
for a in [0.1, 1., 10.]
|
||||
for dtype in jtu.dtypes.floating))
|
||||
def testGamma(self, prng_impl, a, dtype):
|
||||
key = prng.seed_with_impl(prng_impl, 0)
|
||||
def testGamma(self, a, dtype):
|
||||
key = self.seed_prng(0)
|
||||
rand = lambda key, a: random.gamma(key, a, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
@ -908,13 +904,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
assert x.shape == (3, 2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_a={}_prng={}_logspace={}".format(alpha, prng_name, log_space),
|
||||
"alpha": alpha, "log_space": log_space, "prng_impl": prng_impl}
|
||||
for prng_name, prng_impl in PRNG_IMPLS
|
||||
{"testcase_name": "_a={}_logspace={}".format(alpha, log_space),
|
||||
"alpha": alpha, "log_space": log_space}
|
||||
for log_space in [True, False]
|
||||
for alpha in [1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]))
|
||||
def testGammaGrad(self, log_space, prng_impl, alpha):
|
||||
rng = prng.seed_with_impl(prng_impl, 0)
|
||||
def testGammaGrad(self, log_space, alpha):
|
||||
rng = self.seed_prng(0)
|
||||
alphas = np.full((100,), alpha)
|
||||
z = random.gamma(rng, alphas)
|
||||
if log_space:
|
||||
@ -1609,18 +1604,10 @@ def _sampler_unimplemented_with_custom_prng(*args, **kwargs):
|
||||
raise SkipTest('sampler only implemented for default RNG')
|
||||
|
||||
for test_prefix in [
|
||||
'testBeta',
|
||||
'testDirichlet',
|
||||
'testGamma',
|
||||
'testGammaGrad',
|
||||
'testGammaGradType',
|
||||
'testGammaShape',
|
||||
'testIssue1789',
|
||||
'testPoisson',
|
||||
'testPoissonBatched',
|
||||
'testPoissonShape',
|
||||
'testPoissonZeros',
|
||||
'testT',
|
||||
]:
|
||||
for attr in dir(LaxRandomTest):
|
||||
if attr.startswith(test_prefix):
|
||||
|
Loading…
x
Reference in New Issue
Block a user