Deprecate jax.numpy functions alltrue, sometrue, product, cumproduct

This commit is contained in:
Jake VanderPlas 2023-06-02 04:10:46 -07:00
parent 6a89abcc76
commit 3bef6214bb
4 changed files with 33 additions and 12 deletions

View File

@ -8,6 +8,18 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.12
* Deprecations
* The following APIs have been removed after a 3 month deprecation period, in
accordance with the {ref}`api-compatibility` policy:
* `jax.numpy.alltrue`: use `jax.numpy.all`. This follows the deprecation
of `numpy.alltrue` in NumPy version 1.25.0.
* `jax.numpy.sometrue`: use `jax.numpy.any`. This follows the deprecation
of `numpy.sometrue` in NumPy version 1.25.0.
* `jax.numpy.product`: use `jax.numpy.prod`. This follows the deprecation
of `numpy.product` in NumPy version 1.25.0.
* `jax.numpy.cumproduct`: use `jax.numpy.cumprod`. This follows the deprecation
of `numpy.cumproduct` in NumPy version 1.25.0.
## jaxlib 0.4.12
## jax 0.4.11 (May 31, 2023)

View File

@ -50,7 +50,6 @@ namespace; they are listed below.
add
all
allclose
alltrue
amax
amin
angle
@ -125,7 +124,6 @@ namespace; they are listed below.
cross
csingle
cumprod
cumproduct
cumsum
deg2rad
degrees
@ -318,7 +316,6 @@ namespace; they are listed below.
power
printoptions
prod
product
promote_types
ptp
put
@ -361,7 +358,6 @@ namespace; they are listed below.
single
sinh
size
sometrue
sort
sort_complex
split

View File

@ -295,11 +295,8 @@ def any(a: ArrayLike, axis: Axis = None, out: None = None,
return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out,
keepdims=keepdims, where=where)
product = prod
amin = min
amax = max
alltrue = all
sometrue = any
def _axis_size(a: ArrayLike, axis: Union[int, Sequence[int]]):
if not isinstance(axis, (tuple, list)):
@ -683,7 +680,6 @@ def _make_cumulative_reduction(np_reduction: Any, reduction: Callable[..., Array
cumsum = _make_cumulative_reduction(np.cumsum, lax.cumsum, fill_nan=False)
cumprod = _make_cumulative_reduction(np.cumprod, lax.cumprod, fill_nan=False)
cumproduct = cumprod
nancumsum = _make_cumulative_reduction(np.nancumsum, lax.cumsum,
fill_nan=True, fill_value=0)
nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod,

View File

@ -287,7 +287,6 @@ from jax._src.numpy.polynomial import (
)
from jax._src.numpy.reductions import (
alltrue as alltrue,
amin as amin,
amax as amax,
any as any,
@ -296,7 +295,6 @@ from jax._src.numpy.reductions import (
count_nonzero as count_nonzero,
cumsum as cumsum,
cumprod as cumprod,
cumproduct as cumproduct,
max as max,
mean as mean,
median as median,
@ -315,10 +313,8 @@ from jax._src.numpy.reductions import (
nanvar as nanvar,
percentile as percentile,
prod as prod,
product as product,
ptp as ptp,
quantile as quantile,
sometrue as sometrue,
std as std,
sum as sum,
var as var,
@ -441,11 +437,32 @@ _deprecations = {
"jax.numpy.DeviceArray is deprecated. Use jax.Array.",
ndarray,
),
# Added June 2, 2023:
"alltrue": (
"jax.numpy.alltrue is deprecated. Use jax.numpy.all",
all,
),
"cumproduct": (
"jax.numpy.cumproduct is deprecated. Use jax.numpy.cumprod",
cumprod,
),
"product": (
"jax.numpy.product is deprecated. Use jax.numpy.prod",
prod,
),
"sometrue": (
"jax.numpy.sometrue is deprecated. Use jax.numpy.any",
any,
),
}
import typing
if typing.TYPE_CHECKING:
from jax._src.basearray import Array as DeviceArray
alltrue = all
cumproduct = cumprod
product = prod
sometrue = any
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)