From 02f65bb11aedd725015e4c3d351c5c291f3b2230 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Mon, 21 Oct 2024 21:17:59 +0530 Subject: [PATCH] Update warning message for jit of pmap --- jax/_src/interpreters/pxla.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 6a65f1bd0..01d0abb27 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1868,7 +1868,8 @@ def _raise_warnings_or_errors_for_jit_of_pmap( "does not preserve sharded data representations and instead collects " "input and output arrays onto a single device. " "Consider removing the outer jit unless you know what you're doing. " - "See https://github.com/jax-ml/jax/issues/2926.") + "See https://github.com/jax-ml/jax/issues/2926. Or " + "use jax.experimental.shard_map instead of pmap under jit compilation.") if nreps > xb.device_count(backend): raise ValueError(