1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-20 13:56:07 +00:00

Finalize some deprecations in jax.core, jax.lib.xla_bridge, and jax.lib.xla_client.

This commit is contained in:
Jake VanderPlas 2024-12-11 09:50:33 -08:00
parent 01206f839b
commit f858a71461
5 changed files with 25 additions and 31 deletions

@ -19,6 +19,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
APIs of the same name in {mod}`jax.extend.core`; see the documentation for
{mod}`jax.extend` for information on the compatibility guarantees of these
semi-public extensions.
* Several previously-deprecated APIs have been removed, including:
* from {mod}`jax.core`: `check_eqn`, `check_type`, `check_valid_jaxtype`, and
`non_negative_dim`.
* from {mod}`jax.lib.xla_bridge`: `xla_client` and `default_backend`.
* from {mod}`jax.lib.xla_client`: `_xla` and `bfloat16`.
## jax 0.4.37 (Dec 9, 2024)

@ -11,7 +11,6 @@ jax.lib.xla_bridge
.. autosummary::
:toctree: _autosummary
default_backend
get_backend
get_compile_options

@ -160,13 +160,16 @@ _deprecations = {
"Var": ("jax.core.Var is deprecated. Use jax.extend.core.Var instead, "
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
_src_core.Var),
# Added 2024-08-14
"check_eqn": ("jax.core.check_eqn is deprecated.", _src_core.check_eqn),
"check_type": ("jax.core.check_type is deprecated.", _src_core.check_type),
# Finalized 2024-12-11; remove after 2025-3-11
"check_eqn": ("jax.core.check_eqn was removed in JAX v0.4.38.", None),
"check_type": ("jax.core.check_type was removed in JAX v0.4.38.", None),
"check_valid_jaxtype": (
("jax.core.check_valid_jaxtype is deprecated. Instead, you can manually"
("jax.core.check_valid_jaxtype was removed in JAX v0.4.38. Instead, you can manually"
" raise an error if core.valid_jaxtype() returns False."),
_src_core.check_valid_jaxtype),
None),
"non_negative_dim": (
"jax.core.non_negative_dim was removed in JAX v0.4.38. Use max_dim(..., 0).", None,
),
# Finalized 2024-09-25; remove after 2024-12-25
"pp_aval": ("jax.core.pp_aval was removed in JAX v0.4.34.", None),
"pp_eqn": ("jax.core.pp_eqn was removed in JAX v0.4.34.", None),
@ -180,10 +183,6 @@ _deprecations = {
"pp_kv_pairs": ("jax.core.pp_kv_pairs was removed in JAX v0.4.34.", None),
"pp_var": ("jax.core.pp_var was removed in JAX v0.4.34.", None),
"pp_vars": ("jax.core.pp_vars was removed in JAX v0.4.34.", None),
# Added Jan 8, 2024
"non_negative_dim": (
"jax.core.non_negative_dim is deprecated. Use max_dim(..., 0).", _src_core.non_negative_dim,
),
}
import typing
@ -207,9 +206,6 @@ if typing.TYPE_CHECKING:
Var = _src_core.Var
axis_frame = _src_core.axis_frame
call_p = _src_core.call_p
check_eqn = _src_core.check_eqn
check_type = _src_core.check_type
check_valid_jaxtype = _src_core.check_valid_jaxtype
closed_call_p = _src_core.closed_call_p
concrete_aval = _src_core.concrete_aval
dedup_referents = _src_core.dedup_referents
@ -223,7 +219,6 @@ if typing.TYPE_CHECKING:
lattice_join = _src_core.lattice_join
leaked_tracer_error = _src_core.leaked_tracer_error
maybe_find_leaked_tracers = _src_core.maybe_find_leaked_tracers
non_negative_dim = _src_core.non_negative_dim
raise_to_shaped = _src_core.raise_to_shaped
raise_to_shaped_mappings = _src_core.raise_to_shaped_mappings
reset_trace_state = _src_core.reset_trace_state

@ -14,9 +14,7 @@
# ruff: noqa: F401
from jax._src.xla_bridge import (
default_backend as _deprecated_default_backend,
get_backend as _deprecated_get_backend,
xla_client as _deprecated_xla_client,
)
from jax._src.compiler import (
@ -25,25 +23,24 @@ from jax._src.compiler import (
_deprecations = {
# Added July 31, 2024
"xla_client": (
"jax.lib.xla_bridge.xla_client is deprecated; use jax.lib.xla_client directly.",
_deprecated_xla_client
),
"get_backend": (
"jax.lib.xla_bridge.get_backend is deprecated; use jax.extend.backend.get_backend.",
_deprecated_get_backend
),
# Finalized 2024-12-11; remove after 2025-3-11
"xla_client": (
"jax.lib.xla_bridge.xla_client was removed in JAX v0.4.38; use jax.lib.xla_client directly.",
None
),
"default_backend": (
"jax.lib.xla_bridge.default_backend is deprecated; use jax.default_backend.",
_deprecated_default_backend
"jax.lib.xla_bridge.default_backend was removed in JAX v0.4.38; use jax.default_backend.",
None
),
}
import typing as _typing
if _typing.TYPE_CHECKING:
from jax._src.xla_bridge import default_backend as default_backend
from jax._src.xla_bridge import get_backend as get_backend
from jax._src.xla_bridge import xla_client as xla_client
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)

@ -27,14 +27,14 @@ OpSharding = _xc.OpSharding
Traceback = _xc.Traceback
_deprecations = {
# Added Aug 5 2024
# Finalized 2024-12-11; remove after 2025-3-11
"_xla": (
"jax.lib.xla_client._xla is deprecated; use jax.lib.xla_extension.",
_xc._xla,
"jax.lib.xla_client._xla was removed in JAX v0.4.38; use jax.lib.xla_extension.",
None,
),
"bfloat16": (
"jax.lib.xla_client.bfloat16 is deprecated; use ml_dtypes.bfloat16.",
_xc.bfloat16,
"jax.lib.xla_client.bfloat16 was removed in JAX v0.4.38; use ml_dtypes.bfloat16.",
None,
),
# Added Sep 26 2024
"Device": (
@ -104,8 +104,6 @@ _deprecations = {
import typing as _typing
if _typing.TYPE_CHECKING:
_xla = _xc._xla
bfloat16 = _xc.bfloat16
dtype_to_etype = _xc.dtype_to_etype
ops = _xc.ops
register_custom_call_target = _xc.register_custom_call_target