Update warning message for jit of pmap

This commit is contained in:
rajasekharporeddy 2024-10-21 21:17:59 +05:30
parent f833891c87
commit 02f65bb11a

View File

@ -1868,7 +1868,8 @@ def _raise_warnings_or_errors_for_jit_of_pmap(
"does not preserve sharded data representations and instead collects "
"input and output arrays onto a single device. "
"Consider removing the outer jit unless you know what you're doing. "
"See https://github.com/jax-ml/jax/issues/2926.")
"See https://github.com/jax-ml/jax/issues/2926. Or "
"use jax.experimental.shard_map instead of pmap under jit compilation.")
if nreps > xb.device_count(backend):
raise ValueError(