mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Update student-t sampling to use correct key for gamma
This commit is contained in:
parent
64145393b2
commit
fbbdc35d5e
@ -1487,7 +1487,7 @@ def _t(key, df, shape, dtype) -> Array:
|
||||
n = normal(key_n, shape, dtype)
|
||||
two = _lax_const(n, 2)
|
||||
half_df = lax.div(df, two)
|
||||
g = gamma(key_n, half_df, shape, dtype)
|
||||
g = gamma(key_g, half_df, shape, dtype)
|
||||
return n * jnp.sqrt(half_df / g)
|
||||
|
||||
|
||||
|
@ -1136,7 +1136,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
)
|
||||
@jtu.skip_on_devices("cpu", "tpu") # TODO(phawkins): slow compilation times
|
||||
def testT(self, df, dtype):
|
||||
key = self.seed_prng(0)
|
||||
key = self.seed_prng(1)
|
||||
rand = lambda key, df: random.t(key, df, (10000,), dtype)
|
||||
crand = jax.jit(rand)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user