mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sparse] make bcoo_sum_duplicates a primitive
This commit is contained in:
parent
27d32e4a74
commit
edae0ac31f
@ -204,6 +204,8 @@ from jax.experimental.sparse.bcoo import (
|
||||
bcoo_sort_indices as bcoo_sort_indices,
|
||||
bcoo_sort_indices_p as bcoo_sort_indices_p,
|
||||
bcoo_spdot_general_p as bcoo_spdot_general_p,
|
||||
bcoo_sum_duplicates as bcoo_sum_duplicates,
|
||||
bcoo_sum_duplicates_p as bcoo_sum_duplicates_p,
|
||||
bcoo_todense as bcoo_todense,
|
||||
bcoo_todense_p as bcoo_todense_p,
|
||||
bcoo_transpose as bcoo_transpose,
|
||||
|
@ -79,53 +79,21 @@ def _bcoo_nse(mat, n_batch=0, n_dense=0):
|
||||
mask = mask.sum(list(range(n_batch, mask.ndim)))
|
||||
return mask.max()
|
||||
|
||||
# TODO(jakevdp): add a custom autodiff rule that errors if remove_zeros=True, because
|
||||
# it produces wrong values. See https://github.com/google/jax/issues/10163
|
||||
def _bcoo_sum_duplicates(data, indices, shape, nse=None, remove_zeros=True):
|
||||
if nse is None and isinstance(jnp.array(0), core.Tracer):
|
||||
raise ValueError("When used with JIT, vmap, or another transform, sum_duplicates() "
|
||||
"requires passing a non-None value for the nse argument.")
|
||||
# TODO(jakevdp) this can be problematic when used with autodiff; see
|
||||
# https://github.com/google/jax/issues/10163. Should this be a primitive?
|
||||
# Alternatively, maybe roll this into bcoo_sum_duplicates as an optional argument.
|
||||
def bcoo_eliminate_zeros(mat, nse=None):
|
||||
data, indices, shape = mat.data, mat.indices, mat.shape
|
||||
props = _validate_bcoo(data, indices, shape)
|
||||
f = functools.partial(_bcoo_sum_duplicates_unbatched, shape=shape[props.n_batch:],
|
||||
nse=nse, remove_zeros=remove_zeros)
|
||||
mask = (data == 0).all(tuple(range(props.n_batch + 1, data.ndim)))
|
||||
dims_to_contract = tuple(i for i, s in enumerate(indices.shape[:props.n_batch]) if s == 1)
|
||||
mask = mask.all(dims_to_contract, keepdims=True)
|
||||
fill_value = jnp.array(shape[props.n_batch:props.n_batch + props.n_sparse], dtype=indices.dtype)
|
||||
f = lambda i, m: jnp.where(m[:, None], fill_value[None, :], i)
|
||||
for _ in range(props.n_batch):
|
||||
f = broadcasting_vmap(f)
|
||||
data_unique, indices_unique, nse_out = f(data, indices)
|
||||
if nse is None:
|
||||
nse = jnp.max(nse_out)
|
||||
data_unique = lax.slice_in_dim(data_unique, 0, nse, axis=props.n_batch)
|
||||
indices_unique = lax.slice_in_dim(indices_unique, 0, nse, axis=props.n_batch)
|
||||
return data_unique, indices_unique
|
||||
|
||||
def _bcoo_sum_duplicates_unbatched(data, indices, *, shape, nse, remove_zeros):
|
||||
props = _validate_bcoo(data, indices, shape)
|
||||
assert props.n_batch == 0
|
||||
if not props.n_sparse:
|
||||
nse = 1 if nse is None else nse
|
||||
data_unique = jnp.zeros_like(data, shape=(nse, *data.shape[1:])).at[0].set(data.sum(0))
|
||||
indices_unique = jnp.zeros_like(indices, shape=(nse, 0))
|
||||
return data_unique, indices_unique, nse
|
||||
fill_value = jnp.expand_dims(jnp.array(shape[:props.n_sparse], dtype=indices.dtype),
|
||||
range(indices.ndim - 1))
|
||||
out_of_bounds = (indices >= fill_value).any(-1, keepdims=True)
|
||||
if remove_zeros:
|
||||
data_all_zero = (data == 0).all(range(props.n_batch + 1, data.ndim))[:, None]
|
||||
out_of_bounds = out_of_bounds | data_all_zero
|
||||
indices = jnp.where(out_of_bounds, fill_value, indices)
|
||||
if nse is None:
|
||||
indices_unique, inv_idx, nse = _unique(
|
||||
indices, axis=0, return_inverse=True, return_true_size=True,
|
||||
size=props.nse, fill_value=fill_value)
|
||||
nse = nse - (indices == fill_value).any()
|
||||
else:
|
||||
indices_unique, inv_idx = jnp.unique(
|
||||
indices, axis=0, return_inverse=True,
|
||||
size=nse, fill_value=fill_value)
|
||||
data_shape = [indices_unique.shape[0], *data.shape[1:]]
|
||||
data_unique = jnp.zeros(data_shape, data.dtype).at[inv_idx].add(data)
|
||||
oob_mask = jnp.all(indices_unique == fill_value, 1)
|
||||
data_unique = jnp.where(oob_mask[(...,) + props.n_dense * (None,)], 0, data_unique)
|
||||
return data_unique, indices_unique, nse
|
||||
f = vmap(f)
|
||||
indices = f(indices, mask)
|
||||
return bcoo_sum_duplicates(BCOO((data, indices), shape=shape), nse=nse)
|
||||
|
||||
def _unbatch_bcoo(data, indices, shape):
|
||||
n_batch = _validate_bcoo(data, indices, shape).n_batch
|
||||
@ -1043,8 +1011,8 @@ def _bcoo_spdot_general_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices,
|
||||
assert max(rhs_contracting, default=-1) < rhs.n_sparse
|
||||
|
||||
out_shape = (
|
||||
[s for i, s in enumerate(lhs_shape) if i not in lhs_contracting] +
|
||||
[s for i, s in enumerate(rhs_shape) if i not in rhs_contracting])
|
||||
*(s for i, s in enumerate(lhs_shape) if i not in lhs_contracting),
|
||||
*(s for i, s in enumerate(rhs_shape) if i not in rhs_contracting))
|
||||
|
||||
lhs_i = lhs_indices[:, jnp.array(lhs_contracting, dtype=int)]
|
||||
rhs_i = rhs_indices[:, jnp.array(rhs_contracting, dtype=int)]
|
||||
@ -1070,8 +1038,9 @@ def _bcoo_spdot_general_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices,
|
||||
out_indices = out_indices.at[:, :, lhs_j.shape[-1]:].set(rhs_j[None, :])
|
||||
out_indices = out_indices.reshape(len(out_data), out_indices.shape[-1])
|
||||
out_nse = (lhs.nse if lhs_j.shape[1] else 1) * (rhs.nse if rhs_j.shape[1] else 1)
|
||||
# Note: remove_zeros=True is incompatible with autodiff.
|
||||
return _bcoo_sum_duplicates(out_data, out_indices, out_shape, nse=out_nse, remove_zeros=False)
|
||||
# Note: we do not eliminate zeros here, because it can cause issues with autodiff.
|
||||
# See https://github.com/google/jax/issues/10163.
|
||||
return _bcoo_sum_duplicates(out_data, out_indices, spinfo=BCOOInfo(shape=out_shape), nse=out_nse)
|
||||
|
||||
@bcoo_spdot_general_p.def_impl
|
||||
def _bcoo_spdot_general_impl(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo: BCOOInfo, rhs_spinfo: BCOOInfo, dimension_numbers):
|
||||
@ -1283,7 +1252,6 @@ def _bcoo_sort_indices_batching_rule(batched_args, batch_dims, *, spinfo):
|
||||
out_axes = (0, None)
|
||||
return (data_out, indices_out), out_axes
|
||||
|
||||
|
||||
def _bcoo_sort_indices_jvp(primals, tangents, *, spinfo):
|
||||
props = _validate_bcoo(*primals, spinfo.shape)
|
||||
if props.n_sparse == 0:
|
||||
@ -1311,6 +1279,155 @@ ad.primitive_jvps[bcoo_sort_indices_p] = _bcoo_sort_indices_jvp
|
||||
batching.primitive_batchers[bcoo_sort_indices_p] = _bcoo_sort_indices_batching_rule
|
||||
mlir.register_lowering(bcoo_sort_indices_p, _bcoo_sort_indices_mhlo)
|
||||
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
# bcoo_sum_duplicates
|
||||
# Utility to sum duplicate indices in a BCOO array representation.
|
||||
|
||||
bcoo_sum_duplicates_p = core.Primitive("bcoo_sum_duplicates")
|
||||
bcoo_sum_duplicates_p.multiple_results = True
|
||||
|
||||
def bcoo_sum_duplicates(mat, nse=None):
|
||||
"""Sums duplicate indices within a BCOO array, returning an array with sorted indices.
|
||||
|
||||
Args:
|
||||
mat : BCOO array
|
||||
nse : integer (optional). The number of specified elements in the output matrix. This must
|
||||
be specified for bcoo_sum_duplicates to be compatible with JIT and other JAX transformations.
|
||||
If not specified, the optimal nse will be computed based on the contents of the data and
|
||||
index arrays. If specified nse is larger than necessary, data and index arrays will be padded
|
||||
with standard fill values. If smaller than necessary, data elements will be dropped from the
|
||||
output matrix.
|
||||
|
||||
Returns:
|
||||
mat_out : BCOO array with sorted indices and no duplicate indices.
|
||||
"""
|
||||
data, indices = _bcoo_sum_duplicates(mat.data, mat.indices, spinfo=mat._info, nse=nse)
|
||||
return BCOO((data, indices), shape=mat.shape)
|
||||
|
||||
def _bcoo_sum_duplicates(data, indices, *, spinfo, nse):
|
||||
if nse is not None:
|
||||
nse = core.concrete_or_error(operator.index, nse, "nse argument of bcoo_sum_duplicates.")
|
||||
return bcoo_sum_duplicates_p.bind(data, indices, spinfo=spinfo, nse=nse)
|
||||
|
||||
@bcoo_sum_duplicates_p.def_impl
|
||||
def _bcoo_sum_duplicates_impl(data, indices, *, spinfo, nse):
|
||||
props = _validate_bcoo(data, indices, spinfo.shape)
|
||||
f = functools.partial(_bcoo_sum_duplicates_unbatched, shape=spinfo.shape[props.n_batch:])
|
||||
for _ in range(props.n_batch):
|
||||
f = vmap(f)
|
||||
indices_out, mapping, nse_batched = f(indices)
|
||||
if nse is None:
|
||||
nse = 1 if props.n_sparse == 0 else nse_batched.max()
|
||||
indices_out = _adjust_indices_nse(indices_out, nse=nse, shape=spinfo.shape)
|
||||
if props.n_sparse == 0:
|
||||
data = data.sum(props.n_batch, keepdims=True)
|
||||
data_out = jnp.empty((*map(max, indices.shape[:props.n_batch], data.shape[:props.n_batch]),
|
||||
nse, *data.shape[props.n_batch + 1:]), dtype=data.dtype)
|
||||
permute = lambda d_out, m, d: d_out.at[m].add(d, mode='drop')
|
||||
for _ in range(props.n_batch):
|
||||
permute = broadcasting_vmap(permute)
|
||||
data_out = permute(data_out, mapping, data)
|
||||
return data_out, indices_out
|
||||
|
||||
def _adjust_indices_nse(indices, *, nse, shape):
|
||||
props = _validate_bcoo_indices(indices, shape)
|
||||
if nse <= props.nse:
|
||||
indices = indices[..., :nse, :]
|
||||
else:
|
||||
fill = lax.broadcast_in_dim(
|
||||
operand=jnp.array(shape[props.n_batch:props.n_batch + props.n_sparse], dtype=indices.dtype),
|
||||
shape=(*indices.shape[:-2], nse - props.nse, indices.shape[-1]),
|
||||
broadcast_dimensions=(indices.ndim - 1,)
|
||||
)
|
||||
indices = lax.concatenate([indices, fill], dimension=indices.ndim - 2)
|
||||
return indices
|
||||
|
||||
def _bcoo_sum_duplicates_unbatched(indices, *, shape):
|
||||
props = _validate_bcoo_indices(indices, shape)
|
||||
if props.n_sparse == 0:
|
||||
nse = 1
|
||||
mapping = jnp.zeros(nse, dtype='int32')
|
||||
indices_out = jnp.zeros_like(indices, shape=(nse, props.n_sparse))
|
||||
return indices_out, mapping, nse
|
||||
fill_value = jnp.expand_dims(jnp.array(shape[:props.n_sparse], dtype=indices.dtype), (0,))
|
||||
out_of_bounds = (indices >= fill_value).any(-1, keepdims=True)
|
||||
indices = jnp.where(out_of_bounds, fill_value, indices)
|
||||
indices_unique, inv_idx, nse = _unique(
|
||||
indices, axis=0, return_inverse=True, return_true_size=True,
|
||||
size=props.nse, fill_value=fill_value)
|
||||
nse = nse - (indices == fill_value).any()
|
||||
return indices_unique, inv_idx, nse
|
||||
|
||||
@bcoo_sum_duplicates_p.def_abstract_eval
|
||||
def _bcoo_sum_duplicates_abstract_eval(data, indices, *, spinfo, nse):
|
||||
if nse is None:
|
||||
raise ValueError("bcoo_sum_duplicates: nse must be specified when using the function within "
|
||||
"jit, vmap, and other transformations requiring abstract evaluation.")
|
||||
props = _validate_bcoo(data, indices, spinfo.shape)
|
||||
indices_out = core.ShapedArray((*indices.shape[:props.n_batch], nse, props.n_sparse),
|
||||
dtype=indices.dtype, weak_type=indices.weak_type)
|
||||
data_out = core.ShapedArray(
|
||||
(*map(max, indices.shape[:props.n_batch], data.shape[:props.n_batch]),
|
||||
nse, *data.shape[props.n_batch + 1:]), data.dtype, weak_type=data.weak_type)
|
||||
return data_out, indices_out
|
||||
|
||||
def _bcoo_sum_duplicates_batching_rule(batched_args, batch_dims, *, spinfo, nse):
|
||||
data, indices = batched_args
|
||||
if any(b not in [0, None] for b in batch_dims):
|
||||
raise NotImplementedError(f"batch_dims={batch_dims}. Only 0 and None are supported.")
|
||||
if batch_dims[0] is None:
|
||||
data = data[None, ...]
|
||||
if batch_dims[1] is None:
|
||||
indices = indices[None, ...]
|
||||
new_spinfo = BCOOInfo(shape=(max(data.shape[0], indices.shape[0]), *spinfo.shape))
|
||||
data_out, indices_out = bcoo_sum_duplicates_p.bind(data, indices, spinfo=new_spinfo, nse=nse)
|
||||
out_axes = (0, 0)
|
||||
# Note: if data is unbatched on input, it will be batched on output.
|
||||
# However, if indices are unbatched on input, they will be unbatched on output.
|
||||
if batch_dims[1] is None:
|
||||
indices_out = indices_out[0]
|
||||
out_axes = (0, None)
|
||||
return (data_out, indices_out), tuple(out_axes)
|
||||
|
||||
def _bcoo_sum_duplicates_jvp(primals, tangents, *, spinfo, nse):
|
||||
props = _validate_bcoo(*primals, spinfo.shape)
|
||||
|
||||
data, indices = primals
|
||||
data_dot, _ = tangents
|
||||
f = functools.partial(_bcoo_sum_duplicates_unbatched, shape=spinfo.shape[props.n_batch:])
|
||||
for _ in range(props.n_batch):
|
||||
f = broadcasting_vmap(f)
|
||||
indices_out, mapping, nse_batched = f(indices)
|
||||
if nse is None:
|
||||
nse = jnp.sum(nse_batched)
|
||||
try:
|
||||
nse = core.concrete_or_error(operator.index, nse, "nse argument of bcoo_sum_duplicates.")
|
||||
except core.ConcretizationTypeError:
|
||||
raise ValueError("bcoo_sum_duplicates: nse must be specified when using the function within "
|
||||
"jit, vmap, and other transformations requiring abstract evaluation.")
|
||||
indices_out = _adjust_indices_nse(indices_out, nse=nse, shape=spinfo.shape)
|
||||
if props.n_sparse == 0:
|
||||
data = data.sum(props.n_batch, keepdims=True)
|
||||
data_dot = data_dot.sum(props.n_batch, keepdims=True)
|
||||
data_out = jnp.empty((*map(max, indices.shape[:props.n_batch], data.shape[:props.n_batch]),
|
||||
nse, *data.shape[props.n_batch + 1:]), dtype=data.dtype)
|
||||
data_dot_out = data_out
|
||||
permute = lambda d_out, m, d: d_out.at[m].add(d, mode='drop')
|
||||
for _ in range(props.n_batch):
|
||||
permute = broadcasting_vmap(permute)
|
||||
data_out = permute(data_out, mapping, data)
|
||||
indices_dot_out = ad.Zero.from_value(indices_out)
|
||||
data_dot_out = ad.Zero.from_value(data_out) if type(data_dot) is ad.Zero else permute(data_dot_out, mapping, data_dot)
|
||||
return (data_out, indices_out), (data_dot_out, indices_dot_out)
|
||||
|
||||
_bcoo_sum_duplicates_mhlo = mlir.lower_fun(
|
||||
_bcoo_sum_duplicates_impl, multiple_results=True)
|
||||
|
||||
ad.primitive_jvps[bcoo_sum_duplicates_p] = _bcoo_sum_duplicates_jvp
|
||||
batching.primitive_batchers[bcoo_sum_duplicates_p] = _bcoo_sum_duplicates_batching_rule
|
||||
mlir.register_lowering(bcoo_sum_duplicates_p, _bcoo_sum_duplicates_mhlo)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
# BCOO functions that maybe should be primitives?
|
||||
|
||||
@ -1712,9 +1829,10 @@ class BCOO(JAXSparse):
|
||||
will remain among the specified elements. Note: remove_zeros=True is incompatible
|
||||
with autodiff.
|
||||
"""
|
||||
data, indices = _bcoo_sum_duplicates(self.data, self.indices, self.shape,
|
||||
nse=nse, remove_zeros=remove_zeros)
|
||||
return BCOO((data, indices), shape=self.shape)
|
||||
if remove_zeros:
|
||||
return bcoo_eliminate_zeros(self, nse=nse)
|
||||
else:
|
||||
return bcoo_sum_duplicates(self, nse=nse)
|
||||
|
||||
def sort_indices(self):
|
||||
"""Return a copy of the matrix with indices sorted."""
|
||||
|
@ -80,4 +80,4 @@ def random_bcoo(key, shape, *, dtype=jnp.float_, indices_dtype=jnp.int_,
|
||||
data = generator(data_key, shape=data_shape, dtype=dtype, **kwds)
|
||||
indices = _indices(index_keys).reshape(indices_shape).astype(indices_dtype)
|
||||
mat = sparse.BCOO((data, indices), shape=shape)
|
||||
return mat.sum_duplicates() if sorted_indices else mat
|
||||
return mat.sort_indices() if sorted_indices else mat
|
||||
|
@ -1539,11 +1539,13 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
for nse in [None, np.prod(shape) - 1]
|
||||
for remove_zeros in [True, False]))
|
||||
def test_bcoo_sum_duplicates(self, shape, dtype, n_batch, n_dense, nse, remove_zeros):
|
||||
rng = self.rng()
|
||||
# Create a matrix with duplicate indices
|
||||
rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero)
|
||||
M = sparse.BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch, n_dense=n_dense)
|
||||
for i, s in enumerate(shape[n_batch:len(shape) - n_dense]):
|
||||
M.indices = M.indices.at[..., i].set(rng.randint(0, s, size=M.nse))
|
||||
new_indices = jnp.concatenate([M.indices, M.indices], axis=n_batch)
|
||||
new_data = jnp.concatenate([M.data, M.data], axis=n_batch)
|
||||
M = sparse.BCOO((new_data, new_indices), shape=M.shape)
|
||||
|
||||
dedupe = partial(M.sum_duplicates, nse=nse, remove_zeros=remove_zeros)
|
||||
jit_dedupe = jax.jit(dedupe)
|
||||
|
||||
@ -1553,13 +1555,48 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
self.assertEqual(M_dedup.nse, nse)
|
||||
|
||||
if not nse:
|
||||
with self.assertRaisesRegex(ValueError, ".*nse argument"):
|
||||
with self.assertRaisesRegex(ValueError, ".*nse must be specified.*"):
|
||||
jit_dedupe()
|
||||
else:
|
||||
M_dedup = jit_dedupe()
|
||||
self.assertAllClose(M.todense(), M_dedup.todense())
|
||||
self.assertEqual(M_dedup.nse, nse)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_nbatch={}_ndense={}_nse={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense, nse),
|
||||
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense, "nse": nse}
|
||||
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
||||
for dtype in jtu.dtypes.floating
|
||||
for n_batch in range(len(shape) + 1)
|
||||
for n_dense in range(len(shape) + 1 - n_batch)
|
||||
for nse in [None, 5, np.prod(shape) - 1]
|
||||
))
|
||||
def test_bcoo_sum_duplicates_ad(self, shape, dtype, n_batch, n_dense, nse):
|
||||
# Create a matrix with duplicate indices
|
||||
rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero)
|
||||
M = sparse.BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch, n_dense=n_dense)
|
||||
new_indices = jnp.concatenate([M.indices, M.indices], axis=n_batch)
|
||||
new_data = jnp.concatenate([M.data, M.data], axis=n_batch)
|
||||
M = sparse.BCOO((new_data, new_indices), shape=M.shape)
|
||||
|
||||
# TODO(jakevdp) address this corner case.
|
||||
if M.nse == 0:
|
||||
self.skipTest("known failure for nse=0")
|
||||
|
||||
if nse == 'all':
|
||||
nse = M.nse
|
||||
|
||||
def dedupe(data, nse=nse):
|
||||
mat = sparse.BCOO((data, M.indices), shape=M.shape)
|
||||
mat_dedup = mat.sum_duplicates(nse=nse, remove_zeros=False)
|
||||
return mat_dedup.data
|
||||
|
||||
data_dot_fwd = jax.jacfwd(dedupe)(M.data)
|
||||
data_dot_rev = jax.jacrev(dedupe)(M.data)
|
||||
|
||||
self.assertAllClose(data_dot_fwd, data_dot_rev)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
||||
@ -1598,7 +1635,6 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
def sort_indices(data):
|
||||
return sparse.BCOO((data, M.indices), shape=M.shape).sort_indices().data
|
||||
|
||||
# Forward-mode
|
||||
data_dot_fwd = jax.jacfwd(sort_indices)(M.data)
|
||||
data_dot_rev = jax.jacrev(sort_indices)(M.data)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user