mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
A jit-able version of np.repeat. (#3670)
A new keyword argument has been added to np.repeat, total_repeat_length, that can optionally be supplied to make np.repeat jit-able.
This commit is contained in:
parent
06053212b3
commit
f6f97554f9
@ -2505,72 +2505,61 @@ def indices(dimensions, dtype=int32, sparse=False):
|
||||
return stack(output, 0) if output else array([], dtype=dtype)
|
||||
|
||||
|
||||
def _repeat_scalar(a, repeats, axis=None):
|
||||
if not isscalar(repeats):
|
||||
raise NotImplementedError(
|
||||
"_repeat_scalar implementation only supports scalar repeats")
|
||||
if axis is None or isscalar(a) or len(shape(a)) == 0:
|
||||
a = ravel(a)
|
||||
axis = 0
|
||||
a_shape = list(shape(a))
|
||||
num_dims = len(a_shape)
|
||||
if axis < 0:
|
||||
axis = axis + num_dims
|
||||
_TOTAL_REPEAT_LENGTH_DOC = """\
|
||||
Jax adds the optional `total_repeat_length` parameter which specifies the total
|
||||
number of repeat, and defaults to sum(repeats). It must be specified for repeat
|
||||
to be compilable. If `sum(repeats)` is larger than the specified
|
||||
`total_repeat_length` the remaining values will be discarded. In the case of
|
||||
`sum(repeats)` being smaller than the specified target length, the final value
|
||||
will be repeated.
|
||||
"""
|
||||
|
||||
if axis < 0 or axis >= num_dims:
|
||||
raise ValueError(
|
||||
"axis {} is out of bounds for array of dimension {}".format(
|
||||
axis, num_dims))
|
||||
|
||||
# Broadcasts to [..., X, repeats, ...] and reshapes to [..., X * repeats, ...]
|
||||
broadcast_shape = list(a_shape)
|
||||
broadcast_shape.insert(axis + 1, repeats)
|
||||
broadcast_dims = np.concatenate((np.arange(0, axis + 1),
|
||||
np.arange(axis + 2, num_dims + 1)))
|
||||
a_shape[axis] *= repeats
|
||||
return lax.reshape(
|
||||
lax.broadcast_in_dim(a, broadcast_shape, broadcast_dims),
|
||||
a_shape)
|
||||
|
||||
@_wraps(np.repeat)
|
||||
def repeat(a, repeats, axis=None):
|
||||
# use `_repeat_scalar` when possible
|
||||
if isscalar(repeats):
|
||||
return _repeat_scalar(a, repeats, axis)
|
||||
repeats_raveled = np.ravel(np.array(repeats))
|
||||
if size(repeats_raveled) == 1:
|
||||
return _repeat_scalar(a, repeats_raveled.item(), axis)
|
||||
|
||||
if axis is None or isscalar(a):
|
||||
@_wraps(np.repeat, lax_description=_TOTAL_REPEAT_LENGTH_DOC)
|
||||
def repeat(a, repeats, axis=None, *, total_repeat_length=None):
|
||||
if axis is None:
|
||||
a = ravel(a)
|
||||
axis = 0
|
||||
|
||||
# repeats must match the dimension along the requested axis
|
||||
if repeats_raveled.size != a.shape[axis]:
|
||||
raise ValueError(f"repeats shape {repeats_raveled.shape} does not match "
|
||||
f"the dimension on axis {a.shape[axis]}")
|
||||
repeats = array(repeats)
|
||||
repeats = ravel(repeats)
|
||||
|
||||
# calculating the new shape
|
||||
total = repeats_raveled.sum()
|
||||
if ndim(a) != 0:
|
||||
repeats = broadcast_to(repeats, [a.shape[axis]])
|
||||
|
||||
new_shape = list(a.shape)
|
||||
new_shape[axis] = total
|
||||
a_flattened = ravel(a)
|
||||
# If total_repeat_length is not given, use a default.
|
||||
if total_repeat_length is None:
|
||||
total_repeat_length = sum(repeats)
|
||||
|
||||
# first break down raveled input array into list of chunks; each chunk is the
|
||||
# unit of repeat. then tile the repeats to have same length as the list of
|
||||
# chunks. finally repeat each unit x number of times according to the tiled
|
||||
# repeat list.
|
||||
chunks = _prod(a.shape[:axis+1])
|
||||
a_splitted = split(a_flattened, chunks)
|
||||
repeats_tiled = np.tile(repeats_raveled, chunks // len(repeats_raveled))
|
||||
# Special case when a is a scalar.
|
||||
if ndim(a) == 0:
|
||||
if repeats.shape == (1,):
|
||||
return full([total_repeat_length], a)
|
||||
else:
|
||||
raise ValueError('`repeat` with a scalar parameter `a` is only '
|
||||
'implemented for scalar values of the parameter `repeats`.')
|
||||
|
||||
ret = array([], dtype=a.dtype)
|
||||
for i, repeat in enumerate(repeats_tiled):
|
||||
if repeat != 0:
|
||||
ret = concatenate((ret, tile(a_splitted[i], (repeat,))))
|
||||
# Special case if total_repeat_length is zero.
|
||||
if total_repeat_length == 0:
|
||||
return reshape(array([], dtype=a.dtype), array(a.shape).at[axis].set(0))
|
||||
|
||||
# If repeats is on a zero sized axis, then return the array.
|
||||
if a.shape[axis] == 0:
|
||||
return a
|
||||
|
||||
# Modify repeats from e.g. [1,2,5] -> [0,1,2]
|
||||
exclusive_repeats = roll(repeats, shift=1).at[0].set(0)
|
||||
# Cumsum to get indices of new number in repeated tensor, e.g. [0, 1, 3]
|
||||
scatter_indices = cumsum(exclusive_repeats)
|
||||
# Scatter these onto a zero buffer, e.g. [1,1,0,1,0,0,0,0]
|
||||
block_split_indicators = ops.index_update(
|
||||
x=zeros([total_repeat_length], dtype=int32),
|
||||
idx=scatter_indices,
|
||||
y=1)
|
||||
# Cumsum again to get scatter indices for repeat, e.g. [0,1,1,2,2,2,2,2]
|
||||
gather_indices = cumsum(block_split_indicators) - 1
|
||||
return take(a, gather_indices, axis=axis)
|
||||
|
||||
return reshape(ret, new_shape)
|
||||
|
||||
@_wraps(np.tri)
|
||||
def tri(N, M=None, k=0, dtype=None):
|
||||
|
@ -1319,24 +1319,58 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
|
||||
def _compute_total_repeat_length(self, shape, axis, repeats):
|
||||
# Calculate expected size of the repeated axis.
|
||||
if jnp.ndim(shape) == 0 :
|
||||
return repeats
|
||||
shape = jnp.array(shape)
|
||||
if shape.size == 0:
|
||||
return repeats
|
||||
if axis is None:
|
||||
axis = 0
|
||||
if jnp.ndim(shape) != 0:
|
||||
shape = jnp.array([jnp.product(shape)])
|
||||
# Broadcasting the repeats if a scalar value.
|
||||
expected_repeats = jnp.broadcast_to(jnp.ravel(repeats),
|
||||
[shape[axis]])
|
||||
# Total size will be num_repeats X axis length.
|
||||
total_repeat_length = jnp.sum(expected_repeats)
|
||||
return total_repeat_length
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape=[{}]_axis={}_repeats={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis, repeats),
|
||||
{"testcase_name": "_shape=[{}]_axis={}_repeats={}_fixed_size={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
axis, repeats, fixed_size),
|
||||
"axis": axis, "shape": shape, "dtype": dtype, "repeats": repeats,
|
||||
"rng_factory": jtu.rand_default}
|
||||
"rng_factory": jtu.rand_default, 'fixed_size': fixed_size}
|
||||
for repeats in [0, 1, 2]
|
||||
for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes)
|
||||
for axis in [None] + list(range(-len(shape), max(1, len(shape))))))
|
||||
def testRepeat(self, axis, shape, dtype, repeats, rng_factory):
|
||||
for axis in [None] + list(range(-len(shape), max(1, len(shape))))
|
||||
for fixed_size in [True, False]))
|
||||
def testRepeat(self, axis, shape, dtype, repeats, rng_factory, fixed_size):
|
||||
rng = rng_factory(self.rng())
|
||||
np_fun = lambda arg: np.repeat(arg, repeats=repeats, axis=axis)
|
||||
np_fun = _promote_like_jnp(np_fun)
|
||||
jnp_fun = lambda arg: jnp.repeat(arg, repeats=repeats, axis=axis)
|
||||
if fixed_size:
|
||||
total_repeat_length = self._compute_total_repeat_length(
|
||||
shape, axis, repeats)
|
||||
jnp_fun = lambda arg, rep: jnp.repeat(arg, repeats=rep, axis=axis,
|
||||
total_repeat_length=total_repeat_length)
|
||||
jnp_args_maker = lambda: [rng(shape, dtype), repeats]
|
||||
clo_fun = lambda arg: jnp.repeat(arg, repeats=repeats, axis=axis,
|
||||
total_repeat_length=total_repeat_length)
|
||||
clo_fun_args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CompileAndCheck(jnp_fun, jnp_args_maker)
|
||||
self._CheckAgainstNumpy(np_fun, clo_fun, clo_fun_args_maker)
|
||||
else:
|
||||
# Now repeats is in a closure, so a constant.
|
||||
jnp_fun = lambda arg: jnp.repeat(arg, repeats=repeats, axis=axis)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_ind={}_inv={}_count={}".format(
|
||||
@ -1357,7 +1391,11 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
jnp_fun = lambda x: jnp.unique(x, return_index, return_inverse, return_counts)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
|
||||
def testIssue1233(self):
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_fixed_size={}".format(fixed_size),
|
||||
"fixed_size": fixed_size}
|
||||
for fixed_size in [True, False]))
|
||||
def testNonScalarRepeats(self, fixed_size):
|
||||
'''
|
||||
Following numpy test suite from `test_repeat` at
|
||||
https://github.com/numpy/numpy/blob/master/numpy/core/tests/test_multiarray.py
|
||||
@ -1369,18 +1407,30 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
numpy_ans = np.repeat(m, repeats, axis)
|
||||
|
||||
self.assertAllClose(lax_ans, numpy_ans, rtol=tol, atol=tol)
|
||||
if fixed_size:
|
||||
|
||||
jnp_fun = lambda arg: jnp.repeat(arg, repeats = repeats, axis=axis)
|
||||
# Calculate expected size of the repeated axis.
|
||||
rep_length = self._compute_total_repeat_length(m.shape, axis, repeats)
|
||||
jnp_fun = lambda arg, rep: jnp.repeat(
|
||||
arg, repeats = rep, axis=axis, total_repeat_length=rep_length)
|
||||
else:
|
||||
jnp_fun = lambda arg: jnp.repeat(arg, repeats = repeats, axis=axis)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
m = jnp.array([1,2,3,4,5,6])
|
||||
args_maker = lambda: [m]
|
||||
if fixed_size:
|
||||
args_maker = lambda: [m, repeats]
|
||||
else:
|
||||
args_maker = lambda: [m]
|
||||
|
||||
for repeats in [2, [1,3,2,1,1,2], [1,3,0,1,1,2], [2], jnp.array([1,3,2,1,1,2]), jnp.array([2])]:
|
||||
test_single(m, args_maker, repeats, None)
|
||||
|
||||
m_rect = m.reshape((2,3))
|
||||
args_maker = lambda: [m_rect]
|
||||
if fixed_size:
|
||||
args_maker = lambda: [m_rect, repeats]
|
||||
else:
|
||||
args_maker = lambda: [m_rect]
|
||||
|
||||
for repeats in [2, [2,1], [2], jnp.array([2,1]), jnp.array([2])]:
|
||||
test_single(m_rect, args_maker, repeats, axis=0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user