Fix PRNGKey handling under jit-of-pmap

This commit is contained in:
Jake VanderPlas 2024-05-13 19:04:22 -07:00
parent cd6e012326
commit 5150cfeeb0
2 changed files with 11 additions and 0 deletions

View File

@ -1319,6 +1319,8 @@ def _hlo_shard(aval, axis_env, xs, in_axis):
if aval is core.abstract_token:
return xs
elif isinstance(aval, core.ShapedArray):
if dtypes.issubdtype(aval.dtype, dtypes.extended):
aval = aval.dtype._rules.physical_element_aval(aval.dtype)
x, = xs
dims = list(aval.shape)
zero = mlir.ir_constant(np.zeros((), dtype=np.uint32))

View File

@ -2207,6 +2207,15 @@ class PythonPmapTest(jtu.JaxTestCase):
y = jax.pmap(jax.scipy.linalg.expm)(jnp.array([x, x]))
y.block_until_ready() # doesn't crash
def test_pmap_of_prng_key(self):
# Regression test for https://github.com/google/jax/issues/20392
keys = jax.random.split(jax.random.key(0), jax.device_count())
result1 = jax.pmap(jax.random.bits)(keys)
with jtu.ignore_warning(
category=UserWarning, message="The jitted function foo includes a pmap"):
result2 = jax.jit(jax.pmap(jax.random.bits))(keys)
self.assertArraysEqual(result1, result2)
@jtu.pytest_mark_if_available('multiaccelerator')
class CppPmapTest(PythonPmapTest):