fix pjit_test:testWithCustomPRNGKey

This commit is contained in:
Jake VanderPlas 2023-04-25 10:52:15 -07:00
parent 6a23ae92ca
commit 3108f05eee

View File

@ -851,7 +851,7 @@ class PJitTest(jtu.BufferDonationTestCase):
def testWithCustomPRNGKey(self):
if not config.jax_enable_custom_prng:
raise unittest.SkipTest("test requires jax_enable_custom_prng")
key = jax.prng.seed_with_impl(jax.prng.rbg_prng_impl, 87)
key = prng.seed_with_impl(prng.rbg_prng_impl, 87)
# Make sure this doesn't crash
pjit(lambda x: x, in_shardings=None, out_shardings=None)(key)