diff --git a/CHANGELOG.md b/CHANGELOG.md index 0633ce1a2..1774e2382 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,18 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.12 +* Deprecations + * The following APIs have been removed after a 3 month deprecation period, in + accordance with the {ref}`api-compatibility` policy: + * `jax.numpy.alltrue`: use `jax.numpy.all`. This follows the deprecation + of `numpy.alltrue` in NumPy version 1.25.0. + * `jax.numpy.sometrue`: use `jax.numpy.any`. This follows the deprecation + of `numpy.sometrue` in NumPy version 1.25.0. + * `jax.numpy.product`: use `jax.numpy.prod`. This follows the deprecation + of `numpy.product` in NumPy version 1.25.0. + * `jax.numpy.cumproduct`: use `jax.numpy.cumprod`. This follows the deprecation + of `numpy.cumproduct` in NumPy version 1.25.0. + ## jaxlib 0.4.12 ## jax 0.4.11 (May 31, 2023) diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index e0ac99782..467331ff8 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -50,7 +50,6 @@ namespace; they are listed below. add all allclose - alltrue amax amin angle @@ -125,7 +124,6 @@ namespace; they are listed below. cross csingle cumprod - cumproduct cumsum deg2rad degrees @@ -318,7 +316,6 @@ namespace; they are listed below. power printoptions prod - product promote_types ptp put @@ -361,7 +358,6 @@ namespace; they are listed below. single sinh size - sometrue sort sort_complex split diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index e9858fe0b..e76364457 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -295,11 +295,8 @@ def any(a: ArrayLike, axis: Axis = None, out: None = None, return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, where=where) -product = prod amin = min amax = max -alltrue = all -sometrue = any def _axis_size(a: ArrayLike, axis: Union[int, Sequence[int]]): if not isinstance(axis, (tuple, list)): @@ -683,7 +680,6 @@ def _make_cumulative_reduction(np_reduction: Any, reduction: Callable[..., Array cumsum = _make_cumulative_reduction(np.cumsum, lax.cumsum, fill_nan=False) cumprod = _make_cumulative_reduction(np.cumprod, lax.cumprod, fill_nan=False) -cumproduct = cumprod nancumsum = _make_cumulative_reduction(np.nancumsum, lax.cumsum, fill_nan=True, fill_value=0) nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod, diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 91b66a0de..26e4d7ab6 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -287,7 +287,6 @@ from jax._src.numpy.polynomial import ( ) from jax._src.numpy.reductions import ( - alltrue as alltrue, amin as amin, amax as amax, any as any, @@ -296,7 +295,6 @@ from jax._src.numpy.reductions import ( count_nonzero as count_nonzero, cumsum as cumsum, cumprod as cumprod, - cumproduct as cumproduct, max as max, mean as mean, median as median, @@ -315,10 +313,8 @@ from jax._src.numpy.reductions import ( nanvar as nanvar, percentile as percentile, prod as prod, - product as product, ptp as ptp, quantile as quantile, - sometrue as sometrue, std as std, sum as sum, var as var, @@ -441,11 +437,32 @@ _deprecations = { "jax.numpy.DeviceArray is deprecated. Use jax.Array.", ndarray, ), + # Added June 2, 2023: + "alltrue": ( + "jax.numpy.alltrue is deprecated. Use jax.numpy.all", + all, + ), + "cumproduct": ( + "jax.numpy.cumproduct is deprecated. Use jax.numpy.cumprod", + cumprod, + ), + "product": ( + "jax.numpy.product is deprecated. Use jax.numpy.prod", + prod, + ), + "sometrue": ( + "jax.numpy.sometrue is deprecated. Use jax.numpy.any", + any, + ), } import typing if typing.TYPE_CHECKING: from jax._src.basearray import Array as DeviceArray + alltrue = all + cumproduct = cumprod + product = prod + sometrue = any else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations)