diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py index e88a4247a..23415c696 100644 --- a/jax/lib/xla_client.py +++ b/jax/lib/xla_client.py @@ -14,7 +14,6 @@ from jax._src.lib import xla_client as _xc -_xla = _xc._xla # TODO(jakevdp): deprecate this in favor of jax.lib.xla_extension bfloat16 = _xc.bfloat16 # TODO(jakevdp): deprecate this in favor of ml_dtypes.bfloat16 dtype_to_etype = _xc.dtype_to_etype @@ -42,4 +41,20 @@ XlaBuilder = _xc.XlaBuilder XlaComputation = _xc.XlaComputation XlaRuntimeError = _xc.XlaRuntimeError +_deprecations = { + # Added Aug 5 2024 + "_xla" : ( + "jax.lib.xla_client._xla is deprecated; use jax.lib.xla_extension.", + _xc._xla + ), +} + +import typing as _typing +if _typing.TYPE_CHECKING: + _xla = _xc._xla +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _typing del _xc