mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Update jax_spmd_mode
flag docstring and remove unused allow_pjit
option.
PiperOrigin-RevId: 564543943
This commit is contained in:
parent
d4adf0095f
commit
3e06dc8b77
@ -793,17 +793,15 @@ enable_memories = config.define_bool_state(
|
||||
|
||||
spmd_mode = config.define_enum_state(
|
||||
name='jax_spmd_mode',
|
||||
enum_values=['allow_all', 'allow_jit', 'allow_pjit'],
|
||||
enum_values=['allow_all', 'allow_jit'],
|
||||
default='allow_jit',
|
||||
help=("Decides whether Math on `jax.Array`'s that are not fully addressable "
|
||||
"(i.e. spans across multiple processes) is allowed. The options are: "
|
||||
"* allow_pjit: Default, only `pjit` computations are allowed to "
|
||||
" execute on non-fully addressable `jax.Array`s\n"
|
||||
"* allow_jit: `pjit` and `jax.jit` computations are allowed to "
|
||||
" execute on non-fully addressable `jax.Array`s\n"
|
||||
"* allow_jit: Default, `pjit` and `jax.jit` computations are allowed "
|
||||
" to execute on non-fully addressable `jax.Array`s\n"
|
||||
"* allow_all: `jnp`, normal math (like `a + b`, etc), `pjit`, "
|
||||
" `jax.jit` and all other operations are allowed to "
|
||||
" execute on non-fully addresable `jax.Array`s."))
|
||||
" `jax.jit` and all other operations are allowed to "
|
||||
" execute on non-fully addresable `jax.Array`s."))
|
||||
|
||||
|
||||
distributed_debug = config.define_bool_state(
|
||||
|
Loading…
x
Reference in New Issue
Block a user