[sparse]: change bcoo pad values to use OOB indices

This commit is contained in:
Jake VanderPlas 2021-10-15 10:50:05 -07:00
parent 46b9653e28
commit f424a90c71
2 changed files with 61 additions and 28 deletions

View File

@ -55,36 +55,24 @@ def _dedupe_bcoo(data, indices, shape):
if indices.shape[:props.n_batch] != data.shape[:props.n_batch]:
# TODO: handle broadcasted dimensions.
raise NotImplementedError("dedupe_bcoo for broadcasted dimensions.")
f = _dedupe_bcoo_one
f = functools.partial(_dedupe_bcoo_one,
shape=shape[props.n_batch:props.n_batch + props.n_sparse])
for _ in range(props.n_batch):
f = vmap(f)
return f(data, indices)
def _dedupe_bcoo_one(data, indices):
assert indices.ndim == 2
assert data.shape[:1] == indices.shape[:1]
def _dedupe_bcoo_one(data, indices, *, shape):
nse, = data.shape
assert indices.shape == (nse, len(shape))
if indices.shape[1] == 0:
return data, indices
# The following is similar to
# indices_unique, inv_idx = jnp.unique(indices, axis=0, return_inverse=True,
# size=indices.shape[0], fill_value=0)
# but modified to keep padding at the end of the resulting arrays.
is_padding = (indices == 0).all(1) & (data == 0)
perm = jnp.lexsort(indices[:, ::-1].T)
aux = indices[perm]
mask = jnp.ones(indices.shape[0], dtype=bool)
mask = mask.at[1:].set(jnp.any(aux[1:] != aux[:-1], 1))
mask = mask & ~is_padding[perm] # this is the padding modification.
imask = jnp.cumsum(mask) - 1
indices_unique = jnp.where(mask[:, None], aux, 0)[jnp.argsort(~mask)]
inv_idx = jnp.zeros_like(imask).at[perm].set(imask)
indices_unique, inv_idx = jnp.unique(indices, axis=0, return_inverse=True,
size=nse, fill_value=jnp.array(shape))
data_unique = jnp.zeros_like(data).at[inv_idx].add(data)
oob_mask = jnp.all(indices_unique == jnp.array(shape), 1)
data_unique = jnp.where(oob_mask, 0, data_unique)
return data_unique, indices_unique
def _unbatch_bcoo(data, indices, shape):
n_batch = _validate_bcoo(data, indices, shape).n_batch
if n_batch == 0:
@ -234,13 +222,17 @@ def bcoo_fromdense(mat, *, nse=None, n_batch=0, n_dense=0, index_dtype=jnp.int32
@bcoo_fromdense_p.def_impl
def _bcoo_fromdense_impl(mat, *, nse, n_batch, n_dense, index_dtype):
mat = jnp.asarray(mat)
n_sparse = mat.ndim - n_dense - n_batch
mask = (mat != 0)
if n_dense > 0:
mask = mask.any([-(i + 1) for i in range(n_dense)])
nonzero = lambda a: jnp.nonzero(a, size=nse) if a.ndim else ()
def _nonzero(a):
if a.ndim:
return jnp.nonzero(a, size=nse, fill_value=a.shape[:n_sparse])
return ()
for _ in range(n_batch):
nonzero = vmap(nonzero, 0)
indices = nonzero(mask)
_nonzero = vmap(_nonzero, 0)
indices = _nonzero(mask)
if not indices:
indices = jnp.zeros(mask.shape[:n_batch] + (nse, 0), index_dtype)
else:
@ -310,6 +302,7 @@ def bcoo_extract(indices, mat):
@bcoo_extract_p.def_impl
def _bcoo_extract_impl(indices, mat):
mat = jnp.asarray(mat)
n_batch, n_sparse, _, _ = _validate_bcoo(None, indices, mat.shape)
ind_slices = tuple(np.zeros(s, int) if i_s == 1 else np.arange(s)
@ -323,7 +316,7 @@ def _bcoo_extract_impl(indices, mat):
if not sparse_ind + batch_ind:
return mat[None]
return mat[batch_ind + sparse_ind]
return mat.at[batch_ind + sparse_ind].get(mode='fill', fill_value=0)
@bcoo_extract_p.def_abstract_eval
def _bcoo_extract_abstract_eval(indices, mat):
@ -524,7 +517,8 @@ def _bcoo_dot_general_impl(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs
idx = tuple(lhs_indices.T)
idx_right, idx_out = idx[:n_contracting], idx[n_contracting:]
ctc = [0] if n_contracting else []
prod = lax.dot_general(lhs_data, rhs[idx_right], (([], []), (ctc, ctc)))
prod = lax.dot_general(lhs_data, rhs.at[idx_right].get(mode='fill', fill_value=0),
(([], []), (ctc, ctc)))
return out_array.at[idx_out].add(prod) if idx_out else prod.sum(0, dtype=out_array.dtype)
for i in range(n_batch)[::-1]:
axes_in = [0, 0, 0, 0]
@ -708,7 +702,9 @@ def _bcoo_Mv(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_shape, rhs_sha
rhs_i = rhs_indices[:, 0]
mask = jnp.isin(lhs_i, rhs_i, assume_unique=True)
lhs_i_inv = (lhs_i[None, :] == rhs_i[:, None]).argmax(0)
out_data = lhs_data.at[jnp.arange(lhs.nse)].mul(jnp.where(mask, rhs_data[lhs_i_inv], 0))
lhs_i_inv = jnp.where(lhs_i < rhs_shape[0], lhs_i_inv, rhs_shape[0])
rhs_data_at_lhs_indices = jnp.where(mask, rhs_data.at[lhs_i_inv].get(mode='fill', fill_value=0), 0)
out_data = lhs_data.at[jnp.arange(lhs.nse)].mul(rhs_data_at_lhs_indices)
out_indices = jnp.concatenate([lhs_indices[:, :lhs_contract], lhs_indices[:, lhs_contract + 1:]], axis=1)
return out_data, out_indices
@ -837,6 +833,13 @@ def bcoo_reduce_sum(data, indices, *, shape, axes):
# Sum over dense dimensions -> sum over data
dense_axes = tuple(ax - n_sparse + 1 for ax in axes if ax >= n_batch + n_sparse)
data = data.sum(dense_axes)
if n_sparse:
# zero-out data corresponding to invalid indices.
sparse_shape = jnp.array(shape[n_batch: n_batch + n_sparse])
mask = jnp.all(indices < sparse_shape, -1)
if data.ndim > mask.ndim:
mask = lax.expand_dims(mask, tuple(range(mask.ndim, data.ndim)))
data = jnp.where(mask, data, 0)
# Sum over sparse dimensions -> drop index; sum is implicit
sparse_idx = [i for i in range(n_sparse) if i + n_batch not in axes]

View File

@ -1063,8 +1063,9 @@ class BCOOTest(jtu.JaxTestCase):
def test_bcoo_dedupe_padding(self):
# Regression test for https://github.com/google/jax/issues/8163
size = 3
data = jnp.array([1, 0, 0])
indices = jnp.array([1, 0, 0])[:, None]
indices = jnp.array([1, size, size])[:, None]
x = sparse.BCOO((data, indices), shape=(3,))
y = x._dedupe()
self.assertArraysEqual(x.todense(), y.todense())
@ -1151,6 +1152,35 @@ class BCOOTest(jtu.JaxTestCase):
self.assertEqual(M1.dtype, M2.dtype)
self.assertArraysEqual(M1.todense(), M2.todense())
def test_bcoo_bad_fillvals(self):
# Extra values have 100 rather than zero. This lets us check that logic is
# properly ignoring these indices.
data = jnp.array([1, 2, 3, 100, 100])
indices = jnp.array([1, 2, 3, 5, 5])[:, None]
x_sp = sparse.BCOO((data, indices), shape=(5,))
x_de = x_sp.todense()
data = jnp.array([3, 2, 100, 100])
indices = jnp.array([2, 3, 5, 5])[:, None]
y_sp = sparse.BCOO((data, indices), shape=(5,))
y_de = y_sp.todense()
self.assertArraysEqual(x_de, jnp.array([0, 1, 2, 3, 0]))
self.assertArraysEqual(y_de, jnp.array([0, 0, 3, 2, 0]))
self.assertArraysEqual(x_sp._dedupe().todense(), x_de)
self.assertArraysEqual(y_sp._dedupe().todense(), y_de)
# reduce_sum:
self.assertArraysEqual(x_sp.sum(), x_de.sum())
# bcoo_dot_general
self.assertArraysEqual(x_sp @ y_de, x_de @ y_de)
# bcoo_spdot_general
self.assertArraysEqual((x_sp @ y_sp).todense(), x_de @ y_de)
self.assertArraysEqual((y_sp @ x_sp).todense(), y_de @ x_de)
class SparseGradTest(jtu.JaxTestCase):
def test_sparse_grad(self):