mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Finalize deprecation of xb, xc, & xe symbols in jax.interpreters.xla
PiperOrigin-RevId: 689792265
This commit is contained in:
parent
8c6164a492
commit
d4c46825d6
@ -17,6 +17,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
or with `enable_xla=False` have been deprecated since July 2024, with
|
||||
JAX version 0.4.31. Now we removed support for these use cases. `jax2tf`
|
||||
with native serialization will still be supported.
|
||||
* In `jax.interpreters.xla`, the `xb`, `xc`, and `xe` symbols have been removed
|
||||
after being deprecated in JAX v0.4.31. Instead use `xb = jax.lib.xla_bridge`,
|
||||
`xc = jax.lib.xla_client`, and `xe = jax.lib.xla_extension`.
|
||||
|
||||
## jax 0.4.35 (Oct 22, 2024)
|
||||
|
||||
|
@ -23,26 +23,24 @@ from jax._src.dispatch import (
|
||||
apply_primitive as apply_primitive,
|
||||
)
|
||||
|
||||
from jax._src import xla_bridge as _xb
|
||||
from jax._src.lib import xla_client as _xc
|
||||
|
||||
_xe = _xc._xla
|
||||
Backend = _xe.Client
|
||||
Backend = _xc._xla.Client
|
||||
del _xc
|
||||
|
||||
# Deprecations
|
||||
_deprecations = {
|
||||
# Added 2024-06-28
|
||||
# Finalized 2024-10-24; remove after 2025-01-24
|
||||
"xb": (
|
||||
"jax.interpreters.xla.xb is deprecated. Use jax.lib.xla_bridge instead.",
|
||||
_xb
|
||||
("jax.interpreters.xla.xb was removed in JAX v0.4.36. "
|
||||
"Use jax.lib.xla_bridge instead."), None
|
||||
),
|
||||
"xc": (
|
||||
"jax.interpreters.xla.xc is deprecated. Use jax.lib.xla_client instead.",
|
||||
_xc,
|
||||
("jax.interpreters.xla.xc was removed in JAX v0.4.36. "
|
||||
"Use jax.lib.xla_client instead."), None
|
||||
),
|
||||
"xe": (
|
||||
"jax.interpreters.xla.xe is deprecated. Use jax.lib.xla_extension instead.",
|
||||
_xe,
|
||||
("jax.interpreters.xla.xe was removed in JAX v0.4.36. "
|
||||
"Use jax.lib.xla_extension instead."), None
|
||||
),
|
||||
# Finalized 2024-05-13; remove after 2024-08-13
|
||||
"backend_specific_translations": (
|
||||
@ -82,13 +80,6 @@ _deprecations = {
|
||||
),
|
||||
}
|
||||
|
||||
import typing
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
if typing.TYPE_CHECKING:
|
||||
xb = _xb
|
||||
xc = _xc
|
||||
xe = _xe
|
||||
else:
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del typing
|
||||
|
Loading…
x
Reference in New Issue
Block a user