From 3551fcc077aa357d46caaca2a34ffb5295b6577b Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 31 Jul 2024 22:59:49 -0700 Subject: [PATCH] Deprecate several APIs in jax.lib.xla_bridge PiperOrigin-RevId: 658274719 --- CHANGELOG.md | 4 ++++ jax/lib/xla_bridge.py | 33 ++++++++++++++++++++++++++++++--- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a7e83f0a4..348972f23 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * Deprecations * Complex inputs to {func}`jax.numpy.clip` and {func}`jax.numpy.hypot` are no longer allowed, after being deprecated since JAX v0.4.27. + * Deprecated the following APIs: + * `jax.lib.xla_bridge.xla_client`: use {mod}`jax.lib.xla_client` directly. + * `jax.lib.xla_bridge.get_backend`: use {func}`jax.extend.backend.get_backend`. + * `jax.lib.xla_bridge.default_backend`: use {func}`jax.extend.backend.default_backend`. ## jaxlib 0.4.32 diff --git a/jax/lib/xla_bridge.py b/jax/lib/xla_bridge.py index 285eb3f06..e83547969 100644 --- a/jax/lib/xla_bridge.py +++ b/jax/lib/xla_bridge.py @@ -14,12 +14,39 @@ # ruff: noqa: F401 from jax._src.xla_bridge import ( - default_backend as default_backend, - get_backend as get_backend, - xla_client as xla_client, + default_backend as _deprecated_default_backend, + get_backend as _deprecated_get_backend, + xla_client as _deprecated_xla_client, _backends as _backends, ) from jax._src.compiler import ( get_compile_options as get_compile_options, ) + +_deprecations = { + # Added July 31, 2024 + "xla_client": ( + "jax.lib.xla_bridge.xla_client is deprecated; use jax.lib.xla_client directly.", + _deprecated_xla_client + ), + "get_backend": ( + "jax.lib.xla_bridge.get_backend is deprecated; use jax.extend.backend.get_backend.", + _deprecated_get_backend + ), + "default_backend": ( + "jax.lib.xla_bridge.default_backend is deprecated; use jax.extend.backend.default_backend.", + _deprecated_default_backend + ), +} + +import typing as _typing +if _typing.TYPE_CHECKING: + from jax._src.xla_bridge import default_backend as default_backend + from jax._src.xla_bridge import get_backend as get_backend + from jax._src.xla_bridge import xla_client as xla_client +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _typing