diff --git a/CHANGELOG.md b/CHANGELOG.md index aa62a450e..7b0e516e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,21 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.11 +* Deprecations + * The following APIs have been removed after a 3 month deprecation period, in + accordance with the {ref}`api-compatibility` policy: + - `jax.experimental.PartitionSpec`: use `jax.sharding.PartitionSpec`. + - `jax.experimental.maps.Mesh`: use `jax.sharding.Mesh` + - `jax.experimental.pjit.NamedSharding`: use `jax.sharding.NamedSharding`. + - `jax.experimental.pjit.PartitionSpec`: use `jax.sharding.PartitionSpec`. + - `jax.experimental.pjit.FROM_GDA`. Instead pass sharded `jax.Array` objects + as input and remove the optional `in_shardings` argument to `pjit`. + - `jax.interpreters.pxla.PartitionSpec`: use `jax.sharding.PartitionSpec`. + - `jax.interpreters.pxla.Mesh`: use `jax.sharding.Mesh` + - `jax.interpreters.xla.Device`: use `jax.Device`. + - `jax.interpreters.xla.DeviceArray`: use `jax.Array` instead, + + ## jaxlib 0.4.11 ## jax 0.4.10 (May 11, 2023) diff --git a/jax/experimental/__init__.py b/jax/experimental/__init__.py index deef5f8f6..2addcc29b 100644 --- a/jax/experimental/__init__.py +++ b/jax/experimental/__init__.py @@ -15,8 +15,6 @@ # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/google/jax/issues/7570 -# TODO(https://github.com/google/jax/issues/13487): Remove PartitionSpec in -# 3 months from `jax.experimental.PartitionSpec`. from jax.experimental.x64_context import ( enable_x64 as enable_x64, disable_x64 as disable_x64, @@ -24,29 +22,3 @@ from jax.experimental.x64_context import ( from jax._src.callback import ( io_callback as io_callback ) - -# Deprecations - -from jax._src.partition_spec import ( - PartitionSpec as _deprecated_PartitionSpec, -) - -import typing -if typing.TYPE_CHECKING: - from jax._src.partition_spec import ( - PartitionSpec as PartitionSpec, - ) -del typing - -_deprecations = { - # Added Feb 8, 2023: - "PartitionSpec": ( - ("jax.experimental.PartitionSpec is deprecated. Use " - "jax.sharding.PartitionSpec."), - _deprecated_PartitionSpec, - ), -} - -from jax._src.deprecations import deprecation_getattr as _deprecation_getattr -__getattr__ = _deprecation_getattr(__name__, _deprecations) -del _deprecation_getattr diff --git a/jax/experimental/maps.py b/jax/experimental/maps.py index 32425aa91..afdd11dbb 100644 --- a/jax/experimental/maps.py +++ b/jax/experimental/maps.py @@ -27,28 +27,3 @@ from jax._src.mesh import ( ResourceEnv as ResourceEnv, thread_resources as thread_resources, ) - -# Deprecations - -from jax._src.maps import ( - Mesh as _deprecated_Mesh, -) - -import typing -if typing.TYPE_CHECKING: - from jax._src.maps import ( - Mesh as Mesh, - ) -del typing - -_deprecations = { - # Added Feb 13, 2023: - "Mesh": ( - "jax.experimental.maps.Mesh is deprecated. Use jax.sharding.Mesh.", - _deprecated_Mesh, - ), -} - -from jax._src.deprecations import deprecation_getattr as _deprecation_getattr -__getattr__ = _deprecation_getattr(__name__, _deprecations) -del _deprecation_getattr diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index a2b8cb5cc..8b08342da 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -33,46 +33,3 @@ from jax._src.pjit import (_get_op_sharding_from_executable, _get_pspec_from_executable, _pjit_lower_cached, _pjit_lower, _pjit_jaxpr, _process_in_axis_resources) - -from jax._src.sharding_impls import ( - NamedSharding as _deprecated_NamedSharding, -) -from jax._src.partition_spec import ( - PartitionSpec as _deprecated_PartitionSpec, -) - -import typing -if typing.TYPE_CHECKING: - from jax._src.sharding_impls import NamedSharding as NamedSharding - from jax._src.partition_spec import PartitionSpec as PartitionSpec -del typing - -_deprecations = { - # Added Feb 13, 2023: - "NamedSharding": ( - ( - "jax.experimental.pjit.NamedSharding is deprecated. Use " - "jax.sharding.NamedSharding." - ), - _deprecated_NamedSharding, - ), - "PartitionSpec": ( - ( - "jax.experimental.pjit.PartitionSpec is deprecated. Use " - "jax.sharding.PartitionSpec." - ), - _deprecated_PartitionSpec, - ), - "FROM_GDA": ( - ( - "jax.experimental.pjit.FROM_GDA has been removed. Please pass in" - " sharded jax.Arrays as input and remove the in_shardings argument" - " to pjit since pjit will infer the shardings from jax.Array." - ), - None, - ), -} - -from jax._src.deprecations import deprecation_getattr as _deprecation_getattr -__getattr__ = _deprecation_getattr(__name__, _deprecations) -del _deprecation_getattr diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 60968541e..54f627952 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -114,29 +114,12 @@ from jax._src.sharding_specs import ( # Deprecations -from jax._src.mesh import Mesh as _deprecated_Mesh from jax._src.interpreters.pxla import ( ShardedDeviceArray as _deprecated_ShardedDeviceArray, make_sharded_device_array as _deprecated_make_sharded_device_array, ) -from jax._src.partition_spec import ( - PartitionSpec as _deprecated_PartitionSpec, -) _deprecations = { - # Added Feb 8, 2023: - "Mesh": ( - "jax.interpreters.pxla.Mesh is deprecated. Use jax.sharding.Mesh.", - _deprecated_Mesh, - ), - # Added Feb 8, 2023: - "PartitionSpec": ( - ( - "jax.interpreters.pxla.PartitionSpec is deprecated. Use " - "jax.sharding.PartitionSpec." - ), - _deprecated_PartitionSpec, - ), # Added March 15, 2023: "ShardedDeviceArray": ( ( @@ -165,13 +148,11 @@ _deprecations = { import typing if typing.TYPE_CHECKING: - from jax._src.mesh import Mesh as Mesh from jax._src.interpreters.pxla import ( ShardedDeviceArray as ShardedDeviceArray, device_put as device_put, make_sharded_device_array as make_sharded_device_array, ) - from jax._src.partition_spec import PartitionSpec as PartitionSpec else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index e751a8318..06d0c81e0 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -62,17 +62,6 @@ Buffer = _deprecated_DeviceArray _deprecations = { # Added Feb 9, 2023: - "Device": ( - "jax.interpreters.xla.Device is deprecated. Use jax.Device instead.", - _deprecated_Device, - ), - "DeviceArray": ( - ( - "jax.interpreters.xla.DeviceArray is deprecated. Use jax.Array" - " instead." - ), - _deprecated_DeviceArray, - ), "Buffer": ( ( "jax.interpreters.xla.Buffer is deprecated. Use jax.Array" @@ -102,10 +91,6 @@ del _deprecation_getattr import typing if typing.TYPE_CHECKING: - Device = xc.Device - from jax._src.interpreters.xla import ( - DeviceArray as DeviceArray, - ) from jax._src.api import device_put as device_put from jax._src.interpreters.xla import xla_call_p as xla_call_p del typing