Remove several deprecated APIs

This commit is contained in:
Jake VanderPlas 2023-07-11 12:42:32 -07:00
parent a29d4bcd33
commit 21f6736005
5 changed files with 15 additions and 94 deletions

View File

@ -13,11 +13,6 @@ Remember to align the itemized text with the first line of an item within a list
https://jax.readthedocs.io/en/latest/deprecation.html https://jax.readthedocs.io/en/latest/deprecation.html
* JAX now requires NumPy 1.22 or newer as per * JAX now requires NumPy 1.22 or newer as per
https://jax.readthedocs.io/en/latest/deprecation.html https://jax.readthedocs.io/en/latest/deprecation.html
* `jax.interpreters.pxla.device_put` has been removed. This was deprecated in
JAX version 0.4.6: use `jax.device_put` instead.
* `jax.interpreters.pxla.make_sharded_device_array` has been removed. This was
deprecated in JAX version 0.4.6: use `jax.make_array_from_single_device_arrays`
instead.
* Passing optional arguments to {func}`jax.numpy.ndarray.at` by position is * Passing optional arguments to {func}`jax.numpy.ndarray.at` by position is
no longer supported, after being deprecated in JAX version 0.4.7. no longer supported, after being deprecated in JAX version 0.4.7.
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)` For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
@ -26,6 +21,17 @@ Remember to align the itemized text with the first line of an item within a list
* `jax.Array.broadcast`: use {func}`jax.lax.broadcast` instead. * `jax.Array.broadcast`: use {func}`jax.lax.broadcast` instead.
* `jax.Array.broadcast_in_dim`: use {func}`jax.lax.broadcast_in_dim` instead. * `jax.Array.broadcast_in_dim`: use {func}`jax.lax.broadcast_in_dim` instead.
* `jax.Array.split`: use {func}`jax.numpy.split` instead. * `jax.Array.split`: use {func}`jax.numpy.split` instead.
* The following APIs have been removed after previous deprecation:
* `jax.ad`: use {mod}`jax.interpreters.ad`.
* `jax.curry`: use ``curry = lambda f: partial(partial, f)``.
* `jax.partial_eval`: use {mod}`jax.interpreters.partial_eval`.
* `jax.pxla`: use {mod}`jax.interpreters.pxla`.
* `jax.xla`: use {mod}`jax.interpreters.xla`.
* `jax.ShapedArray`: use {class}`jax.core.ShapedArray`.
* `jax.interpreters.pxla.device_put`: use {func}`jax.device_put`.
* `jax.interpreters.pxla.make_sharded_device_array`: use {func}`jax.make_array_from_single_device_arrays`.
* `jax.interpreters.pxla.ShardedDeviceArray`: use {class}`jax.Array`.
* `jax.numpy.DeviceArray`: use {class}`jax.Array`.
* Breaking changes * Breaking changes
* JAX now requires ml_dtypes version 0.2.0 or newer. * JAX now requires ml_dtypes version 0.2.0 or newer.

View File

@ -82,7 +82,6 @@ from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies
from jax._src.api import clear_backends as clear_backends from jax._src.api import clear_backends as clear_backends
from jax._src.api import clear_caches as clear_caches from jax._src.api import clear_caches as clear_caches
from jax._src.custom_derivatives import closure_convert as closure_convert from jax._src.custom_derivatives import closure_convert as closure_convert
from jax._src.util import curry as _deprecated_curry
from jax._src.custom_derivatives import custom_gradient as custom_gradient from jax._src.custom_derivatives import custom_gradient as custom_gradient
from jax._src.custom_derivatives import custom_jvp as custom_jvp from jax._src.custom_derivatives import custom_jvp as custom_jvp
from jax._src.custom_derivatives import custom_vjp as custom_vjp from jax._src.custom_derivatives import custom_vjp as custom_vjp
@ -95,7 +94,6 @@ from jax._src.api import device_put_replicated as device_put_replicated
from jax._src.xla_bridge import devices as devices from jax._src.xla_bridge import devices as devices
from jax._src.api import disable_jit as disable_jit from jax._src.api import disable_jit as disable_jit
from jax._src.api import eval_shape as eval_shape from jax._src.api import eval_shape as eval_shape
from jax._src.api_util import flatten_fun_nokwargs as _deprecated_flatten_fun_nokwargs
from jax._src.dtypes import float0 as float0 from jax._src.dtypes import float0 as float0
from jax._src.api import grad as grad from jax._src.api import grad as grad
from jax._src.api import hessian as hessian from jax._src.api import hessian as hessian
@ -120,19 +118,15 @@ from jax._src.xla_bridge import process_count as process_count
from jax._src.xla_bridge import process_index as process_index from jax._src.xla_bridge import process_index as process_index
from jax._src.callback import pure_callback_api as pure_callback from jax._src.callback import pure_callback_api as pure_callback
from jax._src.ad_checkpoint import checkpoint_wrapper as remat from jax._src.ad_checkpoint import checkpoint_wrapper as remat
from jax._src.core import ShapedArray as _deprecated_ShapedArray
from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct
from jax._src.api import value_and_grad as value_and_grad from jax._src.api import value_and_grad as value_and_grad
from jax._src.api import vjp as vjp from jax._src.api import vjp as vjp
from jax._src.api import vmap as vmap from jax._src.api import vmap as vmap
from jax._src.api import xla_computation as xla_computation from jax._src.api import xla_computation as xla_computation
from jax.interpreters import ad as _deprecated_ad # Force import, allowing jax.interpreters.* to be used after import jax.
import jax.interpreters.batching from jax.interpreters import ad, batching, mlir, partial_eval, pxla, xla
import jax.interpreters.mlir del ad, batching, mlir, partial_eval, pxla, xla
from jax.interpreters import partial_eval as _deprecated_partial_eval
from jax.interpreters import pxla as _deprecated_pxla
from jax.interpreters import xla as _deprecated_xla
from jax._src.array import ( from jax._src.array import (
make_array_from_single_device_arrays as make_array_from_single_device_arrays, make_array_from_single_device_arrays as make_array_from_single_device_arrays,
@ -191,47 +185,11 @@ _deprecations = {
"jax.abstract_arrays is deprecated. Refer to jax.core.", "jax.abstract_arrays is deprecated. Refer to jax.core.",
_deprecated_abstract_arrays _deprecated_abstract_arrays
), ),
# Added 28 March 2023
"ShapedArray": (
"jax.ShapedArray is deprecated. Use jax.core.ShapedArray",
_deprecated_ShapedArray,
),
"ad": (
"jax.ad is deprecated. Use jax.interpreters.ad",
_deprecated_ad,
),
"partial_eval": (
"jax.partial_eval is deprecated. Use jax.interpreters.partial_eval",
_deprecated_partial_eval,
),
"pxla": (
"jax.pxla is deprecated. Use jax.interpreters.pxla",
_deprecated_pxla,
),
"xla": (
"jax.xla is deprecated. Use jax.interpreters.xla",
_deprecated_xla,
),
"curry": (
"jax.curry is deprecated. Use curry = lambda f: partial(partial, f)",
_deprecated_curry,
),
"flatten_fun_nokwargs": (
"jax.flatten_fun_nokwargs is deprecated. Use jax.api_util.flatten_fun_nokwargs.",
_deprecated_flatten_fun_nokwargs,
),
} }
import typing as _typing import typing as _typing
if _typing.TYPE_CHECKING: if _typing.TYPE_CHECKING:
from jax._src import abstract_arrays as abstract_arrays from jax._src import abstract_arrays as abstract_arrays
from jax._src.core import ShapedArray as ShapedArray
from jax.interpreters import ad as ad
from jax.interpreters import partial_eval as partial_eval
from jax.interpreters import pxla as pxla
from jax.interpreters import xla as xla
from jax._src.util import curry as curry
from jax._src.api_util import flatten_fun_nokwargs as flatten_fun_nokwargs
else: else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations) __getattr__ = _deprecation_getattr(__name__, _deprecations)

View File

@ -24,7 +24,7 @@ import itertools as it
import logging import logging
import math import math
from typing import (Any, Callable, NamedTuple, Optional, Sequence, Union, from typing import (Any, Callable, NamedTuple, Optional, Sequence, Union,
Iterable, TYPE_CHECKING, cast, TypeVar) Iterable, cast, TypeVar)
import numpy as np import numpy as np
@ -331,15 +331,6 @@ def make_sharded_device_array(
aval.shape, sharding, device_buffers) # type: ignore aval.shape, sharding, device_buffers) # type: ignore
if TYPE_CHECKING:
ShardedDeviceArray = Any
else:
class ShardedDeviceArray(object):
def __init__(self):
raise RuntimeError("ShardedDeviceArray is a backward compatibility shim "
"and cannot be instantiated.")
def _hashable_index(idx): def _hashable_index(idx):
return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x, idx) return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x, idx)

View File

@ -110,31 +110,3 @@ from jax._src.sharding_specs import (
sharding_spec_sharding_proto as sharding_spec_sharding_proto, sharding_spec_sharding_proto as sharding_spec_sharding_proto,
spec_to_indices as spec_to_indices, spec_to_indices as spec_to_indices,
) )
# Deprecations
from jax._src.interpreters.pxla import (
ShardedDeviceArray as _deprecated_ShardedDeviceArray,
)
_deprecations = {
# Added March 15, 2023:
"ShardedDeviceArray": (
(
"jax.interpreters.pxla.ShardedDeviceArray is deprecated. Use "
"jax.Array."
),
_deprecated_ShardedDeviceArray,
),
}
import typing
if typing.TYPE_CHECKING:
from jax._src.interpreters.pxla import (
ShardedDeviceArray as ShardedDeviceArray,
)
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing

View File

@ -425,11 +425,6 @@ del register_jax_array_methods
# Deprecations # Deprecations
_deprecations = { _deprecations = {
# Added March 14, 2023:
"DeviceArray": (
"jax.numpy.DeviceArray is deprecated. Use jax.Array.",
ndarray,
),
# Added June 2, 2023: # Added June 2, 2023:
"alltrue": ( "alltrue": (
"jax.numpy.alltrue is deprecated. Use jax.numpy.all", "jax.numpy.alltrue is deprecated. Use jax.numpy.all",
@ -451,7 +446,6 @@ _deprecations = {
import typing import typing
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from jax._src.basearray import Array as DeviceArray
alltrue = all alltrue = all
cumproduct = cumprod cumproduct = cumprod
product = prod product = prod