Update student-t sampling to use correct key for gamma

This commit is contained in:
tennessee_wallaceh 2023-02-21 14:52:05 +00:00
parent 64145393b2
commit fbbdc35d5e
2 changed files with 2 additions and 2 deletions

View File

@ -1487,7 +1487,7 @@ def _t(key, df, shape, dtype) -> Array:
n = normal(key_n, shape, dtype)
two = _lax_const(n, 2)
half_df = lax.div(df, two)
g = gamma(key_n, half_df, shape, dtype)
g = gamma(key_g, half_df, shape, dtype)
return n * jnp.sqrt(half_df / g)

View File

@ -1136,7 +1136,7 @@ class LaxRandomTest(jtu.JaxTestCase):
)
@jtu.skip_on_devices("cpu", "tpu") # TODO(phawkins): slow compilation times
def testT(self, df, dtype):
key = self.seed_prng(0)
key = self.seed_prng(1)
rand = lambda key, df: random.t(key, df, (10000,), dtype)
crand = jax.jit(rand)