mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Update warning message for jit of pmap
This commit is contained in:
parent
f833891c87
commit
02f65bb11a
@ -1868,7 +1868,8 @@ def _raise_warnings_or_errors_for_jit_of_pmap(
|
||||
"does not preserve sharded data representations and instead collects "
|
||||
"input and output arrays onto a single device. "
|
||||
"Consider removing the outer jit unless you know what you're doing. "
|
||||
"See https://github.com/jax-ml/jax/issues/2926.")
|
||||
"See https://github.com/jax-ml/jax/issues/2926. Or "
|
||||
"use jax.experimental.shard_map instead of pmap under jit compilation.")
|
||||
|
||||
if nreps > xb.device_count(backend):
|
||||
raise ValueError(
|
||||
|
Loading…
x
Reference in New Issue
Block a user