Deprecate jax.interpreters xb, xc, xe abbreviations.

Instead, import directly as jax.lib.xla_bridge, jax.lib.xla_client, jax.lib.xla_extension.
This commit is contained in:
Jake VanderPlas 2024-06-28 10:47:37 -07:00
parent 8c889b50c0
commit 251dfcad3c

View File

@ -23,14 +23,27 @@ from jax._src.dispatch import (
apply_primitive as apply_primitive,
)
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src import xla_bridge as _xb
from jax._src.lib import xla_client as _xc
xe = xc._xla
Backend = xe.Client
_xe = _xc._xla
Backend = _xe.Client
# Deprecations
_deprecations = {
# Added 2024-06-28
"xb": (
"jax.interpreters.xla.xb is deprecated. Use jax.lib.xla_bridge instead.",
_xb
),
"xc": (
"jax.interpreters.xla.xc is deprecated. Use jax.lib.xla_client instead.",
_xc,
),
"xe": (
"jax.interpreters.xla.xe is deprecated. Use jax.lib.xla_extension instead.",
_xe,
),
# Finalized 2024-05-13; remove after 2024-08-13
"backend_specific_translations": (
"jax.interpreters.xla.backend_specific_translations is deprecated. "
@ -69,6 +82,13 @@ _deprecations = {
),
}
import typing
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
if typing.TYPE_CHECKING:
xb = _xb
xc = _xc
xe = _xe
else:
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing