mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
jax.numpy: add wrapper for np.insert
This commit is contained in:
parent
0851e05efc
commit
b895f530d0
@ -4415,6 +4415,57 @@ def delete(arr, obj, axis=None):
|
||||
raise ValueError(f"np.delete(arr, obj): got obj.dtype={obj.dtype}; must be integer or bool.")
|
||||
return arr[tuple(slice(None) for i in range(axis)) + (mask,)]
|
||||
|
||||
@_wraps(np.insert)
|
||||
def insert(arr, obj, values, axis=None):
|
||||
_check_arraylike("insert", arr, 0 if isinstance(obj, slice) else obj, values)
|
||||
arr = asarray(arr)
|
||||
values = asarray(values)
|
||||
|
||||
if axis is None:
|
||||
arr = ravel(arr)
|
||||
axis = 0
|
||||
axis = core.concrete_or_error(None, axis, "axis argument of jnp.insert()")
|
||||
axis = _canonicalize_axis(axis, arr.ndim)
|
||||
if isinstance(obj, slice):
|
||||
indices = arange(*obj.indices(arr.shape[axis]))
|
||||
else:
|
||||
indices = asarray(obj)
|
||||
|
||||
if indices.ndim > 1:
|
||||
raise ValueError("jnp.insert(): obj must be a slice, a one-dimensional "
|
||||
f"array, or a scalar; got {obj}")
|
||||
if not np.issubdtype(indices.dtype, np.integer):
|
||||
if indices.size == 0 and not isinstance(obj, ndarray):
|
||||
indices = indices.astype(int)
|
||||
else:
|
||||
# Note: np.insert allows boolean inputs but the behavior is deprecated.
|
||||
raise ValueError("jnp.insert(): index array must be "
|
||||
f"integer typed; got {obj}")
|
||||
values = array(values, ndmin=arr.ndim, dtype=arr.dtype, copy=False)
|
||||
|
||||
if indices.size == 1:
|
||||
index = ravel(indices)[0]
|
||||
if indices.ndim == 0:
|
||||
values = moveaxis(values, 0, axis)
|
||||
indices = full(values.shape[axis], index)
|
||||
n_input = arr.shape[axis]
|
||||
n_insert = 0 if len(indices) == 0 else _max(values.shape[axis], len(indices))
|
||||
out_shape = list(arr.shape)
|
||||
out_shape[axis] += n_insert
|
||||
out = zeros_like(arr, shape=tuple(out_shape))
|
||||
|
||||
indices = where(indices < 0, indices + n_input, indices)
|
||||
indices = clip(indices, 0, n_input)
|
||||
|
||||
values_ind = indices.at[argsort(indices)].add(arange(n_insert))
|
||||
arr_mask = ones(n_input + n_insert, dtype=bool).at[values_ind].set(False)
|
||||
arr_ind = where(arr_mask, size=n_input)[0]
|
||||
|
||||
out = out.at[(slice(None),) * axis + (values_ind,)].set(values)
|
||||
out = out.at[(slice(None),) * axis + (arr_ind,)].set(arr)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@_wraps(np.apply_along_axis)
|
||||
def apply_along_axis(func1d, axis: int, arr, *args, **kwargs):
|
||||
|
@ -185,6 +185,7 @@ from jax._src.numpy.lax_numpy import (
|
||||
in1d as in1d,
|
||||
inf as inf,
|
||||
inner as inner,
|
||||
insert as insert,
|
||||
int16 as int16,
|
||||
int32 as int32,
|
||||
int64 as int64,
|
||||
|
@ -2152,6 +2152,60 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_axis={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis),
|
||||
"dtype": dtype, "shape": shape, "axis": axis}
|
||||
for shape in nonempty_nonscalar_array_shapes
|
||||
for dtype in all_dtypes
|
||||
for axis in [None] + list(range(-len(shape), len(shape)))))
|
||||
def testInsertInteger(self, shape, dtype, axis):
|
||||
x = jnp.empty(shape)
|
||||
max_ind = x.size if axis is None else x.shape[axis]
|
||||
rng = jtu.rand_default(self.rng())
|
||||
i_rng = jtu.rand_int(self.rng(), -max_ind, max_ind)
|
||||
args_maker = lambda: [rng(shape, dtype), i_rng((), np.int32), rng((), dtype)]
|
||||
np_fun = lambda *args: np.insert(*args, axis=axis)
|
||||
jnp_fun = lambda *args: jnp.insert(*args, axis=axis)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_axis={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis),
|
||||
"dtype": dtype, "shape": shape, "axis": axis}
|
||||
for shape in nonempty_nonscalar_array_shapes
|
||||
for dtype in all_dtypes
|
||||
for axis in [None] + list(range(-len(shape), len(shape)))))
|
||||
def testInsertSlice(self, shape, dtype, axis):
|
||||
x = jnp.empty(shape)
|
||||
max_ind = x.size if axis is None else x.shape[axis]
|
||||
rng = jtu.rand_default(self.rng())
|
||||
i_rng = jtu.rand_int(self.rng(), -max_ind, max_ind)
|
||||
slc = slice(i_rng((), jnp.int32).item(), i_rng((), jnp.int32).item())
|
||||
args_maker = lambda: [rng(shape, dtype), rng((), dtype)]
|
||||
np_fun = lambda x, val: np.insert(x, slc, val, axis=axis)
|
||||
jnp_fun = lambda x, val: jnp.insert(x, slc, val, axis=axis)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.parameters([
|
||||
[[[1, 1], [2, 2], [3, 3]], 1, 5, None],
|
||||
[[[1, 1], [2, 2], [3, 3]], 1, 5, 1],
|
||||
[[[1, 1], [2, 2], [3, 3]], 1, [1, 2, 3], 1],
|
||||
[[[1, 1], [2, 2], [3, 3]], [1], [[1],[2],[3]], 1],
|
||||
[[1, 1, 2, 2, 3, 3], [2, 2], [5, 6], None],
|
||||
[[1, 1, 2, 2, 3, 3], slice(2, 4), [5, 6], None],
|
||||
[[1, 1, 2, 2, 3, 3], [2, 2], [7.13, False], None],
|
||||
[[[0, 1, 2, 3], [4, 5, 6, 7]], (1, 3), 999, 1]
|
||||
])
|
||||
def testInsertExamples(self, arr, index, values, axis):
|
||||
# Test examples from the np.insert docstring
|
||||
args_maker = lambda: (
|
||||
np.asarray(arr), index if isinstance(index, slice) else np.array(index),
|
||||
np.asarray(values), axis)
|
||||
self._CheckAgainstNumpy(np.insert, jnp.insert, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_axis={}_out_dims={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
|
Loading…
x
Reference in New Issue
Block a user