diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 6a65f1bd0..01d0abb27 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1868,7 +1868,8 @@ def _raise_warnings_or_errors_for_jit_of_pmap( "does not preserve sharded data representations and instead collects " "input and output arrays onto a single device. " "Consider removing the outer jit unless you know what you're doing. " - "See https://github.com/jax-ml/jax/issues/2926.") + "See https://github.com/jax-ml/jax/issues/2926. Or " + "use jax.experimental.shard_map instead of pmap under jit compilation.") if nreps > xb.device_count(backend): raise ValueError(