Finalize deprecations of jax.interpreters.ad config & source_info_util

These have been raising a DeprecationWarning since JAX 0.4.19, released 2023 Oct 19. I've left the undefined symbols in place for now, as they will raise an informative AttributeError.

PiperOrigin-RevId: 616931120
This commit is contained in:
Jake VanderPlas 2024-03-18 13:32:10 -07:00 committed by jax authors
parent bc363de8a5
commit 154403c03d
2 changed files with 6 additions and 17 deletions

View File

@ -18,6 +18,9 @@ Remember to align the itemized text with the first line of an item within a list
that cannot be converted to a JAX array now results in an exception.
* The deprecated flag `jax_parallel_functions_output_gda` has been removed.
This flag was long deprecated and did nothing; its use was a no-op.
* The previously-deprecated imports `jax.interpreters.ad.config` and
`jax.interpreters.ad.source_info_util` have now been removed. Use `jax.config`
and `jax.extend.source_info_util` instead.
## jaxlib 0.4.26

View File

@ -73,32 +73,18 @@ from jax._src.interpreters.ad import (
zeros_like_p as zeros_like_p,
)
from jax import config as _deprecated_config
from jax._src import source_info_util as _deprecated_source_info_util
_deprecations = {
# Added Oct 13, 2023:
# Finalized Mar 18, 2024; remove after June 18, 2024
"config": (
"jax.interpreters.ad.config is deprecated. Use jax.config directly.",
_deprecated_config,
None,
),
"source_info_util": (
"jax.interpreters.ad.source_info_util is deprecated. Use jax.extend.source_info_util.",
_deprecated_source_info_util,
None,
),
}
import typing
if typing.TYPE_CHECKING:
config = _deprecated_config
source_info_util = _deprecated_source_info_util
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing
del _deprecated_config
del _deprecated_source_info_util
def backward_pass(jaxpr, reduce_axes, transform_stack,
consts, primals_in, cotangents_in):
if reduce_axes: