Copybara import of the project:

--
8cf6a6acd151007935b0c3093df05ef036bb0244 by Jake VanderPlas <jakevdp@google.com>:

Remove several deprecated APIs

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16110 from jakevdp:deprecations 8cf6a6acd151007935b0c3093df05ef036bb0244
PiperOrigin-RevId: 534897394
This commit is contained in:
Jake Vanderplas 2023-05-24 10:35:01 -07:00 committed by jax authors
parent 1831b3cd95
commit 399e4ee87f
6 changed files with 15 additions and 130 deletions

View File

@ -8,6 +8,21 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.11
* Deprecations
* The following APIs have been removed after a 3 month deprecation period, in
accordance with the {ref}`api-compatibility` policy:
- `jax.experimental.PartitionSpec`: use `jax.sharding.PartitionSpec`.
- `jax.experimental.maps.Mesh`: use `jax.sharding.Mesh`
- `jax.experimental.pjit.NamedSharding`: use `jax.sharding.NamedSharding`.
- `jax.experimental.pjit.PartitionSpec`: use `jax.sharding.PartitionSpec`.
- `jax.experimental.pjit.FROM_GDA`. Instead pass sharded `jax.Array` objects
as input and remove the optional `in_shardings` argument to `pjit`.
- `jax.interpreters.pxla.PartitionSpec`: use `jax.sharding.PartitionSpec`.
- `jax.interpreters.pxla.Mesh`: use `jax.sharding.Mesh`
- `jax.interpreters.xla.Device`: use `jax.Device`.
- `jax.interpreters.xla.DeviceArray`: use `jax.Array` instead,
## jaxlib 0.4.11
## jax 0.4.10 (May 11, 2023)

View File

@ -15,8 +15,6 @@
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570
# TODO(https://github.com/google/jax/issues/13487): Remove PartitionSpec in
# 3 months from `jax.experimental.PartitionSpec`.
from jax.experimental.x64_context import (
enable_x64 as enable_x64,
disable_x64 as disable_x64,
@ -24,29 +22,3 @@ from jax.experimental.x64_context import (
from jax._src.callback import (
io_callback as io_callback
)
# Deprecations
from jax._src.partition_spec import (
PartitionSpec as _deprecated_PartitionSpec,
)
import typing
if typing.TYPE_CHECKING:
from jax._src.partition_spec import (
PartitionSpec as PartitionSpec,
)
del typing
_deprecations = {
# Added Feb 8, 2023:
"PartitionSpec": (
("jax.experimental.PartitionSpec is deprecated. Use "
"jax.sharding.PartitionSpec."),
_deprecated_PartitionSpec,
),
}
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr

View File

@ -27,28 +27,3 @@ from jax._src.mesh import (
ResourceEnv as ResourceEnv,
thread_resources as thread_resources,
)
# Deprecations
from jax._src.maps import (
Mesh as _deprecated_Mesh,
)
import typing
if typing.TYPE_CHECKING:
from jax._src.maps import (
Mesh as Mesh,
)
del typing
_deprecations = {
# Added Feb 13, 2023:
"Mesh": (
"jax.experimental.maps.Mesh is deprecated. Use jax.sharding.Mesh.",
_deprecated_Mesh,
),
}
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr

View File

@ -33,46 +33,3 @@ from jax._src.pjit import (_get_op_sharding_from_executable,
_get_pspec_from_executable, _pjit_lower_cached,
_pjit_lower, _pjit_jaxpr,
_process_in_axis_resources)
from jax._src.sharding_impls import (
NamedSharding as _deprecated_NamedSharding,
)
from jax._src.partition_spec import (
PartitionSpec as _deprecated_PartitionSpec,
)
import typing
if typing.TYPE_CHECKING:
from jax._src.sharding_impls import NamedSharding as NamedSharding
from jax._src.partition_spec import PartitionSpec as PartitionSpec
del typing
_deprecations = {
# Added Feb 13, 2023:
"NamedSharding": (
(
"jax.experimental.pjit.NamedSharding is deprecated. Use "
"jax.sharding.NamedSharding."
),
_deprecated_NamedSharding,
),
"PartitionSpec": (
(
"jax.experimental.pjit.PartitionSpec is deprecated. Use "
"jax.sharding.PartitionSpec."
),
_deprecated_PartitionSpec,
),
"FROM_GDA": (
(
"jax.experimental.pjit.FROM_GDA has been removed. Please pass in"
" sharded jax.Arrays as input and remove the in_shardings argument"
" to pjit since pjit will infer the shardings from jax.Array."
),
None,
),
}
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr

View File

@ -114,29 +114,12 @@ from jax._src.sharding_specs import (
# Deprecations
from jax._src.mesh import Mesh as _deprecated_Mesh
from jax._src.interpreters.pxla import (
ShardedDeviceArray as _deprecated_ShardedDeviceArray,
make_sharded_device_array as _deprecated_make_sharded_device_array,
)
from jax._src.partition_spec import (
PartitionSpec as _deprecated_PartitionSpec,
)
_deprecations = {
# Added Feb 8, 2023:
"Mesh": (
"jax.interpreters.pxla.Mesh is deprecated. Use jax.sharding.Mesh.",
_deprecated_Mesh,
),
# Added Feb 8, 2023:
"PartitionSpec": (
(
"jax.interpreters.pxla.PartitionSpec is deprecated. Use "
"jax.sharding.PartitionSpec."
),
_deprecated_PartitionSpec,
),
# Added March 15, 2023:
"ShardedDeviceArray": (
(
@ -165,13 +148,11 @@ _deprecations = {
import typing
if typing.TYPE_CHECKING:
from jax._src.mesh import Mesh as Mesh
from jax._src.interpreters.pxla import (
ShardedDeviceArray as ShardedDeviceArray,
device_put as device_put,
make_sharded_device_array as make_sharded_device_array,
)
from jax._src.partition_spec import PartitionSpec as PartitionSpec
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)

View File

@ -62,17 +62,6 @@ Buffer = _deprecated_DeviceArray
_deprecations = {
# Added Feb 9, 2023:
"Device": (
"jax.interpreters.xla.Device is deprecated. Use jax.Device instead.",
_deprecated_Device,
),
"DeviceArray": (
(
"jax.interpreters.xla.DeviceArray is deprecated. Use jax.Array"
" instead."
),
_deprecated_DeviceArray,
),
"Buffer": (
(
"jax.interpreters.xla.Buffer is deprecated. Use jax.Array"
@ -102,10 +91,6 @@ del _deprecation_getattr
import typing
if typing.TYPE_CHECKING:
Device = xc.Device
from jax._src.interpreters.xla import (
DeviceArray as DeviceArray,
)
from jax._src.api import device_put as device_put
from jax._src.interpreters.xla import xla_call_p as xla_call_p
del typing