mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove several deprecated APIs
This commit is contained in:
parent
a29d4bcd33
commit
21f6736005
16
CHANGELOG.md
16
CHANGELOG.md
@ -13,11 +13,6 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
https://jax.readthedocs.io/en/latest/deprecation.html
|
||||
* JAX now requires NumPy 1.22 or newer as per
|
||||
https://jax.readthedocs.io/en/latest/deprecation.html
|
||||
* `jax.interpreters.pxla.device_put` has been removed. This was deprecated in
|
||||
JAX version 0.4.6: use `jax.device_put` instead.
|
||||
* `jax.interpreters.pxla.make_sharded_device_array` has been removed. This was
|
||||
deprecated in JAX version 0.4.6: use `jax.make_array_from_single_device_arrays`
|
||||
instead.
|
||||
* Passing optional arguments to {func}`jax.numpy.ndarray.at` by position is
|
||||
no longer supported, after being deprecated in JAX version 0.4.7.
|
||||
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
|
||||
@ -26,6 +21,17 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* `jax.Array.broadcast`: use {func}`jax.lax.broadcast` instead.
|
||||
* `jax.Array.broadcast_in_dim`: use {func}`jax.lax.broadcast_in_dim` instead.
|
||||
* `jax.Array.split`: use {func}`jax.numpy.split` instead.
|
||||
* The following APIs have been removed after previous deprecation:
|
||||
* `jax.ad`: use {mod}`jax.interpreters.ad`.
|
||||
* `jax.curry`: use ``curry = lambda f: partial(partial, f)``.
|
||||
* `jax.partial_eval`: use {mod}`jax.interpreters.partial_eval`.
|
||||
* `jax.pxla`: use {mod}`jax.interpreters.pxla`.
|
||||
* `jax.xla`: use {mod}`jax.interpreters.xla`.
|
||||
* `jax.ShapedArray`: use {class}`jax.core.ShapedArray`.
|
||||
* `jax.interpreters.pxla.device_put`: use {func}`jax.device_put`.
|
||||
* `jax.interpreters.pxla.make_sharded_device_array`: use {func}`jax.make_array_from_single_device_arrays`.
|
||||
* `jax.interpreters.pxla.ShardedDeviceArray`: use {class}`jax.Array`.
|
||||
* `jax.numpy.DeviceArray`: use {class}`jax.Array`.
|
||||
|
||||
* Breaking changes
|
||||
* JAX now requires ml_dtypes version 0.2.0 or newer.
|
||||
|
@ -82,7 +82,6 @@ from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies
|
||||
from jax._src.api import clear_backends as clear_backends
|
||||
from jax._src.api import clear_caches as clear_caches
|
||||
from jax._src.custom_derivatives import closure_convert as closure_convert
|
||||
from jax._src.util import curry as _deprecated_curry
|
||||
from jax._src.custom_derivatives import custom_gradient as custom_gradient
|
||||
from jax._src.custom_derivatives import custom_jvp as custom_jvp
|
||||
from jax._src.custom_derivatives import custom_vjp as custom_vjp
|
||||
@ -95,7 +94,6 @@ from jax._src.api import device_put_replicated as device_put_replicated
|
||||
from jax._src.xla_bridge import devices as devices
|
||||
from jax._src.api import disable_jit as disable_jit
|
||||
from jax._src.api import eval_shape as eval_shape
|
||||
from jax._src.api_util import flatten_fun_nokwargs as _deprecated_flatten_fun_nokwargs
|
||||
from jax._src.dtypes import float0 as float0
|
||||
from jax._src.api import grad as grad
|
||||
from jax._src.api import hessian as hessian
|
||||
@ -120,19 +118,15 @@ from jax._src.xla_bridge import process_count as process_count
|
||||
from jax._src.xla_bridge import process_index as process_index
|
||||
from jax._src.callback import pure_callback_api as pure_callback
|
||||
from jax._src.ad_checkpoint import checkpoint_wrapper as remat
|
||||
from jax._src.core import ShapedArray as _deprecated_ShapedArray
|
||||
from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct
|
||||
from jax._src.api import value_and_grad as value_and_grad
|
||||
from jax._src.api import vjp as vjp
|
||||
from jax._src.api import vmap as vmap
|
||||
from jax._src.api import xla_computation as xla_computation
|
||||
|
||||
from jax.interpreters import ad as _deprecated_ad
|
||||
import jax.interpreters.batching
|
||||
import jax.interpreters.mlir
|
||||
from jax.interpreters import partial_eval as _deprecated_partial_eval
|
||||
from jax.interpreters import pxla as _deprecated_pxla
|
||||
from jax.interpreters import xla as _deprecated_xla
|
||||
# Force import, allowing jax.interpreters.* to be used after import jax.
|
||||
from jax.interpreters import ad, batching, mlir, partial_eval, pxla, xla
|
||||
del ad, batching, mlir, partial_eval, pxla, xla
|
||||
|
||||
from jax._src.array import (
|
||||
make_array_from_single_device_arrays as make_array_from_single_device_arrays,
|
||||
@ -191,47 +185,11 @@ _deprecations = {
|
||||
"jax.abstract_arrays is deprecated. Refer to jax.core.",
|
||||
_deprecated_abstract_arrays
|
||||
),
|
||||
# Added 28 March 2023
|
||||
"ShapedArray": (
|
||||
"jax.ShapedArray is deprecated. Use jax.core.ShapedArray",
|
||||
_deprecated_ShapedArray,
|
||||
),
|
||||
"ad": (
|
||||
"jax.ad is deprecated. Use jax.interpreters.ad",
|
||||
_deprecated_ad,
|
||||
),
|
||||
"partial_eval": (
|
||||
"jax.partial_eval is deprecated. Use jax.interpreters.partial_eval",
|
||||
_deprecated_partial_eval,
|
||||
),
|
||||
"pxla": (
|
||||
"jax.pxla is deprecated. Use jax.interpreters.pxla",
|
||||
_deprecated_pxla,
|
||||
),
|
||||
"xla": (
|
||||
"jax.xla is deprecated. Use jax.interpreters.xla",
|
||||
_deprecated_xla,
|
||||
),
|
||||
"curry": (
|
||||
"jax.curry is deprecated. Use curry = lambda f: partial(partial, f)",
|
||||
_deprecated_curry,
|
||||
),
|
||||
"flatten_fun_nokwargs": (
|
||||
"jax.flatten_fun_nokwargs is deprecated. Use jax.api_util.flatten_fun_nokwargs.",
|
||||
_deprecated_flatten_fun_nokwargs,
|
||||
),
|
||||
}
|
||||
|
||||
import typing as _typing
|
||||
if _typing.TYPE_CHECKING:
|
||||
from jax._src import abstract_arrays as abstract_arrays
|
||||
from jax._src.core import ShapedArray as ShapedArray
|
||||
from jax.interpreters import ad as ad
|
||||
from jax.interpreters import partial_eval as partial_eval
|
||||
from jax.interpreters import pxla as pxla
|
||||
from jax.interpreters import xla as xla
|
||||
from jax._src.util import curry as curry
|
||||
from jax._src.api_util import flatten_fun_nokwargs as flatten_fun_nokwargs
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
|
@ -24,7 +24,7 @@ import itertools as it
|
||||
import logging
|
||||
import math
|
||||
from typing import (Any, Callable, NamedTuple, Optional, Sequence, Union,
|
||||
Iterable, TYPE_CHECKING, cast, TypeVar)
|
||||
Iterable, cast, TypeVar)
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -331,15 +331,6 @@ def make_sharded_device_array(
|
||||
aval.shape, sharding, device_buffers) # type: ignore
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
ShardedDeviceArray = Any
|
||||
else:
|
||||
class ShardedDeviceArray(object):
|
||||
def __init__(self):
|
||||
raise RuntimeError("ShardedDeviceArray is a backward compatibility shim "
|
||||
"and cannot be instantiated.")
|
||||
|
||||
|
||||
def _hashable_index(idx):
|
||||
return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x, idx)
|
||||
|
||||
|
@ -110,31 +110,3 @@ from jax._src.sharding_specs import (
|
||||
sharding_spec_sharding_proto as sharding_spec_sharding_proto,
|
||||
spec_to_indices as spec_to_indices,
|
||||
)
|
||||
|
||||
# Deprecations
|
||||
|
||||
from jax._src.interpreters.pxla import (
|
||||
ShardedDeviceArray as _deprecated_ShardedDeviceArray,
|
||||
)
|
||||
|
||||
_deprecations = {
|
||||
# Added March 15, 2023:
|
||||
"ShardedDeviceArray": (
|
||||
(
|
||||
"jax.interpreters.pxla.ShardedDeviceArray is deprecated. Use "
|
||||
"jax.Array."
|
||||
),
|
||||
_deprecated_ShardedDeviceArray,
|
||||
),
|
||||
}
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
from jax._src.interpreters.pxla import (
|
||||
ShardedDeviceArray as ShardedDeviceArray,
|
||||
)
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del typing
|
||||
|
@ -425,11 +425,6 @@ del register_jax_array_methods
|
||||
# Deprecations
|
||||
|
||||
_deprecations = {
|
||||
# Added March 14, 2023:
|
||||
"DeviceArray": (
|
||||
"jax.numpy.DeviceArray is deprecated. Use jax.Array.",
|
||||
ndarray,
|
||||
),
|
||||
# Added June 2, 2023:
|
||||
"alltrue": (
|
||||
"jax.numpy.alltrue is deprecated. Use jax.numpy.all",
|
||||
@ -451,7 +446,6 @@ _deprecations = {
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
from jax._src.basearray import Array as DeviceArray
|
||||
alltrue = all
|
||||
cumproduct = cumprod
|
||||
product = prod
|
||||
|
Loading…
x
Reference in New Issue
Block a user