mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
add custom_vjp
upgrade flag to jax
module and fix doc rendering for upgrade flags
This commit is contained in:
parent
ab8dd4e623
commit
e0ddd43933
@ -7,14 +7,15 @@ JAX configuration
|
||||
:toctree: _autosummary
|
||||
|
||||
config
|
||||
enable_checks
|
||||
check_tracer_leaks
|
||||
checking_leaks
|
||||
enable_custom_prng
|
||||
debug_nans
|
||||
debug_infs
|
||||
log_compiles
|
||||
default_matmul_precision
|
||||
default_prng_impl
|
||||
enable_checks
|
||||
enable_custom_prng
|
||||
enable_custom_vjp_by_custom_transpose
|
||||
log_compiles
|
||||
numpy_rank_promotion
|
||||
transfer_guard
|
||||
|
@ -43,6 +43,7 @@ from jax._src.config import (
|
||||
check_tracer_leaks as check_tracer_leaks,
|
||||
checking_leaks as checking_leaks,
|
||||
enable_custom_prng as enable_custom_prng,
|
||||
enable_custom_vjp_by_custom_transpose as enable_custom_vjp_by_custom_transpose,
|
||||
debug_nans as debug_nans,
|
||||
debug_infs as debug_infs,
|
||||
log_compiles as log_compiles,
|
||||
|
@ -57,7 +57,8 @@ def int_env(varname: str, default: int) -> int:
|
||||
UPGRADE_BOOL_HELP = (
|
||||
" This will be enabled by default in future versions of JAX, at which "
|
||||
"point all uses of the flag will be considered deprecated (following "
|
||||
"https://jax.readthedocs.io/en/latest/api_compatibility.html).")
|
||||
"the `API compatibility policy "
|
||||
"<https://jax.readthedocs.io/en/latest/api_compatibility.html>`_).")
|
||||
|
||||
UPGRADE_BOOL_EXTRA_DESC = " (transient)"
|
||||
|
||||
@ -568,6 +569,7 @@ distributed_debug = config.define_bool_state(
|
||||
enable_custom_prng = config.define_bool_state(
|
||||
name='jax_enable_custom_prng',
|
||||
default=False,
|
||||
upgrade=True,
|
||||
help=('Enables an internal upgrade that allows one to define custom '
|
||||
'pseudo-random number generator implementations.'))
|
||||
|
||||
@ -581,6 +583,7 @@ default_prng_impl = config.define_enum_state(
|
||||
enable_custom_vjp_by_custom_transpose = config.define_bool_state(
|
||||
name='jax_enable_custom_vjp_by_custom_transpose',
|
||||
default=False,
|
||||
upgrade=True,
|
||||
help=('Enables an internal upgrade that implements `jax.custom_vjp` by '
|
||||
'reduction to `jax.custom_jvp` and `jax.custom_transpose`.'))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user