Only raise the warning if jax_array is enabled and the code is coming from jit.

PiperOrigin-RevId: 482053435
This commit is contained in:
Yash Katariya 2022-10-18 16:35:10 -07:00 committed by jax authors
parent 5617a02fa4
commit c9a60f9410
2 changed files with 11 additions and 10 deletions

View File

@ -2812,15 +2812,17 @@ def lower_sharding_computation(
if d.process_index == process_index]
if len(device_assignment) != len(local_device_assignment):
check_multihost_collective_allowlist(jaxpr)
warnings.warn(
"Running operations on `Array`s that are not fully addressable by this "
"process (i.e. `Array`s with data sharded across multiple devices and "
"processes.) is dangerous. Its very important that all processes run "
"the same cross-process computations in the same order otherwise it "
"can lead to hangs.\n"
"If youre not already familiar with JAXs multi-process "
"programming model, please read "
"https://jax.readthedocs.io/en/latest/multi_process.html.")
# TODO(yashkatariya): Raise an error here and add a context manager.
if config.jax_array and api_name == 'jit':
warnings.warn(
"Running operations on `Array`s that are not fully addressable by this "
"process (i.e. `Array`s with data sharded across multiple devices and "
"processes.) is dangerous. Its very important that all processes run "
"the same cross-process computations in the same order otherwise it "
"can lead to hangs.\n"
"If youre not already familiar with JAXs multi-process "
"programming model, please read "
"https://jax.readthedocs.io/en/latest/multi_process.html.")
has_outfeed = core.jaxpr_uses_outfeed(jaxpr)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)

View File

@ -19,6 +19,5 @@ filterwarnings =
# numpy uses distutils which is deprecated
ignore:The distutils.* is deprecated.*:DeprecationWarning
ignore:`sharded_jit` is deprecated. Please use `pjit` instead.*:DeprecationWarning
ignore:Running operations on `Array`s that are not fully addressable by this process.*:UserWarning
doctest_optionflags = NUMBER NORMALIZE_WHITESPACE
addopts = --doctest-glob="*.rst"