add regression test for #8070

This commit is contained in:
Matthew Johnson 2021-10-02 20:52:00 -07:00
parent aa8f2da8dc
commit fa18732746

View File

@ -486,6 +486,14 @@ class PJitTest(jtu.BufferDonationTestCase):
execution.join()
@jtu.with_mesh([('x', 2)])
def testWithCustomPRNGKey(self):
if not config.jax_enable_custom_prng:
raise unittest.SkipTest("test requires jax_enable_custom_prng")
key = jax.prng.seed_with_impl(jax.prng.rbg_prng_impl, 87)
# Make sure this doesn't crash
pjit(lambda x: x, in_axis_resources=(None), out_axis_resources=(None))(key)
def spec_regex(s):
return str(s).replace(r"(", r"\(").replace(r")", r"\)")