mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Deprecate a number of APIs in jax.lib.xla_client.
(Technically these aren't public, so they don't need a deprecation period, but this is the polite thing to do.) PiperOrigin-RevId: 684906277
This commit is contained in:
parent
af50c21225
commit
e9c7ff0b7d
@ -33,6 +33,14 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
|||||||
the `vectorized` parameter to those functions. The `vmap_method` parameter
|
the `vectorized` parameter to those functions. The `vmap_method` parameter
|
||||||
should be used instead for better defined behavior. See the discussion in
|
should be used instead for better defined behavior. See the discussion in
|
||||||
{jax-issue}`#23881` for more details.
|
{jax-issue}`#23881` for more details.
|
||||||
|
* The semi-public API `jax.lib.xla_client.register_custom_call_target` has
|
||||||
|
been deprecated. Use the JAX FFI instead.
|
||||||
|
* The semi-public APIs `jax.lib.xla_client.dtype_to_etype`,
|
||||||
|
`jax.lib.xla_client.ops`,
|
||||||
|
`jax.lib.xla_client.shape_from_pyval`, `jax.lib.xla_client.PrimitiveType`,
|
||||||
|
`jax.lib.xla_client.Shape`, `jax.lib.xla_client.XlaBuilder`, and
|
||||||
|
`jax.lib.xla_client.XlaComputation` have been deprecated. Use StableHLO
|
||||||
|
instead.
|
||||||
|
|
||||||
## jax 0.4.34 (October 4, 2024)
|
## jax 0.4.34 (October 4, 2024)
|
||||||
|
|
||||||
|
@ -15,13 +15,9 @@
|
|||||||
from jax._src.lax.fft import FftType as _FftType
|
from jax._src.lax.fft import FftType as _FftType
|
||||||
from jax._src.lib import xla_client as _xc
|
from jax._src.lib import xla_client as _xc
|
||||||
|
|
||||||
dtype_to_etype = _xc.dtype_to_etype
|
|
||||||
get_topology_for_devices = _xc.get_topology_for_devices
|
get_topology_for_devices = _xc.get_topology_for_devices
|
||||||
heap_profile = _xc.heap_profile
|
heap_profile = _xc.heap_profile
|
||||||
mlir_api_version = _xc.mlir_api_version
|
mlir_api_version = _xc.mlir_api_version
|
||||||
ops = _xc.ops
|
|
||||||
register_custom_call_target = _xc.register_custom_call_target
|
|
||||||
shape_from_pyval = _xc.shape_from_pyval
|
|
||||||
ArrayImpl = _xc.ArrayImpl
|
ArrayImpl = _xc.ArrayImpl
|
||||||
Client = _xc.Client
|
Client = _xc.Client
|
||||||
CompileOptions = _xc.CompileOptions
|
CompileOptions = _xc.CompileOptions
|
||||||
@ -29,11 +25,7 @@ DeviceAssignment = _xc.DeviceAssignment
|
|||||||
Frame = _xc.Frame
|
Frame = _xc.Frame
|
||||||
HloSharding = _xc.HloSharding
|
HloSharding = _xc.HloSharding
|
||||||
OpSharding = _xc.OpSharding
|
OpSharding = _xc.OpSharding
|
||||||
PrimitiveType = _xc.PrimitiveType
|
|
||||||
Shape = _xc.Shape
|
|
||||||
Traceback = _xc.Traceback
|
Traceback = _xc.Traceback
|
||||||
XlaBuilder = _xc.XlaBuilder
|
|
||||||
XlaComputation = _xc.XlaComputation
|
|
||||||
|
|
||||||
_deprecations = {
|
_deprecations = {
|
||||||
# Added Aug 5 2024
|
# Added Aug 5 2024
|
||||||
@ -69,6 +61,40 @@ _deprecations = {
|
|||||||
),
|
),
|
||||||
_xc.PaddingType,
|
_xc.PaddingType,
|
||||||
),
|
),
|
||||||
|
# Added Oct 11 2024
|
||||||
|
"dtype_to_etype": (
|
||||||
|
"dtype_to_etype is deprecated; use StableHLO instead.",
|
||||||
|
_xc.dtype_to_etype,
|
||||||
|
),
|
||||||
|
"ops": (
|
||||||
|
"ops is deprecated; use StableHLO instead.",
|
||||||
|
_xc.ops,
|
||||||
|
),
|
||||||
|
"register_custom_call_target": (
|
||||||
|
"register_custom_call_target is deprecated; use the JAX FFI instead "
|
||||||
|
"(https://jax.readthedocs.io/en/latest/ffi.html)",
|
||||||
|
_xc.register_custom_call_target,
|
||||||
|
),
|
||||||
|
"shape_from_pyval": (
|
||||||
|
"shape_from_pyval is deprecated; use StableHLO instead.",
|
||||||
|
_xc.shape_from_pyval,
|
||||||
|
),
|
||||||
|
"PrimitiveType": (
|
||||||
|
"PrimitiveType is deprecated; use StableHLO instead.",
|
||||||
|
_xc.PrimitiveType,
|
||||||
|
),
|
||||||
|
"Shape": (
|
||||||
|
"Shape is deprecated; use StableHLO instead.",
|
||||||
|
_xc.Shape,
|
||||||
|
),
|
||||||
|
"XlaBuilder": (
|
||||||
|
"XlaBuilder is deprecated; use StableHLO instead.",
|
||||||
|
_xc.XlaBuilder,
|
||||||
|
),
|
||||||
|
"XlaComputation": (
|
||||||
|
"XlaComputation is deprecated; use StableHLO instead.",
|
||||||
|
_xc.XlaComputation,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
import typing as _typing
|
import typing as _typing
|
||||||
@ -76,9 +102,17 @@ import typing as _typing
|
|||||||
if _typing.TYPE_CHECKING:
|
if _typing.TYPE_CHECKING:
|
||||||
_xla = _xc._xla
|
_xla = _xc._xla
|
||||||
bfloat16 = _xc.bfloat16
|
bfloat16 = _xc.bfloat16
|
||||||
|
dtype_to_etype = _xc.dtype_to_etype
|
||||||
|
ops = _xc.ops
|
||||||
|
register_custom_call_target = _xc.register_custom_call_target
|
||||||
|
shape_from_pyval = _xc.shape_from_pyval
|
||||||
Device = _xc.Device
|
Device = _xc.Device
|
||||||
FftType = _FftType
|
FftType = _FftType
|
||||||
PaddingType = _xc.PaddingType
|
PaddingType = _xc.PaddingType
|
||||||
|
PrimitiveType = _xc.PrimitiveType
|
||||||
|
Shape = _xc.Shape
|
||||||
|
XlaBuilder = _xc.XlaBuilder
|
||||||
|
XlaComputation = _xc.XlaComputation
|
||||||
XlaRuntimeError = _xc.XlaRuntimeError
|
XlaRuntimeError = _xc.XlaRuntimeError
|
||||||
else:
|
else:
|
||||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||||
|
Loading…
x
Reference in New Issue
Block a user