mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix pjit_test:testWithCustomPRNGKey
This commit is contained in:
parent
6a23ae92ca
commit
3108f05eee
@ -851,7 +851,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
def testWithCustomPRNGKey(self):
|
||||
if not config.jax_enable_custom_prng:
|
||||
raise unittest.SkipTest("test requires jax_enable_custom_prng")
|
||||
key = jax.prng.seed_with_impl(jax.prng.rbg_prng_impl, 87)
|
||||
key = prng.seed_with_impl(prng.rbg_prng_impl, 87)
|
||||
# Make sure this doesn't crash
|
||||
pjit(lambda x: x, in_shardings=None, out_shardings=None)(key)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user