add bcoo_reduce_sum() function

This commit is contained in:
Jake VanderPlas 2021-06-11 13:19:54 -07:00
parent 31e9c65f2a
commit 2113d9c34d
2 changed files with 75 additions and 2 deletions

View File

@ -979,6 +979,59 @@ batching.primitive_batchers[bcoo_dot_general_p] = _bcoo_dot_general_batch_rule
xla.translations[bcoo_dot_general_p] = xla.lower_fun(
_bcoo_dot_general_impl, multiple_results=False)
#----------------------------------------------------------------------
# BCOO functions that maybe should be primitives?
def _tuple_replace(tup, ind, val):
return tuple(val if i == ind else t for i, t in enumerate(tup))
def bcoo_reduce_sum(data, indices, *, shape, axes):
assert all(0 <= a < len(shape) for a in axes)
axes = sorted(set(axes))
n_sparse, nse = indices.shape[-2:]
n_batch = indices.ndim - 2
# Sum over dense dimensions -> sum over data
dense_axes = tuple(ax - n_sparse + 1 for ax in axes if ax >= n_batch + n_sparse)
data = data.sum(dense_axes)
# Sum over sparse dimensions -> drop index; sum is implicit
sparse_idx = [i for i in range(n_sparse) if i + n_batch not in axes]
if not sparse_idx:
indices = jnp.zeros(_tuple_replace(indices.shape, n_batch, 0), indices.dtype)
else:
indices = indices[..., np.array(sparse_idx), :]
# Sum over batch dimensions -> reshape into nse
batch_axes = {ax for ax in axes if ax < n_batch}
# First handle broadcasted batch dimensions
for ax in batch_axes:
if data.shape[ax] == 1:
if indices.shape[ax] == 1:
data = data * shape[ax]
else:
data = lax.broadcast_in_dim(data, _tuple_replace(data.shape, ax, shape[ax]), tuple(range(data.ndim)))
else:
if indices.shape[ax] == 1:
data = data.sum(ax)
assert data.shape[ax] == indices.shape[ax]
new_batch_dims = tuple(sorted(set(range(n_batch)) - batch_axes))
new_batch_shape = tuple(data.shape[i] for i in new_batch_dims)
new_nse = int(nse * np.prod([data.shape[i] for i in batch_axes]))
data = lax.reshape(data,
new_batch_shape + (new_nse,) + data.shape[n_batch + 1:],
new_batch_dims + tuple(batch_axes) + tuple(range(n_batch, data.ndim)))
indices = lax.reshape(indices,
new_batch_shape + (indices.shape[n_batch], new_nse),
new_batch_dims + (n_batch,) + tuple(batch_axes) + tuple(range(n_batch + 1, indices.ndim)))
out_shape = tuple(shape[i] for i in range(len(shape)) if i not in axes)
return data, indices, out_shape
#----------------------------------------------------------------------
# Sparse objects (APIs subject to change)
class JAXSparse:
@ -1155,8 +1208,8 @@ class BCOO(JAXSparse):
super().__init__(args, shape=shape)
@classmethod
def fromdense(cls, mat, *, nnz=None, index_dtype=np.int32):
return cls(bcoo_fromdense(mat, nse=nnz, index_dtype=index_dtype), shape=mat.shape)
def fromdense(cls, mat, *, nnz=None, index_dtype=np.int32, n_dense=0, n_batch=0):
return cls(bcoo_fromdense(mat, nse=nnz, index_dtype=index_dtype, n_dense=n_dense, n_batch=n_batch), shape=mat.shape)
@api.jit
def todense(self):

View File

@ -715,6 +715,26 @@ class BCOOTest(jtu.JaxTestCase):
self.assertAllClose(M1, M2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_nbatch={}_ndense={}_axes={}".format(
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense, axes),
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense, "axes": axes}
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
for n_batch in range(len(shape) + 1)
for n_dense in range(len(shape) + 1 - n_batch)
for naxes in range(len(shape))
for axes in itertools.combinations(range(len(shape)), naxes)))
def test_bcoo_reduce_sum(self, shape, dtype, n_batch, n_dense, axes):
rng = rand_sparse(self.rng())
M = rng(shape, dtype)
data, indices = sparse_ops.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
data_out, indices_out, shape_out = sparse_ops.bcoo_reduce_sum(data, indices, shape=shape, axes=axes)
result_dense = M.sum(axes)
result_sparse = sparse_ops.bcoo_todense(data_out, indices_out, shape=shape_out)
tol = {np.float32: 1E-6, np.float64: 1E-14}
self.assertAllClose(result_dense, result_sparse, atol=tol, rtol=tol)
class SparseObjectTest(jtu.JaxTestCase):
@parameterized.named_parameters(