From e9c7ff0b7dcecfef9396b737a22be11b1574ce2a Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 11 Oct 2024 11:42:08 -0700 Subject: [PATCH] Deprecate a number of APIs in jax.lib.xla_client. (Technically these aren't public, so they don't need a deprecation period, but this is the polite thing to do.) PiperOrigin-RevId: 684906277 --- CHANGELOG.md | 8 +++++++ jax/lib/xla_client.py | 50 ++++++++++++++++++++++++++++++++++++------- 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bd72fe6c3..7fb6c1bd0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,14 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. the `vectorized` parameter to those functions. The `vmap_method` parameter should be used instead for better defined behavior. See the discussion in {jax-issue}`#23881` for more details. + * The semi-public API `jax.lib.xla_client.register_custom_call_target` has + been deprecated. Use the JAX FFI instead. + * The semi-public APIs `jax.lib.xla_client.dtype_to_etype`, + `jax.lib.xla_client.ops`, + `jax.lib.xla_client.shape_from_pyval`, `jax.lib.xla_client.PrimitiveType`, + `jax.lib.xla_client.Shape`, `jax.lib.xla_client.XlaBuilder`, and + `jax.lib.xla_client.XlaComputation` have been deprecated. Use StableHLO + instead. ## jax 0.4.34 (October 4, 2024) diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py index ff256e89d..aaf379103 100644 --- a/jax/lib/xla_client.py +++ b/jax/lib/xla_client.py @@ -15,13 +15,9 @@ from jax._src.lax.fft import FftType as _FftType from jax._src.lib import xla_client as _xc -dtype_to_etype = _xc.dtype_to_etype get_topology_for_devices = _xc.get_topology_for_devices heap_profile = _xc.heap_profile mlir_api_version = _xc.mlir_api_version -ops = _xc.ops -register_custom_call_target = _xc.register_custom_call_target -shape_from_pyval = _xc.shape_from_pyval ArrayImpl = _xc.ArrayImpl Client = _xc.Client CompileOptions = _xc.CompileOptions @@ -29,11 +25,7 @@ DeviceAssignment = _xc.DeviceAssignment Frame = _xc.Frame HloSharding = _xc.HloSharding OpSharding = _xc.OpSharding -PrimitiveType = _xc.PrimitiveType -Shape = _xc.Shape Traceback = _xc.Traceback -XlaBuilder = _xc.XlaBuilder -XlaComputation = _xc.XlaComputation _deprecations = { # Added Aug 5 2024 @@ -69,6 +61,40 @@ _deprecations = { ), _xc.PaddingType, ), + # Added Oct 11 2024 + "dtype_to_etype": ( + "dtype_to_etype is deprecated; use StableHLO instead.", + _xc.dtype_to_etype, + ), + "ops": ( + "ops is deprecated; use StableHLO instead.", + _xc.ops, + ), + "register_custom_call_target": ( + "register_custom_call_target is deprecated; use the JAX FFI instead " + "(https://jax.readthedocs.io/en/latest/ffi.html)", + _xc.register_custom_call_target, + ), + "shape_from_pyval": ( + "shape_from_pyval is deprecated; use StableHLO instead.", + _xc.shape_from_pyval, + ), + "PrimitiveType": ( + "PrimitiveType is deprecated; use StableHLO instead.", + _xc.PrimitiveType, + ), + "Shape": ( + "Shape is deprecated; use StableHLO instead.", + _xc.Shape, + ), + "XlaBuilder": ( + "XlaBuilder is deprecated; use StableHLO instead.", + _xc.XlaBuilder, + ), + "XlaComputation": ( + "XlaComputation is deprecated; use StableHLO instead.", + _xc.XlaComputation, + ), } import typing as _typing @@ -76,9 +102,17 @@ import typing as _typing if _typing.TYPE_CHECKING: _xla = _xc._xla bfloat16 = _xc.bfloat16 + dtype_to_etype = _xc.dtype_to_etype + ops = _xc.ops + register_custom_call_target = _xc.register_custom_call_target + shape_from_pyval = _xc.shape_from_pyval Device = _xc.Device FftType = _FftType PaddingType = _xc.PaddingType + PrimitiveType = _xc.PrimitiveType + Shape = _xc.Shape + XlaBuilder = _xc.XlaBuilder + XlaComputation = _xc.XlaComputation XlaRuntimeError = _xc.XlaRuntimeError else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr