diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 04c361c1e..9c50908b6 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1319,6 +1319,8 @@ def _hlo_shard(aval, axis_env, xs, in_axis): if aval is core.abstract_token: return xs elif isinstance(aval, core.ShapedArray): + if dtypes.issubdtype(aval.dtype, dtypes.extended): + aval = aval.dtype._rules.physical_element_aval(aval.dtype) x, = xs dims = list(aval.shape) zero = mlir.ir_constant(np.zeros((), dtype=np.uint32)) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 6f2ca23ae..de20cda55 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -2207,6 +2207,15 @@ class PythonPmapTest(jtu.JaxTestCase): y = jax.pmap(jax.scipy.linalg.expm)(jnp.array([x, x])) y.block_until_ready() # doesn't crash + def test_pmap_of_prng_key(self): + # Regression test for https://github.com/google/jax/issues/20392 + keys = jax.random.split(jax.random.key(0), jax.device_count()) + result1 = jax.pmap(jax.random.bits)(keys) + with jtu.ignore_warning( + category=UserWarning, message="The jitted function foo includes a pmap"): + result2 = jax.jit(jax.pmap(jax.random.bits))(keys) + self.assertArraysEqual(result1, result2) + @jtu.pytest_mark_if_available('multiaccelerator') class CppPmapTest(PythonPmapTest):