mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
add regression test for #8070
This commit is contained in:
parent
aa8f2da8dc
commit
fa18732746
@ -486,6 +486,14 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
|
||||
execution.join()
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testWithCustomPRNGKey(self):
|
||||
if not config.jax_enable_custom_prng:
|
||||
raise unittest.SkipTest("test requires jax_enable_custom_prng")
|
||||
key = jax.prng.seed_with_impl(jax.prng.rbg_prng_impl, 87)
|
||||
# Make sure this doesn't crash
|
||||
pjit(lambda x: x, in_axis_resources=(None), out_axis_resources=(None))(key)
|
||||
|
||||
def spec_regex(s):
|
||||
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user