mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Deprecate jax.numpy functions alltrue, sometrue, product, cumproduct
This commit is contained in:
parent
6a89abcc76
commit
3bef6214bb
12
CHANGELOG.md
12
CHANGELOG.md
@ -8,6 +8,18 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jax 0.4.12
|
||||
|
||||
* Deprecations
|
||||
* The following APIs have been removed after a 3 month deprecation period, in
|
||||
accordance with the {ref}`api-compatibility` policy:
|
||||
* `jax.numpy.alltrue`: use `jax.numpy.all`. This follows the deprecation
|
||||
of `numpy.alltrue` in NumPy version 1.25.0.
|
||||
* `jax.numpy.sometrue`: use `jax.numpy.any`. This follows the deprecation
|
||||
of `numpy.sometrue` in NumPy version 1.25.0.
|
||||
* `jax.numpy.product`: use `jax.numpy.prod`. This follows the deprecation
|
||||
of `numpy.product` in NumPy version 1.25.0.
|
||||
* `jax.numpy.cumproduct`: use `jax.numpy.cumprod`. This follows the deprecation
|
||||
of `numpy.cumproduct` in NumPy version 1.25.0.
|
||||
|
||||
## jaxlib 0.4.12
|
||||
|
||||
## jax 0.4.11 (May 31, 2023)
|
||||
|
@ -50,7 +50,6 @@ namespace; they are listed below.
|
||||
add
|
||||
all
|
||||
allclose
|
||||
alltrue
|
||||
amax
|
||||
amin
|
||||
angle
|
||||
@ -125,7 +124,6 @@ namespace; they are listed below.
|
||||
cross
|
||||
csingle
|
||||
cumprod
|
||||
cumproduct
|
||||
cumsum
|
||||
deg2rad
|
||||
degrees
|
||||
@ -318,7 +316,6 @@ namespace; they are listed below.
|
||||
power
|
||||
printoptions
|
||||
prod
|
||||
product
|
||||
promote_types
|
||||
ptp
|
||||
put
|
||||
@ -361,7 +358,6 @@ namespace; they are listed below.
|
||||
single
|
||||
sinh
|
||||
size
|
||||
sometrue
|
||||
sort
|
||||
sort_complex
|
||||
split
|
||||
|
@ -295,11 +295,8 @@ def any(a: ArrayLike, axis: Axis = None, out: None = None,
|
||||
return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out,
|
||||
keepdims=keepdims, where=where)
|
||||
|
||||
product = prod
|
||||
amin = min
|
||||
amax = max
|
||||
alltrue = all
|
||||
sometrue = any
|
||||
|
||||
def _axis_size(a: ArrayLike, axis: Union[int, Sequence[int]]):
|
||||
if not isinstance(axis, (tuple, list)):
|
||||
@ -683,7 +680,6 @@ def _make_cumulative_reduction(np_reduction: Any, reduction: Callable[..., Array
|
||||
|
||||
cumsum = _make_cumulative_reduction(np.cumsum, lax.cumsum, fill_nan=False)
|
||||
cumprod = _make_cumulative_reduction(np.cumprod, lax.cumprod, fill_nan=False)
|
||||
cumproduct = cumprod
|
||||
nancumsum = _make_cumulative_reduction(np.nancumsum, lax.cumsum,
|
||||
fill_nan=True, fill_value=0)
|
||||
nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod,
|
||||
|
@ -287,7 +287,6 @@ from jax._src.numpy.polynomial import (
|
||||
)
|
||||
|
||||
from jax._src.numpy.reductions import (
|
||||
alltrue as alltrue,
|
||||
amin as amin,
|
||||
amax as amax,
|
||||
any as any,
|
||||
@ -296,7 +295,6 @@ from jax._src.numpy.reductions import (
|
||||
count_nonzero as count_nonzero,
|
||||
cumsum as cumsum,
|
||||
cumprod as cumprod,
|
||||
cumproduct as cumproduct,
|
||||
max as max,
|
||||
mean as mean,
|
||||
median as median,
|
||||
@ -315,10 +313,8 @@ from jax._src.numpy.reductions import (
|
||||
nanvar as nanvar,
|
||||
percentile as percentile,
|
||||
prod as prod,
|
||||
product as product,
|
||||
ptp as ptp,
|
||||
quantile as quantile,
|
||||
sometrue as sometrue,
|
||||
std as std,
|
||||
sum as sum,
|
||||
var as var,
|
||||
@ -441,11 +437,32 @@ _deprecations = {
|
||||
"jax.numpy.DeviceArray is deprecated. Use jax.Array.",
|
||||
ndarray,
|
||||
),
|
||||
# Added June 2, 2023:
|
||||
"alltrue": (
|
||||
"jax.numpy.alltrue is deprecated. Use jax.numpy.all",
|
||||
all,
|
||||
),
|
||||
"cumproduct": (
|
||||
"jax.numpy.cumproduct is deprecated. Use jax.numpy.cumprod",
|
||||
cumprod,
|
||||
),
|
||||
"product": (
|
||||
"jax.numpy.product is deprecated. Use jax.numpy.prod",
|
||||
prod,
|
||||
),
|
||||
"sometrue": (
|
||||
"jax.numpy.sometrue is deprecated. Use jax.numpy.any",
|
||||
any,
|
||||
),
|
||||
}
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
from jax._src.basearray import Array as DeviceArray
|
||||
alltrue = all
|
||||
cumproduct = cumprod
|
||||
product = prod
|
||||
sometrue = any
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
|
Loading…
x
Reference in New Issue
Block a user