Update jax_spmd_mode flag docstring and remove unused allow_pjit option.

PiperOrigin-RevId: 564543943
This commit is contained in:
Ruoxin Sang 2023-09-11 17:07:51 -07:00 committed by jax authors
parent d4adf0095f
commit 3e06dc8b77

View File

@ -793,17 +793,15 @@ enable_memories = config.define_bool_state(
spmd_mode = config.define_enum_state(
name='jax_spmd_mode',
enum_values=['allow_all', 'allow_jit', 'allow_pjit'],
enum_values=['allow_all', 'allow_jit'],
default='allow_jit',
help=("Decides whether Math on `jax.Array`'s that are not fully addressable "
"(i.e. spans across multiple processes) is allowed. The options are: "
"* allow_pjit: Default, only `pjit` computations are allowed to "
" execute on non-fully addressable `jax.Array`s\n"
"* allow_jit: `pjit` and `jax.jit` computations are allowed to "
" execute on non-fully addressable `jax.Array`s\n"
"* allow_jit: Default, `pjit` and `jax.jit` computations are allowed "
" to execute on non-fully addressable `jax.Array`s\n"
"* allow_all: `jnp`, normal math (like `a + b`, etc), `pjit`, "
" `jax.jit` and all other operations are allowed to "
" execute on non-fully addresable `jax.Array`s."))
" `jax.jit` and all other operations are allowed to "
" execute on non-fully addresable `jax.Array`s."))
distributed_debug = config.define_bool_state(