[sparse] make bcoo_sum_duplicates a primitive

This commit is contained in:
Jake VanderPlas 2022-04-19 15:34:49 -07:00
parent 27d32e4a74
commit edae0ac31f
4 changed files with 215 additions and 59 deletions

View File

@ -204,6 +204,8 @@ from jax.experimental.sparse.bcoo import (
bcoo_sort_indices as bcoo_sort_indices,
bcoo_sort_indices_p as bcoo_sort_indices_p,
bcoo_spdot_general_p as bcoo_spdot_general_p,
bcoo_sum_duplicates as bcoo_sum_duplicates,
bcoo_sum_duplicates_p as bcoo_sum_duplicates_p,
bcoo_todense as bcoo_todense,
bcoo_todense_p as bcoo_todense_p,
bcoo_transpose as bcoo_transpose,

View File

@ -79,53 +79,21 @@ def _bcoo_nse(mat, n_batch=0, n_dense=0):
mask = mask.sum(list(range(n_batch, mask.ndim)))
return mask.max()
# TODO(jakevdp): add a custom autodiff rule that errors if remove_zeros=True, because
# it produces wrong values. See https://github.com/google/jax/issues/10163
def _bcoo_sum_duplicates(data, indices, shape, nse=None, remove_zeros=True):
if nse is None and isinstance(jnp.array(0), core.Tracer):
raise ValueError("When used with JIT, vmap, or another transform, sum_duplicates() "
"requires passing a non-None value for the nse argument.")
# TODO(jakevdp) this can be problematic when used with autodiff; see
# https://github.com/google/jax/issues/10163. Should this be a primitive?
# Alternatively, maybe roll this into bcoo_sum_duplicates as an optional argument.
def bcoo_eliminate_zeros(mat, nse=None):
data, indices, shape = mat.data, mat.indices, mat.shape
props = _validate_bcoo(data, indices, shape)
f = functools.partial(_bcoo_sum_duplicates_unbatched, shape=shape[props.n_batch:],
nse=nse, remove_zeros=remove_zeros)
mask = (data == 0).all(tuple(range(props.n_batch + 1, data.ndim)))
dims_to_contract = tuple(i for i, s in enumerate(indices.shape[:props.n_batch]) if s == 1)
mask = mask.all(dims_to_contract, keepdims=True)
fill_value = jnp.array(shape[props.n_batch:props.n_batch + props.n_sparse], dtype=indices.dtype)
f = lambda i, m: jnp.where(m[:, None], fill_value[None, :], i)
for _ in range(props.n_batch):
f = broadcasting_vmap(f)
data_unique, indices_unique, nse_out = f(data, indices)
if nse is None:
nse = jnp.max(nse_out)
data_unique = lax.slice_in_dim(data_unique, 0, nse, axis=props.n_batch)
indices_unique = lax.slice_in_dim(indices_unique, 0, nse, axis=props.n_batch)
return data_unique, indices_unique
def _bcoo_sum_duplicates_unbatched(data, indices, *, shape, nse, remove_zeros):
props = _validate_bcoo(data, indices, shape)
assert props.n_batch == 0
if not props.n_sparse:
nse = 1 if nse is None else nse
data_unique = jnp.zeros_like(data, shape=(nse, *data.shape[1:])).at[0].set(data.sum(0))
indices_unique = jnp.zeros_like(indices, shape=(nse, 0))
return data_unique, indices_unique, nse
fill_value = jnp.expand_dims(jnp.array(shape[:props.n_sparse], dtype=indices.dtype),
range(indices.ndim - 1))
out_of_bounds = (indices >= fill_value).any(-1, keepdims=True)
if remove_zeros:
data_all_zero = (data == 0).all(range(props.n_batch + 1, data.ndim))[:, None]
out_of_bounds = out_of_bounds | data_all_zero
indices = jnp.where(out_of_bounds, fill_value, indices)
if nse is None:
indices_unique, inv_idx, nse = _unique(
indices, axis=0, return_inverse=True, return_true_size=True,
size=props.nse, fill_value=fill_value)
nse = nse - (indices == fill_value).any()
else:
indices_unique, inv_idx = jnp.unique(
indices, axis=0, return_inverse=True,
size=nse, fill_value=fill_value)
data_shape = [indices_unique.shape[0], *data.shape[1:]]
data_unique = jnp.zeros(data_shape, data.dtype).at[inv_idx].add(data)
oob_mask = jnp.all(indices_unique == fill_value, 1)
data_unique = jnp.where(oob_mask[(...,) + props.n_dense * (None,)], 0, data_unique)
return data_unique, indices_unique, nse
f = vmap(f)
indices = f(indices, mask)
return bcoo_sum_duplicates(BCOO((data, indices), shape=shape), nse=nse)
def _unbatch_bcoo(data, indices, shape):
n_batch = _validate_bcoo(data, indices, shape).n_batch
@ -1043,8 +1011,8 @@ def _bcoo_spdot_general_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices,
assert max(rhs_contracting, default=-1) < rhs.n_sparse
out_shape = (
[s for i, s in enumerate(lhs_shape) if i not in lhs_contracting] +
[s for i, s in enumerate(rhs_shape) if i not in rhs_contracting])
*(s for i, s in enumerate(lhs_shape) if i not in lhs_contracting),
*(s for i, s in enumerate(rhs_shape) if i not in rhs_contracting))
lhs_i = lhs_indices[:, jnp.array(lhs_contracting, dtype=int)]
rhs_i = rhs_indices[:, jnp.array(rhs_contracting, dtype=int)]
@ -1070,8 +1038,9 @@ def _bcoo_spdot_general_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices,
out_indices = out_indices.at[:, :, lhs_j.shape[-1]:].set(rhs_j[None, :])
out_indices = out_indices.reshape(len(out_data), out_indices.shape[-1])
out_nse = (lhs.nse if lhs_j.shape[1] else 1) * (rhs.nse if rhs_j.shape[1] else 1)
# Note: remove_zeros=True is incompatible with autodiff.
return _bcoo_sum_duplicates(out_data, out_indices, out_shape, nse=out_nse, remove_zeros=False)
# Note: we do not eliminate zeros here, because it can cause issues with autodiff.
# See https://github.com/google/jax/issues/10163.
return _bcoo_sum_duplicates(out_data, out_indices, spinfo=BCOOInfo(shape=out_shape), nse=out_nse)
@bcoo_spdot_general_p.def_impl
def _bcoo_spdot_general_impl(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo: BCOOInfo, rhs_spinfo: BCOOInfo, dimension_numbers):
@ -1283,7 +1252,6 @@ def _bcoo_sort_indices_batching_rule(batched_args, batch_dims, *, spinfo):
out_axes = (0, None)
return (data_out, indices_out), out_axes
def _bcoo_sort_indices_jvp(primals, tangents, *, spinfo):
props = _validate_bcoo(*primals, spinfo.shape)
if props.n_sparse == 0:
@ -1311,6 +1279,155 @@ ad.primitive_jvps[bcoo_sort_indices_p] = _bcoo_sort_indices_jvp
batching.primitive_batchers[bcoo_sort_indices_p] = _bcoo_sort_indices_batching_rule
mlir.register_lowering(bcoo_sort_indices_p, _bcoo_sort_indices_mhlo)
#----------------------------------------------------------------------
# bcoo_sum_duplicates
# Utility to sum duplicate indices in a BCOO array representation.
bcoo_sum_duplicates_p = core.Primitive("bcoo_sum_duplicates")
bcoo_sum_duplicates_p.multiple_results = True
def bcoo_sum_duplicates(mat, nse=None):
"""Sums duplicate indices within a BCOO array, returning an array with sorted indices.
Args:
mat : BCOO array
nse : integer (optional). The number of specified elements in the output matrix. This must
be specified for bcoo_sum_duplicates to be compatible with JIT and other JAX transformations.
If not specified, the optimal nse will be computed based on the contents of the data and
index arrays. If specified nse is larger than necessary, data and index arrays will be padded
with standard fill values. If smaller than necessary, data elements will be dropped from the
output matrix.
Returns:
mat_out : BCOO array with sorted indices and no duplicate indices.
"""
data, indices = _bcoo_sum_duplicates(mat.data, mat.indices, spinfo=mat._info, nse=nse)
return BCOO((data, indices), shape=mat.shape)
def _bcoo_sum_duplicates(data, indices, *, spinfo, nse):
if nse is not None:
nse = core.concrete_or_error(operator.index, nse, "nse argument of bcoo_sum_duplicates.")
return bcoo_sum_duplicates_p.bind(data, indices, spinfo=spinfo, nse=nse)
@bcoo_sum_duplicates_p.def_impl
def _bcoo_sum_duplicates_impl(data, indices, *, spinfo, nse):
props = _validate_bcoo(data, indices, spinfo.shape)
f = functools.partial(_bcoo_sum_duplicates_unbatched, shape=spinfo.shape[props.n_batch:])
for _ in range(props.n_batch):
f = vmap(f)
indices_out, mapping, nse_batched = f(indices)
if nse is None:
nse = 1 if props.n_sparse == 0 else nse_batched.max()
indices_out = _adjust_indices_nse(indices_out, nse=nse, shape=spinfo.shape)
if props.n_sparse == 0:
data = data.sum(props.n_batch, keepdims=True)
data_out = jnp.empty((*map(max, indices.shape[:props.n_batch], data.shape[:props.n_batch]),
nse, *data.shape[props.n_batch + 1:]), dtype=data.dtype)
permute = lambda d_out, m, d: d_out.at[m].add(d, mode='drop')
for _ in range(props.n_batch):
permute = broadcasting_vmap(permute)
data_out = permute(data_out, mapping, data)
return data_out, indices_out
def _adjust_indices_nse(indices, *, nse, shape):
props = _validate_bcoo_indices(indices, shape)
if nse <= props.nse:
indices = indices[..., :nse, :]
else:
fill = lax.broadcast_in_dim(
operand=jnp.array(shape[props.n_batch:props.n_batch + props.n_sparse], dtype=indices.dtype),
shape=(*indices.shape[:-2], nse - props.nse, indices.shape[-1]),
broadcast_dimensions=(indices.ndim - 1,)
)
indices = lax.concatenate([indices, fill], dimension=indices.ndim - 2)
return indices
def _bcoo_sum_duplicates_unbatched(indices, *, shape):
props = _validate_bcoo_indices(indices, shape)
if props.n_sparse == 0:
nse = 1
mapping = jnp.zeros(nse, dtype='int32')
indices_out = jnp.zeros_like(indices, shape=(nse, props.n_sparse))
return indices_out, mapping, nse
fill_value = jnp.expand_dims(jnp.array(shape[:props.n_sparse], dtype=indices.dtype), (0,))
out_of_bounds = (indices >= fill_value).any(-1, keepdims=True)
indices = jnp.where(out_of_bounds, fill_value, indices)
indices_unique, inv_idx, nse = _unique(
indices, axis=0, return_inverse=True, return_true_size=True,
size=props.nse, fill_value=fill_value)
nse = nse - (indices == fill_value).any()
return indices_unique, inv_idx, nse
@bcoo_sum_duplicates_p.def_abstract_eval
def _bcoo_sum_duplicates_abstract_eval(data, indices, *, spinfo, nse):
if nse is None:
raise ValueError("bcoo_sum_duplicates: nse must be specified when using the function within "
"jit, vmap, and other transformations requiring abstract evaluation.")
props = _validate_bcoo(data, indices, spinfo.shape)
indices_out = core.ShapedArray((*indices.shape[:props.n_batch], nse, props.n_sparse),
dtype=indices.dtype, weak_type=indices.weak_type)
data_out = core.ShapedArray(
(*map(max, indices.shape[:props.n_batch], data.shape[:props.n_batch]),
nse, *data.shape[props.n_batch + 1:]), data.dtype, weak_type=data.weak_type)
return data_out, indices_out
def _bcoo_sum_duplicates_batching_rule(batched_args, batch_dims, *, spinfo, nse):
data, indices = batched_args
if any(b not in [0, None] for b in batch_dims):
raise NotImplementedError(f"batch_dims={batch_dims}. Only 0 and None are supported.")
if batch_dims[0] is None:
data = data[None, ...]
if batch_dims[1] is None:
indices = indices[None, ...]
new_spinfo = BCOOInfo(shape=(max(data.shape[0], indices.shape[0]), *spinfo.shape))
data_out, indices_out = bcoo_sum_duplicates_p.bind(data, indices, spinfo=new_spinfo, nse=nse)
out_axes = (0, 0)
# Note: if data is unbatched on input, it will be batched on output.
# However, if indices are unbatched on input, they will be unbatched on output.
if batch_dims[1] is None:
indices_out = indices_out[0]
out_axes = (0, None)
return (data_out, indices_out), tuple(out_axes)
def _bcoo_sum_duplicates_jvp(primals, tangents, *, spinfo, nse):
props = _validate_bcoo(*primals, spinfo.shape)
data, indices = primals
data_dot, _ = tangents
f = functools.partial(_bcoo_sum_duplicates_unbatched, shape=spinfo.shape[props.n_batch:])
for _ in range(props.n_batch):
f = broadcasting_vmap(f)
indices_out, mapping, nse_batched = f(indices)
if nse is None:
nse = jnp.sum(nse_batched)
try:
nse = core.concrete_or_error(operator.index, nse, "nse argument of bcoo_sum_duplicates.")
except core.ConcretizationTypeError:
raise ValueError("bcoo_sum_duplicates: nse must be specified when using the function within "
"jit, vmap, and other transformations requiring abstract evaluation.")
indices_out = _adjust_indices_nse(indices_out, nse=nse, shape=spinfo.shape)
if props.n_sparse == 0:
data = data.sum(props.n_batch, keepdims=True)
data_dot = data_dot.sum(props.n_batch, keepdims=True)
data_out = jnp.empty((*map(max, indices.shape[:props.n_batch], data.shape[:props.n_batch]),
nse, *data.shape[props.n_batch + 1:]), dtype=data.dtype)
data_dot_out = data_out
permute = lambda d_out, m, d: d_out.at[m].add(d, mode='drop')
for _ in range(props.n_batch):
permute = broadcasting_vmap(permute)
data_out = permute(data_out, mapping, data)
indices_dot_out = ad.Zero.from_value(indices_out)
data_dot_out = ad.Zero.from_value(data_out) if type(data_dot) is ad.Zero else permute(data_dot_out, mapping, data_dot)
return (data_out, indices_out), (data_dot_out, indices_dot_out)
_bcoo_sum_duplicates_mhlo = mlir.lower_fun(
_bcoo_sum_duplicates_impl, multiple_results=True)
ad.primitive_jvps[bcoo_sum_duplicates_p] = _bcoo_sum_duplicates_jvp
batching.primitive_batchers[bcoo_sum_duplicates_p] = _bcoo_sum_duplicates_batching_rule
mlir.register_lowering(bcoo_sum_duplicates_p, _bcoo_sum_duplicates_mhlo)
#----------------------------------------------------------------------
# BCOO functions that maybe should be primitives?
@ -1712,9 +1829,10 @@ class BCOO(JAXSparse):
will remain among the specified elements. Note: remove_zeros=True is incompatible
with autodiff.
"""
data, indices = _bcoo_sum_duplicates(self.data, self.indices, self.shape,
nse=nse, remove_zeros=remove_zeros)
return BCOO((data, indices), shape=self.shape)
if remove_zeros:
return bcoo_eliminate_zeros(self, nse=nse)
else:
return bcoo_sum_duplicates(self, nse=nse)
def sort_indices(self):
"""Return a copy of the matrix with indices sorted."""

View File

@ -80,4 +80,4 @@ def random_bcoo(key, shape, *, dtype=jnp.float_, indices_dtype=jnp.int_,
data = generator(data_key, shape=data_shape, dtype=dtype, **kwds)
indices = _indices(index_keys).reshape(indices_shape).astype(indices_dtype)
mat = sparse.BCOO((data, indices), shape=shape)
return mat.sum_duplicates() if sorted_indices else mat
return mat.sort_indices() if sorted_indices else mat

View File

@ -1539,11 +1539,13 @@ class BCOOTest(jtu.JaxTestCase):
for nse in [None, np.prod(shape) - 1]
for remove_zeros in [True, False]))
def test_bcoo_sum_duplicates(self, shape, dtype, n_batch, n_dense, nse, remove_zeros):
rng = self.rng()
# Create a matrix with duplicate indices
rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero)
M = sparse.BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch, n_dense=n_dense)
for i, s in enumerate(shape[n_batch:len(shape) - n_dense]):
M.indices = M.indices.at[..., i].set(rng.randint(0, s, size=M.nse))
new_indices = jnp.concatenate([M.indices, M.indices], axis=n_batch)
new_data = jnp.concatenate([M.data, M.data], axis=n_batch)
M = sparse.BCOO((new_data, new_indices), shape=M.shape)
dedupe = partial(M.sum_duplicates, nse=nse, remove_zeros=remove_zeros)
jit_dedupe = jax.jit(dedupe)
@ -1553,13 +1555,48 @@ class BCOOTest(jtu.JaxTestCase):
self.assertEqual(M_dedup.nse, nse)
if not nse:
with self.assertRaisesRegex(ValueError, ".*nse argument"):
with self.assertRaisesRegex(ValueError, ".*nse must be specified.*"):
jit_dedupe()
else:
M_dedup = jit_dedupe()
self.assertAllClose(M.todense(), M_dedup.todense())
self.assertEqual(M_dedup.nse, nse)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_nbatch={}_ndense={}_nse={}".format(
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense, nse),
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense, "nse": nse}
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
for dtype in jtu.dtypes.floating
for n_batch in range(len(shape) + 1)
for n_dense in range(len(shape) + 1 - n_batch)
for nse in [None, 5, np.prod(shape) - 1]
))
def test_bcoo_sum_duplicates_ad(self, shape, dtype, n_batch, n_dense, nse):
# Create a matrix with duplicate indices
rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero)
M = sparse.BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch, n_dense=n_dense)
new_indices = jnp.concatenate([M.indices, M.indices], axis=n_batch)
new_data = jnp.concatenate([M.data, M.data], axis=n_batch)
M = sparse.BCOO((new_data, new_indices), shape=M.shape)
# TODO(jakevdp) address this corner case.
if M.nse == 0:
self.skipTest("known failure for nse=0")
if nse == 'all':
nse = M.nse
def dedupe(data, nse=nse):
mat = sparse.BCOO((data, M.indices), shape=M.shape)
mat_dedup = mat.sum_duplicates(nse=nse, remove_zeros=False)
return mat_dedup.data
data_dot_fwd = jax.jacfwd(dedupe)(M.data)
data_dot_rev = jax.jacrev(dedupe)(M.data)
self.assertAllClose(data_dot_fwd, data_dot_rev)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
@ -1598,7 +1635,6 @@ class BCOOTest(jtu.JaxTestCase):
def sort_indices(data):
return sparse.BCOO((data, M.indices), shape=M.shape).sort_indices().data
# Forward-mode
data_dot_fwd = jax.jacfwd(sort_indices)(M.data)
data_dot_rev = jax.jacrev(sort_indices)(M.data)