mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Deprecate passing NdArrays with ndim != 1 and non-arraylike inputs to jnp.trim_zeros
This commit is contained in:
parent
b904599b98
commit
2714469397
@ -12,6 +12,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
|
||||
## jax 0.4.34
|
||||
|
||||
* Deprecations
|
||||
* In {func}`jax.numpy.trim_zeros`, non-arraylike arguments or arraylike arguments
|
||||
with `ndim != 1` are now deprecated, and in the future will result in an error.
|
||||
|
||||
* Deletion:
|
||||
* `jax.xla_computation` is deleted. It's been 3 months since it's deprecation
|
||||
in 0.4.30 JAX release.
|
||||
|
@ -132,3 +132,4 @@ register('jax-numpy-clip-args')
|
||||
register('jax-numpy-linalg-matrix_rank-tol')
|
||||
register('jax-numpy-linalg-pinv-rcond')
|
||||
register('jax-numpy-quantile-interpolation')
|
||||
register('jax-numpy-trimzeros-not-1d-array')
|
||||
|
@ -7018,7 +7018,7 @@ def diagflat(v: ArrayLike, k: int = 0) -> Array:
|
||||
return res
|
||||
|
||||
|
||||
def trim_zeros(filt, trim='fb'):
|
||||
def trim_zeros(filt: ArrayLike, trim: str ='fb') -> Array:
|
||||
"""Trim leading and/or trailing zeros of the input array.
|
||||
|
||||
JAX implementation of :func:`numpy.trim_zeros`.
|
||||
@ -7040,14 +7040,26 @@ def trim_zeros(filt, trim='fb'):
|
||||
>>> jnp.trim_zeros(x)
|
||||
Array([2, 0, 1, 4, 3], dtype=int32)
|
||||
"""
|
||||
filt = core.concrete_or_error(asarray, filt,
|
||||
"Error arose in the `filt` argument of trim_zeros()")
|
||||
nz = (filt == 0)
|
||||
# Non-array inputs are deprecated 2024-09-11
|
||||
util.check_arraylike("trim_zeros", filt, emit_warning=True)
|
||||
core.concrete_or_error(None, filt,
|
||||
"Error arose in the `filt` argument of trim_zeros()")
|
||||
filt_arr = jax.numpy.asarray(filt)
|
||||
del filt
|
||||
if filt_arr.ndim != 1:
|
||||
# Added on 2024-09-11
|
||||
if deprecations.is_accelerated("jax-numpy-trimzeros-not-1d-array"):
|
||||
raise TypeError(f"'filt' must be 1-D array, but received {filt_arr.ndim}-D array.")
|
||||
warnings.warn(
|
||||
"Passing arrays with ndim != 1 to jnp.trim_zeros() is deprecated. Currently, it "
|
||||
"works with Arrays having ndim != 1. In the future this will result in an error.",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
nz = (filt_arr == 0)
|
||||
if reductions.all(nz):
|
||||
return empty(0, _dtype(filt))
|
||||
start = argmin(nz) if 'f' in trim.lower() else 0
|
||||
end = argmin(nz[::-1]) if 'b' in trim.lower() else 0
|
||||
return filt[start:len(filt) - end]
|
||||
return empty(0, filt_arr.dtype)
|
||||
start: Array | int = argmin(nz) if 'f' in trim.lower() else 0
|
||||
end: Array | int = argmin(nz[::-1]) if 'b' in trim.lower() else 0
|
||||
return filt_arr[start:len(filt_arr) - end]
|
||||
|
||||
|
||||
def trim_zeros_tol(filt, tol, trim='fb'):
|
||||
|
@ -1478,6 +1478,12 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
jnp_fun = lambda arg1: jnp.trim_zeros(arg1, trim)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
|
||||
def testTrimZerosNotOneDArray(self):
|
||||
# TODO: make this an error after the deprecation period.
|
||||
with self.assertWarnsRegex(DeprecationWarning,
|
||||
r"Passing arrays with ndim != 1 to jnp.trim_zeros\(\)"):
|
||||
jnp.trim_zeros(jnp.array([[0.0, 1.0, 0.0],[2.0, 4.5, 0.0]]))
|
||||
|
||||
@jtu.sample_product(
|
||||
rank=(1, 2),
|
||||
dtype=default_dtypes,
|
||||
|
Loading…
x
Reference in New Issue
Block a user