Merge pull request #22181 from jakevdp:xla-abbrevs

PiperOrigin-RevId: 648701764
This commit is contained in:
jax authors 2024-07-02 06:45:51 -07:00
commit 92ebb533bd

View File

@ -23,14 +23,27 @@ from jax._src.dispatch import (
apply_primitive as apply_primitive,
)
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src import xla_bridge as _xb
from jax._src.lib import xla_client as _xc
xe = xc._xla
Backend = xe.Client
_xe = _xc._xla
Backend = _xe.Client
# Deprecations
_deprecations = {
# Added 2024-06-28
"xb": (
"jax.interpreters.xla.xb is deprecated. Use jax.lib.xla_bridge instead.",
_xb
),
"xc": (
"jax.interpreters.xla.xc is deprecated. Use jax.lib.xla_client instead.",
_xc,
),
"xe": (
"jax.interpreters.xla.xe is deprecated. Use jax.lib.xla_extension instead.",
_xe,
),
# Finalized 2024-05-13; remove after 2024-08-13
"backend_specific_translations": (
"jax.interpreters.xla.backend_specific_translations is deprecated. "
@ -69,6 +82,13 @@ _deprecations = {
),
}
import typing
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
if typing.TYPE_CHECKING:
xb = _xb
xc = _xc
xe = _xe
else:
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing