mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Copybara import of the project:
-- 8cf6a6acd151007935b0c3093df05ef036bb0244 by Jake VanderPlas <jakevdp@google.com>: Remove several deprecated APIs COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16110 from jakevdp:deprecations 8cf6a6acd151007935b0c3093df05ef036bb0244 PiperOrigin-RevId: 534897394
This commit is contained in:
parent
1831b3cd95
commit
399e4ee87f
15
CHANGELOG.md
15
CHANGELOG.md
@ -8,6 +8,21 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jax 0.4.11
|
||||
|
||||
* Deprecations
|
||||
* The following APIs have been removed after a 3 month deprecation period, in
|
||||
accordance with the {ref}`api-compatibility` policy:
|
||||
- `jax.experimental.PartitionSpec`: use `jax.sharding.PartitionSpec`.
|
||||
- `jax.experimental.maps.Mesh`: use `jax.sharding.Mesh`
|
||||
- `jax.experimental.pjit.NamedSharding`: use `jax.sharding.NamedSharding`.
|
||||
- `jax.experimental.pjit.PartitionSpec`: use `jax.sharding.PartitionSpec`.
|
||||
- `jax.experimental.pjit.FROM_GDA`. Instead pass sharded `jax.Array` objects
|
||||
as input and remove the optional `in_shardings` argument to `pjit`.
|
||||
- `jax.interpreters.pxla.PartitionSpec`: use `jax.sharding.PartitionSpec`.
|
||||
- `jax.interpreters.pxla.Mesh`: use `jax.sharding.Mesh`
|
||||
- `jax.interpreters.xla.Device`: use `jax.Device`.
|
||||
- `jax.interpreters.xla.DeviceArray`: use `jax.Array` instead,
|
||||
|
||||
|
||||
## jaxlib 0.4.11
|
||||
|
||||
## jax 0.4.10 (May 11, 2023)
|
||||
|
@ -15,8 +15,6 @@
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||
|
||||
# TODO(https://github.com/google/jax/issues/13487): Remove PartitionSpec in
|
||||
# 3 months from `jax.experimental.PartitionSpec`.
|
||||
from jax.experimental.x64_context import (
|
||||
enable_x64 as enable_x64,
|
||||
disable_x64 as disable_x64,
|
||||
@ -24,29 +22,3 @@ from jax.experimental.x64_context import (
|
||||
from jax._src.callback import (
|
||||
io_callback as io_callback
|
||||
)
|
||||
|
||||
# Deprecations
|
||||
|
||||
from jax._src.partition_spec import (
|
||||
PartitionSpec as _deprecated_PartitionSpec,
|
||||
)
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
from jax._src.partition_spec import (
|
||||
PartitionSpec as PartitionSpec,
|
||||
)
|
||||
del typing
|
||||
|
||||
_deprecations = {
|
||||
# Added Feb 8, 2023:
|
||||
"PartitionSpec": (
|
||||
("jax.experimental.PartitionSpec is deprecated. Use "
|
||||
"jax.sharding.PartitionSpec."),
|
||||
_deprecated_PartitionSpec,
|
||||
),
|
||||
}
|
||||
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
|
@ -27,28 +27,3 @@ from jax._src.mesh import (
|
||||
ResourceEnv as ResourceEnv,
|
||||
thread_resources as thread_resources,
|
||||
)
|
||||
|
||||
# Deprecations
|
||||
|
||||
from jax._src.maps import (
|
||||
Mesh as _deprecated_Mesh,
|
||||
)
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
from jax._src.maps import (
|
||||
Mesh as Mesh,
|
||||
)
|
||||
del typing
|
||||
|
||||
_deprecations = {
|
||||
# Added Feb 13, 2023:
|
||||
"Mesh": (
|
||||
"jax.experimental.maps.Mesh is deprecated. Use jax.sharding.Mesh.",
|
||||
_deprecated_Mesh,
|
||||
),
|
||||
}
|
||||
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
|
@ -33,46 +33,3 @@ from jax._src.pjit import (_get_op_sharding_from_executable,
|
||||
_get_pspec_from_executable, _pjit_lower_cached,
|
||||
_pjit_lower, _pjit_jaxpr,
|
||||
_process_in_axis_resources)
|
||||
|
||||
from jax._src.sharding_impls import (
|
||||
NamedSharding as _deprecated_NamedSharding,
|
||||
)
|
||||
from jax._src.partition_spec import (
|
||||
PartitionSpec as _deprecated_PartitionSpec,
|
||||
)
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
from jax._src.sharding_impls import NamedSharding as NamedSharding
|
||||
from jax._src.partition_spec import PartitionSpec as PartitionSpec
|
||||
del typing
|
||||
|
||||
_deprecations = {
|
||||
# Added Feb 13, 2023:
|
||||
"NamedSharding": (
|
||||
(
|
||||
"jax.experimental.pjit.NamedSharding is deprecated. Use "
|
||||
"jax.sharding.NamedSharding."
|
||||
),
|
||||
_deprecated_NamedSharding,
|
||||
),
|
||||
"PartitionSpec": (
|
||||
(
|
||||
"jax.experimental.pjit.PartitionSpec is deprecated. Use "
|
||||
"jax.sharding.PartitionSpec."
|
||||
),
|
||||
_deprecated_PartitionSpec,
|
||||
),
|
||||
"FROM_GDA": (
|
||||
(
|
||||
"jax.experimental.pjit.FROM_GDA has been removed. Please pass in"
|
||||
" sharded jax.Arrays as input and remove the in_shardings argument"
|
||||
" to pjit since pjit will infer the shardings from jax.Array."
|
||||
),
|
||||
None,
|
||||
),
|
||||
}
|
||||
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
|
@ -114,29 +114,12 @@ from jax._src.sharding_specs import (
|
||||
|
||||
# Deprecations
|
||||
|
||||
from jax._src.mesh import Mesh as _deprecated_Mesh
|
||||
from jax._src.interpreters.pxla import (
|
||||
ShardedDeviceArray as _deprecated_ShardedDeviceArray,
|
||||
make_sharded_device_array as _deprecated_make_sharded_device_array,
|
||||
)
|
||||
from jax._src.partition_spec import (
|
||||
PartitionSpec as _deprecated_PartitionSpec,
|
||||
)
|
||||
|
||||
_deprecations = {
|
||||
# Added Feb 8, 2023:
|
||||
"Mesh": (
|
||||
"jax.interpreters.pxla.Mesh is deprecated. Use jax.sharding.Mesh.",
|
||||
_deprecated_Mesh,
|
||||
),
|
||||
# Added Feb 8, 2023:
|
||||
"PartitionSpec": (
|
||||
(
|
||||
"jax.interpreters.pxla.PartitionSpec is deprecated. Use "
|
||||
"jax.sharding.PartitionSpec."
|
||||
),
|
||||
_deprecated_PartitionSpec,
|
||||
),
|
||||
# Added March 15, 2023:
|
||||
"ShardedDeviceArray": (
|
||||
(
|
||||
@ -165,13 +148,11 @@ _deprecations = {
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
from jax._src.mesh import Mesh as Mesh
|
||||
from jax._src.interpreters.pxla import (
|
||||
ShardedDeviceArray as ShardedDeviceArray,
|
||||
device_put as device_put,
|
||||
make_sharded_device_array as make_sharded_device_array,
|
||||
)
|
||||
from jax._src.partition_spec import PartitionSpec as PartitionSpec
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
|
@ -62,17 +62,6 @@ Buffer = _deprecated_DeviceArray
|
||||
|
||||
_deprecations = {
|
||||
# Added Feb 9, 2023:
|
||||
"Device": (
|
||||
"jax.interpreters.xla.Device is deprecated. Use jax.Device instead.",
|
||||
_deprecated_Device,
|
||||
),
|
||||
"DeviceArray": (
|
||||
(
|
||||
"jax.interpreters.xla.DeviceArray is deprecated. Use jax.Array"
|
||||
" instead."
|
||||
),
|
||||
_deprecated_DeviceArray,
|
||||
),
|
||||
"Buffer": (
|
||||
(
|
||||
"jax.interpreters.xla.Buffer is deprecated. Use jax.Array"
|
||||
@ -102,10 +91,6 @@ del _deprecation_getattr
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
Device = xc.Device
|
||||
from jax._src.interpreters.xla import (
|
||||
DeviceArray as DeviceArray,
|
||||
)
|
||||
from jax._src.api import device_put as device_put
|
||||
from jax._src.interpreters.xla import xla_call_p as xla_call_p
|
||||
del typing
|
||||
|
Loading…
x
Reference in New Issue
Block a user