Deprecate a number of APIs in jax.lib.xla_client.

(Technically these aren't public, so they don't need a deprecation period, but this is the polite thing to do.)

PiperOrigin-RevId: 684906277
This commit is contained in:
Peter Hawkins 2024-10-11 11:42:08 -07:00 committed by jax authors
parent af50c21225
commit e9c7ff0b7d
2 changed files with 50 additions and 8 deletions

View File

@ -33,6 +33,14 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
the `vectorized` parameter to those functions. The `vmap_method` parameter
should be used instead for better defined behavior. See the discussion in
{jax-issue}`#23881` for more details.
* The semi-public API `jax.lib.xla_client.register_custom_call_target` has
been deprecated. Use the JAX FFI instead.
* The semi-public APIs `jax.lib.xla_client.dtype_to_etype`,
`jax.lib.xla_client.ops`,
`jax.lib.xla_client.shape_from_pyval`, `jax.lib.xla_client.PrimitiveType`,
`jax.lib.xla_client.Shape`, `jax.lib.xla_client.XlaBuilder`, and
`jax.lib.xla_client.XlaComputation` have been deprecated. Use StableHLO
instead.
## jax 0.4.34 (October 4, 2024)

View File

@ -15,13 +15,9 @@
from jax._src.lax.fft import FftType as _FftType
from jax._src.lib import xla_client as _xc
dtype_to_etype = _xc.dtype_to_etype
get_topology_for_devices = _xc.get_topology_for_devices
heap_profile = _xc.heap_profile
mlir_api_version = _xc.mlir_api_version
ops = _xc.ops
register_custom_call_target = _xc.register_custom_call_target
shape_from_pyval = _xc.shape_from_pyval
ArrayImpl = _xc.ArrayImpl
Client = _xc.Client
CompileOptions = _xc.CompileOptions
@ -29,11 +25,7 @@ DeviceAssignment = _xc.DeviceAssignment
Frame = _xc.Frame
HloSharding = _xc.HloSharding
OpSharding = _xc.OpSharding
PrimitiveType = _xc.PrimitiveType
Shape = _xc.Shape
Traceback = _xc.Traceback
XlaBuilder = _xc.XlaBuilder
XlaComputation = _xc.XlaComputation
_deprecations = {
# Added Aug 5 2024
@ -69,6 +61,40 @@ _deprecations = {
),
_xc.PaddingType,
),
# Added Oct 11 2024
"dtype_to_etype": (
"dtype_to_etype is deprecated; use StableHLO instead.",
_xc.dtype_to_etype,
),
"ops": (
"ops is deprecated; use StableHLO instead.",
_xc.ops,
),
"register_custom_call_target": (
"register_custom_call_target is deprecated; use the JAX FFI instead "
"(https://jax.readthedocs.io/en/latest/ffi.html)",
_xc.register_custom_call_target,
),
"shape_from_pyval": (
"shape_from_pyval is deprecated; use StableHLO instead.",
_xc.shape_from_pyval,
),
"PrimitiveType": (
"PrimitiveType is deprecated; use StableHLO instead.",
_xc.PrimitiveType,
),
"Shape": (
"Shape is deprecated; use StableHLO instead.",
_xc.Shape,
),
"XlaBuilder": (
"XlaBuilder is deprecated; use StableHLO instead.",
_xc.XlaBuilder,
),
"XlaComputation": (
"XlaComputation is deprecated; use StableHLO instead.",
_xc.XlaComputation,
),
}
import typing as _typing
@ -76,9 +102,17 @@ import typing as _typing
if _typing.TYPE_CHECKING:
_xla = _xc._xla
bfloat16 = _xc.bfloat16
dtype_to_etype = _xc.dtype_to_etype
ops = _xc.ops
register_custom_call_target = _xc.register_custom_call_target
shape_from_pyval = _xc.shape_from_pyval
Device = _xc.Device
FftType = _FftType
PaddingType = _xc.PaddingType
PrimitiveType = _xc.PrimitiveType
Shape = _xc.Shape
XlaBuilder = _xc.XlaBuilder
XlaComputation = _xc.XlaComputation
XlaRuntimeError = _xc.XlaRuntimeError
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr