A jit-able version of np.repeat. (#3670)

A new keyword argument has been added to np.repeat, total_repeat_length, that can optionally be supplied to make np.repeat jit-able.
This commit is contained in:
Jonathan Godwin 2020-07-14 18:37:09 +01:00 committed by GitHub
parent 06053212b3
commit f6f97554f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 109 additions and 70 deletions

View File

@ -2505,72 +2505,61 @@ def indices(dimensions, dtype=int32, sparse=False):
return stack(output, 0) if output else array([], dtype=dtype)
def _repeat_scalar(a, repeats, axis=None):
if not isscalar(repeats):
raise NotImplementedError(
"_repeat_scalar implementation only supports scalar repeats")
if axis is None or isscalar(a) or len(shape(a)) == 0:
a = ravel(a)
axis = 0
a_shape = list(shape(a))
num_dims = len(a_shape)
if axis < 0:
axis = axis + num_dims
_TOTAL_REPEAT_LENGTH_DOC = """\
Jax adds the optional `total_repeat_length` parameter which specifies the total
number of repeat, and defaults to sum(repeats). It must be specified for repeat
to be compilable. If `sum(repeats)` is larger than the specified
`total_repeat_length` the remaining values will be discarded. In the case of
`sum(repeats)` being smaller than the specified target length, the final value
will be repeated.
"""
if axis < 0 or axis >= num_dims:
raise ValueError(
"axis {} is out of bounds for array of dimension {}".format(
axis, num_dims))
# Broadcasts to [..., X, repeats, ...] and reshapes to [..., X * repeats, ...]
broadcast_shape = list(a_shape)
broadcast_shape.insert(axis + 1, repeats)
broadcast_dims = np.concatenate((np.arange(0, axis + 1),
np.arange(axis + 2, num_dims + 1)))
a_shape[axis] *= repeats
return lax.reshape(
lax.broadcast_in_dim(a, broadcast_shape, broadcast_dims),
a_shape)
@_wraps(np.repeat)
def repeat(a, repeats, axis=None):
# use `_repeat_scalar` when possible
if isscalar(repeats):
return _repeat_scalar(a, repeats, axis)
repeats_raveled = np.ravel(np.array(repeats))
if size(repeats_raveled) == 1:
return _repeat_scalar(a, repeats_raveled.item(), axis)
if axis is None or isscalar(a):
@_wraps(np.repeat, lax_description=_TOTAL_REPEAT_LENGTH_DOC)
def repeat(a, repeats, axis=None, *, total_repeat_length=None):
if axis is None:
a = ravel(a)
axis = 0
# repeats must match the dimension along the requested axis
if repeats_raveled.size != a.shape[axis]:
raise ValueError(f"repeats shape {repeats_raveled.shape} does not match "
f"the dimension on axis {a.shape[axis]}")
repeats = array(repeats)
repeats = ravel(repeats)
# calculating the new shape
total = repeats_raveled.sum()
if ndim(a) != 0:
repeats = broadcast_to(repeats, [a.shape[axis]])
new_shape = list(a.shape)
new_shape[axis] = total
a_flattened = ravel(a)
# If total_repeat_length is not given, use a default.
if total_repeat_length is None:
total_repeat_length = sum(repeats)
# first break down raveled input array into list of chunks; each chunk is the
# unit of repeat. then tile the repeats to have same length as the list of
# chunks. finally repeat each unit x number of times according to the tiled
# repeat list.
chunks = _prod(a.shape[:axis+1])
a_splitted = split(a_flattened, chunks)
repeats_tiled = np.tile(repeats_raveled, chunks // len(repeats_raveled))
# Special case when a is a scalar.
if ndim(a) == 0:
if repeats.shape == (1,):
return full([total_repeat_length], a)
else:
raise ValueError('`repeat` with a scalar parameter `a` is only '
'implemented for scalar values of the parameter `repeats`.')
ret = array([], dtype=a.dtype)
for i, repeat in enumerate(repeats_tiled):
if repeat != 0:
ret = concatenate((ret, tile(a_splitted[i], (repeat,))))
# Special case if total_repeat_length is zero.
if total_repeat_length == 0:
return reshape(array([], dtype=a.dtype), array(a.shape).at[axis].set(0))
# If repeats is on a zero sized axis, then return the array.
if a.shape[axis] == 0:
return a
# Modify repeats from e.g. [1,2,5] -> [0,1,2]
exclusive_repeats = roll(repeats, shift=1).at[0].set(0)
# Cumsum to get indices of new number in repeated tensor, e.g. [0, 1, 3]
scatter_indices = cumsum(exclusive_repeats)
# Scatter these onto a zero buffer, e.g. [1,1,0,1,0,0,0,0]
block_split_indicators = ops.index_update(
x=zeros([total_repeat_length], dtype=int32),
idx=scatter_indices,
y=1)
# Cumsum again to get scatter indices for repeat, e.g. [0,1,1,2,2,2,2,2]
gather_indices = cumsum(block_split_indicators) - 1
return take(a, gather_indices, axis=axis)
return reshape(ret, new_shape)
@_wraps(np.tri)
def tri(N, M=None, k=0, dtype=None):

View File

@ -1319,24 +1319,58 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def _compute_total_repeat_length(self, shape, axis, repeats):
# Calculate expected size of the repeated axis.
if jnp.ndim(shape) == 0 :
return repeats
shape = jnp.array(shape)
if shape.size == 0:
return repeats
if axis is None:
axis = 0
if jnp.ndim(shape) != 0:
shape = jnp.array([jnp.product(shape)])
# Broadcasting the repeats if a scalar value.
expected_repeats = jnp.broadcast_to(jnp.ravel(repeats),
[shape[axis]])
# Total size will be num_repeats X axis length.
total_repeat_length = jnp.sum(expected_repeats)
return total_repeat_length
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape=[{}]_axis={}_repeats={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis, repeats),
{"testcase_name": "_shape=[{}]_axis={}_repeats={}_fixed_size={}".format(
jtu.format_shape_dtype_string(shape, dtype),
axis, repeats, fixed_size),
"axis": axis, "shape": shape, "dtype": dtype, "repeats": repeats,
"rng_factory": jtu.rand_default}
"rng_factory": jtu.rand_default, 'fixed_size': fixed_size}
for repeats in [0, 1, 2]
for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes)
for axis in [None] + list(range(-len(shape), max(1, len(shape))))))
def testRepeat(self, axis, shape, dtype, repeats, rng_factory):
for axis in [None] + list(range(-len(shape), max(1, len(shape))))
for fixed_size in [True, False]))
def testRepeat(self, axis, shape, dtype, repeats, rng_factory, fixed_size):
rng = rng_factory(self.rng())
np_fun = lambda arg: np.repeat(arg, repeats=repeats, axis=axis)
np_fun = _promote_like_jnp(np_fun)
jnp_fun = lambda arg: jnp.repeat(arg, repeats=repeats, axis=axis)
if fixed_size:
total_repeat_length = self._compute_total_repeat_length(
shape, axis, repeats)
jnp_fun = lambda arg, rep: jnp.repeat(arg, repeats=rep, axis=axis,
total_repeat_length=total_repeat_length)
jnp_args_maker = lambda: [rng(shape, dtype), repeats]
clo_fun = lambda arg: jnp.repeat(arg, repeats=repeats, axis=axis,
total_repeat_length=total_repeat_length)
clo_fun_args_maker = lambda: [rng(shape, dtype)]
self._CompileAndCheck(jnp_fun, jnp_args_maker)
self._CheckAgainstNumpy(np_fun, clo_fun, clo_fun_args_maker)
else:
# Now repeats is in a closure, so a constant.
jnp_fun = lambda arg: jnp.repeat(arg, repeats=repeats, axis=axis)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_ind={}_inv={}_count={}".format(
@ -1357,7 +1391,11 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
jnp_fun = lambda x: jnp.unique(x, return_index, return_inverse, return_counts)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
def testIssue1233(self):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_fixed_size={}".format(fixed_size),
"fixed_size": fixed_size}
for fixed_size in [True, False]))
def testNonScalarRepeats(self, fixed_size):
'''
Following numpy test suite from `test_repeat` at
https://github.com/numpy/numpy/blob/master/numpy/core/tests/test_multiarray.py
@ -1369,18 +1407,30 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
numpy_ans = np.repeat(m, repeats, axis)
self.assertAllClose(lax_ans, numpy_ans, rtol=tol, atol=tol)
if fixed_size:
jnp_fun = lambda arg: jnp.repeat(arg, repeats = repeats, axis=axis)
# Calculate expected size of the repeated axis.
rep_length = self._compute_total_repeat_length(m.shape, axis, repeats)
jnp_fun = lambda arg, rep: jnp.repeat(
arg, repeats = rep, axis=axis, total_repeat_length=rep_length)
else:
jnp_fun = lambda arg: jnp.repeat(arg, repeats = repeats, axis=axis)
self._CompileAndCheck(jnp_fun, args_maker)
m = jnp.array([1,2,3,4,5,6])
args_maker = lambda: [m]
if fixed_size:
args_maker = lambda: [m, repeats]
else:
args_maker = lambda: [m]
for repeats in [2, [1,3,2,1,1,2], [1,3,0,1,1,2], [2], jnp.array([1,3,2,1,1,2]), jnp.array([2])]:
test_single(m, args_maker, repeats, None)
m_rect = m.reshape((2,3))
args_maker = lambda: [m_rect]
if fixed_size:
args_maker = lambda: [m_rect, repeats]
else:
args_maker = lambda: [m_rect]
for repeats in [2, [2,1], [2], jnp.array([2,1]), jnp.array([2])]:
test_single(m_rect, args_maker, repeats, axis=0)