Merge pull request #22885 from jakevdp:dep-xla

PiperOrigin-RevId: 659724044
This commit is contained in:
jax authors 2024-08-05 16:43:46 -07:00
commit 0ab4d68511

View File

@ -14,7 +14,6 @@
from jax._src.lib import xla_client as _xc
_xla = _xc._xla # TODO(jakevdp): deprecate this in favor of jax.lib.xla_extension
bfloat16 = _xc.bfloat16 # TODO(jakevdp): deprecate this in favor of ml_dtypes.bfloat16
dtype_to_etype = _xc.dtype_to_etype
@ -42,4 +41,20 @@ XlaBuilder = _xc.XlaBuilder
XlaComputation = _xc.XlaComputation
XlaRuntimeError = _xc.XlaRuntimeError
_deprecations = {
# Added Aug 5 2024
"_xla" : (
"jax.lib.xla_client._xla is deprecated; use jax.lib.xla_extension.",
_xc._xla
),
}
import typing as _typing
if _typing.TYPE_CHECKING:
_xla = _xc._xla
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _typing
del _xc