add custom_vjp upgrade flag to jax module and fix doc rendering for upgrade flags

This commit is contained in:
Roy Frostig 2022-03-28 17:17:33 -07:00
parent ab8dd4e623
commit e0ddd43933
3 changed files with 9 additions and 4 deletions

View File

@ -7,14 +7,15 @@ JAX configuration
:toctree: _autosummary
config
enable_checks
check_tracer_leaks
checking_leaks
enable_custom_prng
debug_nans
debug_infs
log_compiles
default_matmul_precision
default_prng_impl
enable_checks
enable_custom_prng
enable_custom_vjp_by_custom_transpose
log_compiles
numpy_rank_promotion
transfer_guard

View File

@ -43,6 +43,7 @@ from jax._src.config import (
check_tracer_leaks as check_tracer_leaks,
checking_leaks as checking_leaks,
enable_custom_prng as enable_custom_prng,
enable_custom_vjp_by_custom_transpose as enable_custom_vjp_by_custom_transpose,
debug_nans as debug_nans,
debug_infs as debug_infs,
log_compiles as log_compiles,

View File

@ -57,7 +57,8 @@ def int_env(varname: str, default: int) -> int:
UPGRADE_BOOL_HELP = (
" This will be enabled by default in future versions of JAX, at which "
"point all uses of the flag will be considered deprecated (following "
"https://jax.readthedocs.io/en/latest/api_compatibility.html).")
"the `API compatibility policy "
"<https://jax.readthedocs.io/en/latest/api_compatibility.html>`_).")
UPGRADE_BOOL_EXTRA_DESC = " (transient)"
@ -568,6 +569,7 @@ distributed_debug = config.define_bool_state(
enable_custom_prng = config.define_bool_state(
name='jax_enable_custom_prng',
default=False,
upgrade=True,
help=('Enables an internal upgrade that allows one to define custom '
'pseudo-random number generator implementations.'))
@ -581,6 +583,7 @@ default_prng_impl = config.define_enum_state(
enable_custom_vjp_by_custom_transpose = config.define_bool_state(
name='jax_enable_custom_vjp_by_custom_transpose',
default=False,
upgrade=True,
help=('Enables an internal upgrade that implements `jax.custom_vjp` by '
'reduction to `jax.custom_jvp` and `jax.custom_transpose`.'))