Add a new jax.spmd_mode config for preventing unintentional hangs and incorrect results when users pass jax.Arrays that span across multiple processes (i.e. not fully addressable) to jit or jnp operations (that are jitted by default).

Implicitly jitted functions will **always** require a `jax.spmd_mode` context manager for operating on non-fully addressable jax.Array.

Explicitly jitted functions will require the `jax.spmd_mode` config to begin with as we roll out jax.Array since its a new behavior for `jit` (previously jit only worked on single device arrays).
* Overtime (via docs) and as users become more familiar with the new parallelism APIs, we can relax this restriction and allow explicit `jit` to work without needing the config. This can happen when we merge the frontend of `jit` and `pjit`.

PiperOrigin-RevId: 485075693
This commit is contained in:
Yash Katariya 2022-10-31 09:46:46 -07:00 committed by jax authors
parent f3ddd565c3
commit ca1f58e37b
3 changed files with 31 additions and 5 deletions

View File

@ -57,6 +57,7 @@ from jax._src.config import (
transfer_guard_host_to_device as transfer_guard_host_to_device,
transfer_guard_device_to_device as transfer_guard_device_to_device,
transfer_guard_device_to_host as transfer_guard_device_to_host,
spmd_mode as spmd_mode,
)
from .core import eval_context as ensure_compile_time_eval
from jax._src.environment_info import print_environment_info as print_environment_info

View File

@ -710,6 +710,23 @@ jax_array = config.define_bool_state(
'used.'))
spmd_mode = config.define_enum_state(
name='jax_spmd_mode',
enum_values=['allow_all', 'allow_jit', 'allow_pjit'],
# TODO(yashkatariya): Default to `allow_jit` when the training wheels come
# off.
default='allow_pjit',
help=("Decides whether Math on `jax.Array`'s that are not fully addressable "
"(i.e. spans across multiple processes) is allowed. The options are: "
"* allow_pjit: Default, only `pjit` computations are allowed to "
" execute on non-fully addressable `jax.Array`s\n"
"* allow_jit: `pjit` and `jax.jit` computations are allowed to "
" execute on non-fully addressable `jax.Array`s\n"
"* allow_all: `jnp`, normal math (like `a + b`, etc), `pjit`, "
" `jax.jit` and all other operations are allowed to "
" execute on non-fully addresable `jax.Array`s."))
distributed_debug = config.define_bool_state(
name='jax_distributed_debug',
default=False,

View File

@ -39,7 +39,6 @@ import itertools as it
import logging
import operator as op
import sys
import warnings
import threading
import types
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
@ -2810,9 +2809,16 @@ def lower_sharding_computation(
if d.process_index == d.client.process_index()]
if len(device_assignment) != len(local_device_assignment):
check_multihost_collective_allowlist(jaxpr)
# TODO(yashkatariya): Raise an error here and add a context manager.
if config.jax_array and api_name == 'jit':
warnings.warn(
# TODO(yashkatariya): Once jit and pjit's frontend is merged, use the
# argument on jit `_allow_multiprocess` (which will be added later) instead
# of the `api_name` check here.
# Furthermore, `allow_jit` is not allowed yet because `allow_jit` only
# allows explicit `jax.jit` to work but not implicitly jitted `jnp`.
# operations. This restriction will be relaxed in the future when the
# default value of `spmd_mode` config changes to `allow_jit`.
if (config.jax_array and api_name == 'jit' and
config.jax_spmd_mode != 'allow_all'):
raise RuntimeError(
"Running operations on `Array`s that are not fully addressable by this "
"process (i.e. `Array`s with data sharded across multiple devices and "
"processes.) is dangerous. Its very important that all processes run "
@ -2820,7 +2826,9 @@ def lower_sharding_computation(
"can lead to hangs.\n"
"If youre not already familiar with JAXs multi-process "
"programming model, please read "
"https://jax.readthedocs.io/en/latest/multi_process.html.")
"https://jax.readthedocs.io/en/latest/multi_process.html\n"
"To fix this error, run your `jitted` computation inside "
"`with jax.spmd_mode('allow_all'):` context manager.")
has_outfeed = core.jaxpr_uses_outfeed(jaxpr)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)