mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Add a new jax.spmd_mode
config for preventing unintentional hangs and incorrect results when users pass jax.Array
s that span across multiple processes (i.e. not fully addressable) to jit
or jnp operations (that are jitted by default).
Implicitly jitted functions will **always** require a `jax.spmd_mode` context manager for operating on non-fully addressable jax.Array. Explicitly jitted functions will require the `jax.spmd_mode` config to begin with as we roll out jax.Array since its a new behavior for `jit` (previously jit only worked on single device arrays). * Overtime (via docs) and as users become more familiar with the new parallelism APIs, we can relax this restriction and allow explicit `jit` to work without needing the config. This can happen when we merge the frontend of `jit` and `pjit`. PiperOrigin-RevId: 485075693
This commit is contained in:
parent
f3ddd565c3
commit
ca1f58e37b
@ -57,6 +57,7 @@ from jax._src.config import (
|
||||
transfer_guard_host_to_device as transfer_guard_host_to_device,
|
||||
transfer_guard_device_to_device as transfer_guard_device_to_device,
|
||||
transfer_guard_device_to_host as transfer_guard_device_to_host,
|
||||
spmd_mode as spmd_mode,
|
||||
)
|
||||
from .core import eval_context as ensure_compile_time_eval
|
||||
from jax._src.environment_info import print_environment_info as print_environment_info
|
||||
|
@ -710,6 +710,23 @@ jax_array = config.define_bool_state(
|
||||
'used.'))
|
||||
|
||||
|
||||
spmd_mode = config.define_enum_state(
|
||||
name='jax_spmd_mode',
|
||||
enum_values=['allow_all', 'allow_jit', 'allow_pjit'],
|
||||
# TODO(yashkatariya): Default to `allow_jit` when the training wheels come
|
||||
# off.
|
||||
default='allow_pjit',
|
||||
help=("Decides whether Math on `jax.Array`'s that are not fully addressable "
|
||||
"(i.e. spans across multiple processes) is allowed. The options are: "
|
||||
"* allow_pjit: Default, only `pjit` computations are allowed to "
|
||||
" execute on non-fully addressable `jax.Array`s\n"
|
||||
"* allow_jit: `pjit` and `jax.jit` computations are allowed to "
|
||||
" execute on non-fully addressable `jax.Array`s\n"
|
||||
"* allow_all: `jnp`, normal math (like `a + b`, etc), `pjit`, "
|
||||
" `jax.jit` and all other operations are allowed to "
|
||||
" execute on non-fully addresable `jax.Array`s."))
|
||||
|
||||
|
||||
distributed_debug = config.define_bool_state(
|
||||
name='jax_distributed_debug',
|
||||
default=False,
|
||||
|
@ -39,7 +39,6 @@ import itertools as it
|
||||
import logging
|
||||
import operator as op
|
||||
import sys
|
||||
import warnings
|
||||
import threading
|
||||
import types
|
||||
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
|
||||
@ -2810,9 +2809,16 @@ def lower_sharding_computation(
|
||||
if d.process_index == d.client.process_index()]
|
||||
if len(device_assignment) != len(local_device_assignment):
|
||||
check_multihost_collective_allowlist(jaxpr)
|
||||
# TODO(yashkatariya): Raise an error here and add a context manager.
|
||||
if config.jax_array and api_name == 'jit':
|
||||
warnings.warn(
|
||||
# TODO(yashkatariya): Once jit and pjit's frontend is merged, use the
|
||||
# argument on jit `_allow_multiprocess` (which will be added later) instead
|
||||
# of the `api_name` check here.
|
||||
# Furthermore, `allow_jit` is not allowed yet because `allow_jit` only
|
||||
# allows explicit `jax.jit` to work but not implicitly jitted `jnp`.
|
||||
# operations. This restriction will be relaxed in the future when the
|
||||
# default value of `spmd_mode` config changes to `allow_jit`.
|
||||
if (config.jax_array and api_name == 'jit' and
|
||||
config.jax_spmd_mode != 'allow_all'):
|
||||
raise RuntimeError(
|
||||
"Running operations on `Array`s that are not fully addressable by this "
|
||||
"process (i.e. `Array`s with data sharded across multiple devices and "
|
||||
"processes.) is dangerous. It’s very important that all processes run "
|
||||
@ -2820,7 +2826,9 @@ def lower_sharding_computation(
|
||||
"can lead to hangs.\n"
|
||||
"If you’re not already familiar with JAX’s multi-process "
|
||||
"programming model, please read "
|
||||
"https://jax.readthedocs.io/en/latest/multi_process.html.")
|
||||
"https://jax.readthedocs.io/en/latest/multi_process.html\n"
|
||||
"To fix this error, run your `jitted` computation inside "
|
||||
"`with jax.spmd_mode('allow_all'):` context manager.")
|
||||
|
||||
has_outfeed = core.jaxpr_uses_outfeed(jaxpr)
|
||||
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
||||
|
Loading…
x
Reference in New Issue
Block a user