mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #22885 from jakevdp:dep-xla
PiperOrigin-RevId: 659724044
This commit is contained in:
commit
0ab4d68511
@ -14,7 +14,6 @@
|
||||
|
||||
from jax._src.lib import xla_client as _xc
|
||||
|
||||
_xla = _xc._xla # TODO(jakevdp): deprecate this in favor of jax.lib.xla_extension
|
||||
bfloat16 = _xc.bfloat16 # TODO(jakevdp): deprecate this in favor of ml_dtypes.bfloat16
|
||||
|
||||
dtype_to_etype = _xc.dtype_to_etype
|
||||
@ -42,4 +41,20 @@ XlaBuilder = _xc.XlaBuilder
|
||||
XlaComputation = _xc.XlaComputation
|
||||
XlaRuntimeError = _xc.XlaRuntimeError
|
||||
|
||||
_deprecations = {
|
||||
# Added Aug 5 2024
|
||||
"_xla" : (
|
||||
"jax.lib.xla_client._xla is deprecated; use jax.lib.xla_extension.",
|
||||
_xc._xla
|
||||
),
|
||||
}
|
||||
|
||||
import typing as _typing
|
||||
if _typing.TYPE_CHECKING:
|
||||
_xla = _xc._xla
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del _typing
|
||||
del _xc
|
||||
|
Loading…
x
Reference in New Issue
Block a user