mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add new cumulative_sum function to numpy and array_api
This commit is contained in:
parent
adbb11f9fe
commit
ceeb975735
@ -9,8 +9,9 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
## jax 0.4.27
|
## jax 0.4.27
|
||||||
|
|
||||||
* New Functionality
|
* New Functionality
|
||||||
* Added {func}`jax.numpy.unstack`, following the addition of this function in
|
* Added {func}`jax.numpy.unstack` and {func}`jax.numpy.cumulative_sum`,
|
||||||
the array API 2023 standard, soon to be adopted by NumPy.
|
following their addition in the array API 2023 standard, soon to be
|
||||||
|
adopted by NumPy.
|
||||||
|
|
||||||
* Changes
|
* Changes
|
||||||
* {func}`jax.pure_callback` and {func}`jax.experimental.io_callback`
|
* {func}`jax.pure_callback` and {func}`jax.experimental.io_callback`
|
||||||
|
@ -138,6 +138,7 @@ namespace; they are listed below.
|
|||||||
csingle
|
csingle
|
||||||
cumprod
|
cumprod
|
||||||
cumsum
|
cumsum
|
||||||
|
cumulative_sum
|
||||||
deg2rad
|
deg2rad
|
||||||
degrees
|
degrees
|
||||||
delete
|
delete
|
||||||
|
@ -26,7 +26,7 @@ import numpy as np
|
|||||||
|
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax._src import api
|
from jax._src import api
|
||||||
from jax._src import core
|
from jax._src import core, config
|
||||||
from jax._src import dtypes
|
from jax._src import dtypes
|
||||||
from jax._src.numpy import ufuncs
|
from jax._src.numpy import ufuncs
|
||||||
from jax._src.numpy.util import (
|
from jax._src.numpy.util import (
|
||||||
@ -708,6 +708,42 @@ nancumsum = _make_cumulative_reduction(np.nancumsum, lax.cumsum,
|
|||||||
nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod,
|
nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod,
|
||||||
fill_nan=True, fill_value=1)
|
fill_nan=True, fill_value=1)
|
||||||
|
|
||||||
|
@implements(getattr(np, 'cumulative_sum', None))
|
||||||
|
def cumulative_sum(
|
||||||
|
x: ArrayLike, /, *, axis: int | None = None,
|
||||||
|
dtype: DTypeLike | None = None,
|
||||||
|
include_initial: bool = False) -> Array:
|
||||||
|
check_arraylike("cumulative_sum", x)
|
||||||
|
x = lax_internal.asarray(x)
|
||||||
|
if x.ndim == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"The input must be non-scalar to take a cumulative sum, however a "
|
||||||
|
"scalar value or scalar array was given."
|
||||||
|
)
|
||||||
|
if axis is None:
|
||||||
|
axis = 0
|
||||||
|
if x.ndim > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"The input array has rank {x.ndim}, however axis was not set to an "
|
||||||
|
"explicit value. The axis argument is only optional for one-dimensional "
|
||||||
|
"arrays.")
|
||||||
|
|
||||||
|
axis = _canonicalize_axis(axis, x.ndim)
|
||||||
|
dtypes.check_user_dtype_supported(dtype)
|
||||||
|
kind = x.dtype.kind
|
||||||
|
if (dtype is None and kind in {'i', 'u'}
|
||||||
|
and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)):
|
||||||
|
dtype = dtypes.canonicalize_dtype(dtypes._default_types[kind])
|
||||||
|
x = x.astype(dtype=dtype or x.dtype)
|
||||||
|
out = cumsum(x, axis=axis)
|
||||||
|
if include_initial:
|
||||||
|
zeros_shape = list(x.shape)
|
||||||
|
zeros_shape[axis] = 1
|
||||||
|
out = lax_internal.concatenate(
|
||||||
|
[lax_internal.full(zeros_shape, 0, dtype=out.dtype), out],
|
||||||
|
dimension=axis)
|
||||||
|
return out
|
||||||
|
|
||||||
# Quantiles
|
# Quantiles
|
||||||
@implements(np.quantile, skip_params=['out', 'overwrite_input'])
|
@implements(np.quantile, skip_params=['out', 'overwrite_input'])
|
||||||
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
|
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
|
||||||
|
@ -204,6 +204,7 @@ from jax.experimental.array_api._sorting_functions import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from jax.experimental.array_api._statistical_functions import (
|
from jax.experimental.array_api._statistical_functions import (
|
||||||
|
cumulative_sum as cumulative_sum,
|
||||||
max as max,
|
max as max,
|
||||||
mean as mean,
|
mean as mean,
|
||||||
min as min,
|
min as min,
|
||||||
|
@ -18,6 +18,10 @@ from jax.experimental.array_api._data_type_functions import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False):
|
||||||
|
"""Calculates the cumulative sum of elements in the input array x."""
|
||||||
|
return jax.numpy.cumulative_sum(x, axis=axis, dtype=dtype, include_initial=include_initial)
|
||||||
|
|
||||||
def max(x, /, *, axis=None, keepdims=False):
|
def max(x, /, *, axis=None, keepdims=False):
|
||||||
"""Calculates the maximum value of the input array x."""
|
"""Calculates the maximum value of the input array x."""
|
||||||
return jax.numpy.max(x, axis=axis, keepdims=keepdims)
|
return jax.numpy.max(x, axis=axis, keepdims=keepdims)
|
||||||
|
@ -296,6 +296,7 @@ from jax._src.numpy.reductions import (
|
|||||||
count_nonzero as count_nonzero,
|
count_nonzero as count_nonzero,
|
||||||
cumsum as cumsum,
|
cumsum as cumsum,
|
||||||
cumprod as cumprod,
|
cumprod as cumprod,
|
||||||
|
cumulative_sum as cumulative_sum,
|
||||||
max as max,
|
max as max,
|
||||||
mean as mean,
|
mean as mean,
|
||||||
median as median,
|
median as median,
|
||||||
|
@ -241,6 +241,9 @@ def cumprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
|
|||||||
cumproduct = cumprod
|
cumproduct = cumprod
|
||||||
def cumsum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
|
def cumsum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
|
||||||
out: None = ...) -> Array: ...
|
out: None = ...) -> Array: ...
|
||||||
|
def cumulative_sum(x: ArrayLike, /, *, axis: int | None = ...,
|
||||||
|
dtype: DTypeLike | None = ...,
|
||||||
|
include_initial: bool = ...) -> Array: ...
|
||||||
|
|
||||||
def deg2rad(x: ArrayLike, /) -> Array: ...
|
def deg2rad(x: ArrayLike, /) -> Array: ...
|
||||||
degrees = rad2deg
|
degrees = rad2deg
|
||||||
|
@ -68,6 +68,7 @@ MAIN_NAMESPACE = {
|
|||||||
'copysign',
|
'copysign',
|
||||||
'cos',
|
'cos',
|
||||||
'cosh',
|
'cosh',
|
||||||
|
'cumulative_sum',
|
||||||
'divide',
|
'divide',
|
||||||
'e',
|
'e',
|
||||||
'empty',
|
'empty',
|
||||||
|
@ -770,5 +770,64 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
|||||||
self.assertAllClose(expected, actual, atol=0)
|
self.assertAllClose(expected, actual, atol=0)
|
||||||
|
|
||||||
|
|
||||||
|
@jtu.sample_product(
|
||||||
|
[dict(shape=shape, axis=axis)
|
||||||
|
for shape in all_shapes
|
||||||
|
for axis in list(
|
||||||
|
range(-len(shape), len(shape))
|
||||||
|
) + ([None] if len(shape) == 1 else [])],
|
||||||
|
dtype=all_dtypes + [None],
|
||||||
|
out_dtype=all_dtypes,
|
||||||
|
include_initial=[False, True],
|
||||||
|
)
|
||||||
|
@jtu.ignore_warning(category=NumpyComplexWarning)
|
||||||
|
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
|
||||||
|
def testCumulativeSum(self, shape, axis, dtype, out_dtype, include_initial):
|
||||||
|
rng = jtu.rand_some_zero(self.rng())
|
||||||
|
|
||||||
|
def np_mock_op(x, axis=None, dtype=None, include_initial=False):
|
||||||
|
kind = x.dtype.kind
|
||||||
|
if (dtype is None and kind in {'i', 'u'}
|
||||||
|
and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)):
|
||||||
|
dtype = dtypes.canonicalize_dtype(dtypes._default_types[kind])
|
||||||
|
axis = axis or 0
|
||||||
|
x = x.astype(dtype=dtype or x.dtype)
|
||||||
|
out = jnp.cumsum(x, axis=axis)
|
||||||
|
if include_initial:
|
||||||
|
zeros_shape = list(x.shape)
|
||||||
|
zeros_shape[axis] = 1
|
||||||
|
out = jnp.concat([jnp.zeros(zeros_shape, dtype=out.dtype), out], axis=axis)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# We currently "cheat" to ensure we have JAX arrays, not NumPy arrays as
|
||||||
|
# input because we rely on JAX-specific casting behavior
|
||||||
|
args_maker = lambda: [jnp.array(rng(shape, dtype))]
|
||||||
|
np_op = getattr(np, "cumulative_sum", np_mock_op)
|
||||||
|
kwargs = dict(axis=axis, dtype=out_dtype, include_initial=include_initial)
|
||||||
|
|
||||||
|
np_fun = lambda x: np_op(x, **kwargs)
|
||||||
|
jnp_fun = lambda x: jnp.cumulative_sum(x, **kwargs)
|
||||||
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||||
|
self._CompileAndCheck(jnp_fun, args_maker)
|
||||||
|
|
||||||
|
|
||||||
|
@jtu.sample_product(
|
||||||
|
shape=filter(lambda x: len(x) != 1, all_shapes), dtype=all_dtypes,
|
||||||
|
include_initial=[False, True])
|
||||||
|
def testCumulativeSumErrors(self, shape, dtype, include_initial):
|
||||||
|
rng = jtu.rand_some_zero(self.rng())
|
||||||
|
x = rng(shape, dtype)
|
||||||
|
rank = jnp.asarray(x).ndim
|
||||||
|
if rank == 0:
|
||||||
|
msg = r"The input must be non-scalar to take"
|
||||||
|
with self.assertRaisesRegex(ValueError, msg):
|
||||||
|
jnp.cumulative_sum(x, include_initial=include_initial)
|
||||||
|
elif rank > 1:
|
||||||
|
msg = r"The input array has rank \d*, however"
|
||||||
|
with self.assertRaisesRegex(ValueError, msg):
|
||||||
|
jnp.cumulative_sum(x, include_initial=include_initial)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user