mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sparse]: change bcoo pad values to use OOB indices
This commit is contained in:
parent
46b9653e28
commit
f424a90c71
@ -55,36 +55,24 @@ def _dedupe_bcoo(data, indices, shape):
|
||||
if indices.shape[:props.n_batch] != data.shape[:props.n_batch]:
|
||||
# TODO: handle broadcasted dimensions.
|
||||
raise NotImplementedError("dedupe_bcoo for broadcasted dimensions.")
|
||||
f = _dedupe_bcoo_one
|
||||
f = functools.partial(_dedupe_bcoo_one,
|
||||
shape=shape[props.n_batch:props.n_batch + props.n_sparse])
|
||||
for _ in range(props.n_batch):
|
||||
f = vmap(f)
|
||||
return f(data, indices)
|
||||
|
||||
def _dedupe_bcoo_one(data, indices):
|
||||
assert indices.ndim == 2
|
||||
assert data.shape[:1] == indices.shape[:1]
|
||||
|
||||
def _dedupe_bcoo_one(data, indices, *, shape):
|
||||
nse, = data.shape
|
||||
assert indices.shape == (nse, len(shape))
|
||||
if indices.shape[1] == 0:
|
||||
return data, indices
|
||||
|
||||
# The following is similar to
|
||||
# indices_unique, inv_idx = jnp.unique(indices, axis=0, return_inverse=True,
|
||||
# size=indices.shape[0], fill_value=0)
|
||||
# but modified to keep padding at the end of the resulting arrays.
|
||||
is_padding = (indices == 0).all(1) & (data == 0)
|
||||
perm = jnp.lexsort(indices[:, ::-1].T)
|
||||
aux = indices[perm]
|
||||
mask = jnp.ones(indices.shape[0], dtype=bool)
|
||||
mask = mask.at[1:].set(jnp.any(aux[1:] != aux[:-1], 1))
|
||||
mask = mask & ~is_padding[perm] # this is the padding modification.
|
||||
imask = jnp.cumsum(mask) - 1
|
||||
indices_unique = jnp.where(mask[:, None], aux, 0)[jnp.argsort(~mask)]
|
||||
inv_idx = jnp.zeros_like(imask).at[perm].set(imask)
|
||||
|
||||
indices_unique, inv_idx = jnp.unique(indices, axis=0, return_inverse=True,
|
||||
size=nse, fill_value=jnp.array(shape))
|
||||
data_unique = jnp.zeros_like(data).at[inv_idx].add(data)
|
||||
oob_mask = jnp.all(indices_unique == jnp.array(shape), 1)
|
||||
data_unique = jnp.where(oob_mask, 0, data_unique)
|
||||
return data_unique, indices_unique
|
||||
|
||||
|
||||
def _unbatch_bcoo(data, indices, shape):
|
||||
n_batch = _validate_bcoo(data, indices, shape).n_batch
|
||||
if n_batch == 0:
|
||||
@ -234,13 +222,17 @@ def bcoo_fromdense(mat, *, nse=None, n_batch=0, n_dense=0, index_dtype=jnp.int32
|
||||
@bcoo_fromdense_p.def_impl
|
||||
def _bcoo_fromdense_impl(mat, *, nse, n_batch, n_dense, index_dtype):
|
||||
mat = jnp.asarray(mat)
|
||||
n_sparse = mat.ndim - n_dense - n_batch
|
||||
mask = (mat != 0)
|
||||
if n_dense > 0:
|
||||
mask = mask.any([-(i + 1) for i in range(n_dense)])
|
||||
nonzero = lambda a: jnp.nonzero(a, size=nse) if a.ndim else ()
|
||||
def _nonzero(a):
|
||||
if a.ndim:
|
||||
return jnp.nonzero(a, size=nse, fill_value=a.shape[:n_sparse])
|
||||
return ()
|
||||
for _ in range(n_batch):
|
||||
nonzero = vmap(nonzero, 0)
|
||||
indices = nonzero(mask)
|
||||
_nonzero = vmap(_nonzero, 0)
|
||||
indices = _nonzero(mask)
|
||||
if not indices:
|
||||
indices = jnp.zeros(mask.shape[:n_batch] + (nse, 0), index_dtype)
|
||||
else:
|
||||
@ -310,6 +302,7 @@ def bcoo_extract(indices, mat):
|
||||
|
||||
@bcoo_extract_p.def_impl
|
||||
def _bcoo_extract_impl(indices, mat):
|
||||
mat = jnp.asarray(mat)
|
||||
n_batch, n_sparse, _, _ = _validate_bcoo(None, indices, mat.shape)
|
||||
|
||||
ind_slices = tuple(np.zeros(s, int) if i_s == 1 else np.arange(s)
|
||||
@ -323,7 +316,7 @@ def _bcoo_extract_impl(indices, mat):
|
||||
|
||||
if not sparse_ind + batch_ind:
|
||||
return mat[None]
|
||||
return mat[batch_ind + sparse_ind]
|
||||
return mat.at[batch_ind + sparse_ind].get(mode='fill', fill_value=0)
|
||||
|
||||
@bcoo_extract_p.def_abstract_eval
|
||||
def _bcoo_extract_abstract_eval(indices, mat):
|
||||
@ -524,7 +517,8 @@ def _bcoo_dot_general_impl(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs
|
||||
idx = tuple(lhs_indices.T)
|
||||
idx_right, idx_out = idx[:n_contracting], idx[n_contracting:]
|
||||
ctc = [0] if n_contracting else []
|
||||
prod = lax.dot_general(lhs_data, rhs[idx_right], (([], []), (ctc, ctc)))
|
||||
prod = lax.dot_general(lhs_data, rhs.at[idx_right].get(mode='fill', fill_value=0),
|
||||
(([], []), (ctc, ctc)))
|
||||
return out_array.at[idx_out].add(prod) if idx_out else prod.sum(0, dtype=out_array.dtype)
|
||||
for i in range(n_batch)[::-1]:
|
||||
axes_in = [0, 0, 0, 0]
|
||||
@ -708,7 +702,9 @@ def _bcoo_Mv(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_shape, rhs_sha
|
||||
rhs_i = rhs_indices[:, 0]
|
||||
mask = jnp.isin(lhs_i, rhs_i, assume_unique=True)
|
||||
lhs_i_inv = (lhs_i[None, :] == rhs_i[:, None]).argmax(0)
|
||||
out_data = lhs_data.at[jnp.arange(lhs.nse)].mul(jnp.where(mask, rhs_data[lhs_i_inv], 0))
|
||||
lhs_i_inv = jnp.where(lhs_i < rhs_shape[0], lhs_i_inv, rhs_shape[0])
|
||||
rhs_data_at_lhs_indices = jnp.where(mask, rhs_data.at[lhs_i_inv].get(mode='fill', fill_value=0), 0)
|
||||
out_data = lhs_data.at[jnp.arange(lhs.nse)].mul(rhs_data_at_lhs_indices)
|
||||
out_indices = jnp.concatenate([lhs_indices[:, :lhs_contract], lhs_indices[:, lhs_contract + 1:]], axis=1)
|
||||
return out_data, out_indices
|
||||
|
||||
@ -837,6 +833,13 @@ def bcoo_reduce_sum(data, indices, *, shape, axes):
|
||||
# Sum over dense dimensions -> sum over data
|
||||
dense_axes = tuple(ax - n_sparse + 1 for ax in axes if ax >= n_batch + n_sparse)
|
||||
data = data.sum(dense_axes)
|
||||
if n_sparse:
|
||||
# zero-out data corresponding to invalid indices.
|
||||
sparse_shape = jnp.array(shape[n_batch: n_batch + n_sparse])
|
||||
mask = jnp.all(indices < sparse_shape, -1)
|
||||
if data.ndim > mask.ndim:
|
||||
mask = lax.expand_dims(mask, tuple(range(mask.ndim, data.ndim)))
|
||||
data = jnp.where(mask, data, 0)
|
||||
|
||||
# Sum over sparse dimensions -> drop index; sum is implicit
|
||||
sparse_idx = [i for i in range(n_sparse) if i + n_batch not in axes]
|
||||
|
@ -1063,8 +1063,9 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
|
||||
def test_bcoo_dedupe_padding(self):
|
||||
# Regression test for https://github.com/google/jax/issues/8163
|
||||
size = 3
|
||||
data = jnp.array([1, 0, 0])
|
||||
indices = jnp.array([1, 0, 0])[:, None]
|
||||
indices = jnp.array([1, size, size])[:, None]
|
||||
x = sparse.BCOO((data, indices), shape=(3,))
|
||||
y = x._dedupe()
|
||||
self.assertArraysEqual(x.todense(), y.todense())
|
||||
@ -1151,6 +1152,35 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
self.assertEqual(M1.dtype, M2.dtype)
|
||||
self.assertArraysEqual(M1.todense(), M2.todense())
|
||||
|
||||
def test_bcoo_bad_fillvals(self):
|
||||
# Extra values have 100 rather than zero. This lets us check that logic is
|
||||
# properly ignoring these indices.
|
||||
data = jnp.array([1, 2, 3, 100, 100])
|
||||
indices = jnp.array([1, 2, 3, 5, 5])[:, None]
|
||||
x_sp = sparse.BCOO((data, indices), shape=(5,))
|
||||
x_de = x_sp.todense()
|
||||
|
||||
data = jnp.array([3, 2, 100, 100])
|
||||
indices = jnp.array([2, 3, 5, 5])[:, None]
|
||||
y_sp = sparse.BCOO((data, indices), shape=(5,))
|
||||
y_de = y_sp.todense()
|
||||
|
||||
self.assertArraysEqual(x_de, jnp.array([0, 1, 2, 3, 0]))
|
||||
self.assertArraysEqual(y_de, jnp.array([0, 0, 3, 2, 0]))
|
||||
|
||||
self.assertArraysEqual(x_sp._dedupe().todense(), x_de)
|
||||
self.assertArraysEqual(y_sp._dedupe().todense(), y_de)
|
||||
|
||||
# reduce_sum:
|
||||
self.assertArraysEqual(x_sp.sum(), x_de.sum())
|
||||
|
||||
# bcoo_dot_general
|
||||
self.assertArraysEqual(x_sp @ y_de, x_de @ y_de)
|
||||
|
||||
# bcoo_spdot_general
|
||||
self.assertArraysEqual((x_sp @ y_sp).todense(), x_de @ y_de)
|
||||
self.assertArraysEqual((y_sp @ x_sp).todense(), y_de @ x_de)
|
||||
|
||||
|
||||
class SparseGradTest(jtu.JaxTestCase):
|
||||
def test_sparse_grad(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user