mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Deprecate the contents of jax.lib.xla_extension.
PiperOrigin-RevId: 741145943
This commit is contained in:
parent
875e4795c4
commit
9932ff1f79
@ -22,9 +22,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
instead.
|
||||
* Implemented host callback handlers for CPU and GPU devices using XLA's FFI
|
||||
and removed existing CPU/GPU handlers using XLA's custom call.
|
||||
* All APIs in `jax.lib.xla_extension` are now deprecated.
|
||||
* Several previously-deprecated APIs have been removed, including:
|
||||
* From `jax.lib.xla_client`: `FftType`, `PaddingType`, `dtype_to_etype`,
|
||||
and `shape_from_pyval`.
|
||||
* From `jax.lib.xla_extension`: `ArrayImpl`, `XlaRuntimeError`.
|
||||
|
||||
## jax 0.5.3 (Mar 19, 2025)
|
||||
|
||||
|
@ -14,42 +14,122 @@
|
||||
|
||||
from jax._src.lib import xla_extension as _xe
|
||||
|
||||
get_distributed_runtime_client = _xe.get_distributed_runtime_client
|
||||
get_distributed_runtime_service = _xe.get_distributed_runtime_service
|
||||
hlo_module_cost_analysis = _xe.hlo_module_cost_analysis
|
||||
hlo_module_to_dot_graph = _xe.hlo_module_to_dot_graph
|
||||
ifrt_proxy = _xe.ifrt_proxy
|
||||
jax_jit = _xe.jax_jit
|
||||
mlir = _xe.mlir
|
||||
pmap_lib = _xe.pmap_lib
|
||||
profiler = _xe.profiler
|
||||
pytree = _xe.pytree
|
||||
Device = _xe.Device
|
||||
DistributedRuntimeClient = _xe.DistributedRuntimeClient
|
||||
HloModule = _xe.HloModule
|
||||
HloPrintOptions = _xe.HloPrintOptions
|
||||
OpSharding = _xe.OpSharding
|
||||
PjitFunctionCache = _xe.PjitFunctionCache
|
||||
PjitFunction = _xe.PjitFunction
|
||||
PmapFunction = _xe.PmapFunction
|
||||
|
||||
_deprecations = {
|
||||
# Added Nov 20 2024
|
||||
"ArrayImpl": (
|
||||
"jax.lib.xla_extension.ArrayImpl is deprecated; use jax.Array instead.",
|
||||
_xe.ArrayImpl,
|
||||
(
|
||||
"jax.lib.xla_extension.ArrayImpl has been removed; use jax.Array"
|
||||
" instead."
|
||||
),
|
||||
None,
|
||||
),
|
||||
"XlaRuntimeError": (
|
||||
"jax.lib.xla_extension.XlaRuntimeError is deprecated; use jax.errors.JaxRuntimeError instead.",
|
||||
_xe.XlaRuntimeError,
|
||||
(
|
||||
"jax.lib.xla_extension.XlaRuntimeError has been removed; use"
|
||||
" jax.errors.JaxRuntimeError instead."
|
||||
),
|
||||
None,
|
||||
),
|
||||
# Deprecated March 26 2025.
|
||||
"DistributedRuntimeClient": (
|
||||
(
|
||||
"jax.lib.xla_extension.DistributedRuntimeClient is"
|
||||
" deprecated; use jax.distributed instead."
|
||||
),
|
||||
_xe.DistributedRuntimeClient,
|
||||
),
|
||||
"get_distributed_runtime_client": (
|
||||
(
|
||||
"jax.lib.xla_extension.get_distributed_runtime_client is"
|
||||
" deprecated; use jax.distributed instead."
|
||||
),
|
||||
_xe.get_distributed_runtime_client,
|
||||
),
|
||||
"get_distributed_runtime_service": (
|
||||
(
|
||||
"jax.lib.xla_extension.get_distributed_runtime_service is"
|
||||
" deprecated; use jax.distributed instead."
|
||||
),
|
||||
_xe.get_distributed_runtime_service,
|
||||
),
|
||||
"Device": (
|
||||
"jax.lib.xla_extension.Device is deprecated; use jax.Device instead.",
|
||||
_xe.Device,
|
||||
),
|
||||
"PjitFunctionCache": (
|
||||
"jax.lib.xla_extension.PjitFunctionCache is deprecated.",
|
||||
_xe.PjitFunctionCache,
|
||||
),
|
||||
"ifrt_proxy": (
|
||||
"jax.lib.xla_extension.ifrt_proxy is deprecated.",
|
||||
_xe.ifrt_proxy,
|
||||
),
|
||||
"jax_jit": (
|
||||
"jax.lib.xla_extension.jax_jit is deprecated.",
|
||||
_xe.jax_jit,
|
||||
),
|
||||
"mlir": ("jax.lib.xla_extension.mlir is deprecated.", _xe.mlir),
|
||||
"pmap_lib": ("jax.lib.xla_extension.pmap_lib is deprecated.", _xe.pmap_lib),
|
||||
"profiler": (
|
||||
"jax.lib.xla_extension.profiler is deprecated.",
|
||||
_xe.profiler,
|
||||
),
|
||||
"pytree": (
|
||||
"jax.lib.xla_extension.pytree is deprecated.",
|
||||
_xe.pytree,
|
||||
),
|
||||
"hlo_module_cost_analysis": (
|
||||
"jax.lib.xla_extension.hlo_module_cost_analysis is deprecated.",
|
||||
_xe.hlo_module_cost_analysis,
|
||||
),
|
||||
"hlo_module_to_dot_graph": (
|
||||
"jax.lib.xla_extension.hlo_module_to_dot_graph is deprecated.",
|
||||
_xe.hlo_module_to_dot_graph,
|
||||
),
|
||||
"HloModule": (
|
||||
"jax.lib.xla_extension.HloModule is deprecated.",
|
||||
_xe.HloModule,
|
||||
),
|
||||
"HloPrintOptions": (
|
||||
"jax.lib.xla_extension.HloPrintOptions is deprecated.",
|
||||
_xe.HloPrintOptions,
|
||||
),
|
||||
"OpSharding": (
|
||||
"jax.lib.xla_extension.OpSharding is deprecated.",
|
||||
_xe.OpSharding,
|
||||
),
|
||||
"PjitFunction": (
|
||||
"jax.lib.xla_extension.PjitFunction is deprecated.",
|
||||
_xe.PjitFunction,
|
||||
),
|
||||
"PmapFunction": (
|
||||
"jax.lib.xla_extension.PmapFunction is deprecated.",
|
||||
_xe.PmapFunction,
|
||||
),
|
||||
}
|
||||
|
||||
import typing as _typing
|
||||
|
||||
if _typing.TYPE_CHECKING:
|
||||
ArrayImpl = _xe.ArrayImpl
|
||||
XlaRuntimeError = _xe.XlaRuntimeError
|
||||
Device = _xe.Device
|
||||
DistributedRuntimeClient = _xe.DistributedRuntimeClient
|
||||
HloModule = _xe.HloModule
|
||||
HloPrintOptions = _xe.HloPrintOptions
|
||||
OpSharding = _xe.OpSharding
|
||||
PjitFunction = _xe.PjitFunction
|
||||
PjitFunctionCache = _xe.PjitFunctionCache
|
||||
PmapFunction = _xe.PmapFunction
|
||||
|
||||
get_distributed_runtime_client = _xe.get_distributed_runtime_client
|
||||
get_distributed_runtime_service = _xe.get_distributed_runtime_service
|
||||
hlo_module_cost_analysis = _xe.hlo_module_cost_analysis
|
||||
hlo_module_to_dot_graph = _xe.hlo_module_to_dot_graph
|
||||
ifrt_proxy = _xe.ifrt_proxy
|
||||
jax_jit = _xe.jax_jit
|
||||
mlir = _xe.mlir
|
||||
pmap_lib = _xe.pmap_lib
|
||||
profiler = _xe.profiler
|
||||
pytree = _xe.pytree
|
||||
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user