From 3e06dc8b77a444671ae66de62b47362bd27a8a5b Mon Sep 17 00:00:00 2001 From: Ruoxin Sang Date: Mon, 11 Sep 2023 17:07:51 -0700 Subject: [PATCH] Update `jax_spmd_mode` flag docstring and remove unused `allow_pjit` option. PiperOrigin-RevId: 564543943 --- jax/_src/config.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index f122ffb05..3fbce2ce3 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -793,17 +793,15 @@ enable_memories = config.define_bool_state( spmd_mode = config.define_enum_state( name='jax_spmd_mode', - enum_values=['allow_all', 'allow_jit', 'allow_pjit'], + enum_values=['allow_all', 'allow_jit'], default='allow_jit', help=("Decides whether Math on `jax.Array`'s that are not fully addressable " "(i.e. spans across multiple processes) is allowed. The options are: " - "* allow_pjit: Default, only `pjit` computations are allowed to " - " execute on non-fully addressable `jax.Array`s\n" - "* allow_jit: `pjit` and `jax.jit` computations are allowed to " - " execute on non-fully addressable `jax.Array`s\n" + "* allow_jit: Default, `pjit` and `jax.jit` computations are allowed " + " to execute on non-fully addressable `jax.Array`s\n" "* allow_all: `jnp`, normal math (like `a + b`, etc), `pjit`, " - " `jax.jit` and all other operations are allowed to " - " execute on non-fully addresable `jax.Array`s.")) + " `jax.jit` and all other operations are allowed to " + " execute on non-fully addresable `jax.Array`s.")) distributed_debug = config.define_bool_state(