mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add bcoo_reduce_sum() function
This commit is contained in:
parent
31e9c65f2a
commit
2113d9c34d
@ -979,6 +979,59 @@ batching.primitive_batchers[bcoo_dot_general_p] = _bcoo_dot_general_batch_rule
|
||||
xla.translations[bcoo_dot_general_p] = xla.lower_fun(
|
||||
_bcoo_dot_general_impl, multiple_results=False)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
# BCOO functions that maybe should be primitives?
|
||||
|
||||
def _tuple_replace(tup, ind, val):
|
||||
return tuple(val if i == ind else t for i, t in enumerate(tup))
|
||||
|
||||
def bcoo_reduce_sum(data, indices, *, shape, axes):
|
||||
assert all(0 <= a < len(shape) for a in axes)
|
||||
axes = sorted(set(axes))
|
||||
n_sparse, nse = indices.shape[-2:]
|
||||
n_batch = indices.ndim - 2
|
||||
|
||||
# Sum over dense dimensions -> sum over data
|
||||
dense_axes = tuple(ax - n_sparse + 1 for ax in axes if ax >= n_batch + n_sparse)
|
||||
data = data.sum(dense_axes)
|
||||
|
||||
# Sum over sparse dimensions -> drop index; sum is implicit
|
||||
sparse_idx = [i for i in range(n_sparse) if i + n_batch not in axes]
|
||||
if not sparse_idx:
|
||||
indices = jnp.zeros(_tuple_replace(indices.shape, n_batch, 0), indices.dtype)
|
||||
else:
|
||||
indices = indices[..., np.array(sparse_idx), :]
|
||||
|
||||
# Sum over batch dimensions -> reshape into nse
|
||||
batch_axes = {ax for ax in axes if ax < n_batch}
|
||||
|
||||
# First handle broadcasted batch dimensions
|
||||
for ax in batch_axes:
|
||||
if data.shape[ax] == 1:
|
||||
if indices.shape[ax] == 1:
|
||||
data = data * shape[ax]
|
||||
else:
|
||||
data = lax.broadcast_in_dim(data, _tuple_replace(data.shape, ax, shape[ax]), tuple(range(data.ndim)))
|
||||
else:
|
||||
if indices.shape[ax] == 1:
|
||||
data = data.sum(ax)
|
||||
assert data.shape[ax] == indices.shape[ax]
|
||||
|
||||
new_batch_dims = tuple(sorted(set(range(n_batch)) - batch_axes))
|
||||
new_batch_shape = tuple(data.shape[i] for i in new_batch_dims)
|
||||
new_nse = int(nse * np.prod([data.shape[i] for i in batch_axes]))
|
||||
|
||||
data = lax.reshape(data,
|
||||
new_batch_shape + (new_nse,) + data.shape[n_batch + 1:],
|
||||
new_batch_dims + tuple(batch_axes) + tuple(range(n_batch, data.ndim)))
|
||||
indices = lax.reshape(indices,
|
||||
new_batch_shape + (indices.shape[n_batch], new_nse),
|
||||
new_batch_dims + (n_batch,) + tuple(batch_axes) + tuple(range(n_batch + 1, indices.ndim)))
|
||||
|
||||
out_shape = tuple(shape[i] for i in range(len(shape)) if i not in axes)
|
||||
return data, indices, out_shape
|
||||
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
# Sparse objects (APIs subject to change)
|
||||
class JAXSparse:
|
||||
@ -1155,8 +1208,8 @@ class BCOO(JAXSparse):
|
||||
super().__init__(args, shape=shape)
|
||||
|
||||
@classmethod
|
||||
def fromdense(cls, mat, *, nnz=None, index_dtype=np.int32):
|
||||
return cls(bcoo_fromdense(mat, nse=nnz, index_dtype=index_dtype), shape=mat.shape)
|
||||
def fromdense(cls, mat, *, nnz=None, index_dtype=np.int32, n_dense=0, n_batch=0):
|
||||
return cls(bcoo_fromdense(mat, nse=nnz, index_dtype=index_dtype, n_dense=n_dense, n_batch=n_batch), shape=mat.shape)
|
||||
|
||||
@api.jit
|
||||
def todense(self):
|
||||
|
@ -715,6 +715,26 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(M1, M2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_nbatch={}_ndense={}_axes={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense, axes),
|
||||
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense, "axes": axes}
|
||||
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
||||
for n_batch in range(len(shape) + 1)
|
||||
for n_dense in range(len(shape) + 1 - n_batch)
|
||||
for naxes in range(len(shape))
|
||||
for axes in itertools.combinations(range(len(shape)), naxes)))
|
||||
def test_bcoo_reduce_sum(self, shape, dtype, n_batch, n_dense, axes):
|
||||
rng = rand_sparse(self.rng())
|
||||
M = rng(shape, dtype)
|
||||
data, indices = sparse_ops.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
|
||||
data_out, indices_out, shape_out = sparse_ops.bcoo_reduce_sum(data, indices, shape=shape, axes=axes)
|
||||
result_dense = M.sum(axes)
|
||||
result_sparse = sparse_ops.bcoo_todense(data_out, indices_out, shape=shape_out)
|
||||
tol = {np.float32: 1E-6, np.float64: 1E-14}
|
||||
self.assertAllClose(result_dense, result_sparse, atol=tol, rtol=tol)
|
||||
|
||||
|
||||
class SparseObjectTest(jtu.JaxTestCase):
|
||||
@parameterized.named_parameters(
|
||||
|
Loading…
x
Reference in New Issue
Block a user