Merge pull request #7070 from jakevdp:bcoo-transpose

PiperOrigin-RevId: 381087809
This commit is contained in:
jax authors 2021-06-23 12:15:34 -07:00
commit 35557794fb
3 changed files with 189 additions and 15 deletions

View File

@ -25,6 +25,8 @@ from .ops import (
bcoo_rdot_general,
bcoo_todense,
bcoo_todense_p,
bcoo_transpose,
bcoo_transpose_p,
coo_fromdense,
coo_fromdense_p,
coo_matmat,

View File

@ -32,7 +32,7 @@ Further down are some examples of potential high-level wrappers for sparse objec
import functools
import operator
from typing import Any, Tuple
from typing import Any, Sequence, Tuple
from jax import api
from jax import core
@ -603,11 +603,18 @@ def _validate_bcoo(data, indices, shape):
def _compatible(shape1, shape2):
return all(s1 in (1, s2) for s1, s2 in safe_zip(shape1, shape2))
assert _compatible(data.shape[:n_batch], shape[:n_batch])
assert data.shape[-(n_dense + 1):] == (nse,) + shape[n_batch + n_sparse:]
assert _compatible(indices.shape[:n_batch], shape[:n_batch])
assert indices.shape[n_batch:] == (n_sparse, nse)
if not _compatible(data.shape[:n_batch], shape[:n_batch]):
raise ValueError("data batch dimensions not compatible for "
f"data.shape={data.shape}, shape={shape}")
if data.shape[-(n_dense + 1):] != (nse,) + shape[n_batch + n_sparse:]:
raise ValueError(f"Invalid data.shape={data.shape} for "
f"nse={nse}, n_batch={n_batch}, n_dense={n_dense}")
if not _compatible(indices.shape[:n_batch], shape[:n_batch]):
raise ValueError("indices batch dimensions not compatible for "
f"indices.shape={indices.shape}, shape={shape}")
if indices.shape[n_batch:] != (n_sparse, nse):
raise ValueError(f"Invalid indices.shape={indices.shape} for "
f"nse={nse}, n_batch={n_batch}, n_dense={n_dense}")
return n_batch, n_sparse, n_dense
@ -833,6 +840,98 @@ batching.primitive_batchers[bcoo_extract_p] = _bcoo_extract_batching_rule
xla.translations[bcoo_extract_p] = xla.lower_fun(
_bcoo_extract_impl, multiple_results=False)
#----------------------------------------------------------------------
# bcoo_transpose
# transpose of a BCOO array
bcoo_transpose_p = core.Primitive('bcoo_transpose')
bcoo_transpose_p.multiple_results = True
def bcoo_transpose(data, indices, *, permutation, shape):
return bcoo_transpose_p.bind(data, indices, permutation=permutation, shape=shape)
def _validate_permutation(data, indices, permutation, shape):
if not isinstance(permutation, (tuple, list, np.ndarray)):
raise TypeError(f"transpose permutation must be a tuple/list/ndarray, got {type(permutation)}.")
if tuple(sorted(permutation)) != tuple(range(len(shape))):
raise TypeError("transpose permutation isn't a permutation of operand dimensions, "
f"got permutation {permutation} for shape {shape}.")
n_batch, n_sparse, n_dense = _validate_bcoo(data, indices, shape)
batch_perm = permutation[:n_batch]
sparse_perm = [p - n_batch for p in permutation[n_batch: n_batch + n_sparse]]
dense_perm = [p - n_sparse - n_batch for p in permutation[n_batch + n_sparse:]]
if n_batch and tuple(sorted(batch_perm)) != tuple(range(n_batch)):
raise NotImplementedError("transpose permutation cannot permute batch axes with non-batch axes; "
f"got permutation {permutation}, with n_batch={n_batch}.")
if n_dense and tuple(sorted(dense_perm)) != tuple(range(n_dense)):
raise NotImplementedError("transpose permutation cannot permute dense axes with non-dense axes; "
f"got permutation {permutation}, with n_dense={n_dense}.")
return batch_perm, sparse_perm, dense_perm
@bcoo_transpose_p.def_impl
def _bcoo_transpose_impl(data, indices, *, permutation: Sequence[int], shape: Tuple[int]):
batch_perm, sparse_perm, dense_perm = _validate_permutation(data, indices, permutation, shape)
n_batch = len(batch_perm)
indices = indices[..., sparse_perm, :].transpose(*batch_perm, n_batch, n_batch + 1)
data = data.transpose(*batch_perm, n_batch, *(d + n_batch + 1 for d in dense_perm))
return data, indices
@bcoo_transpose_p.def_abstract_eval
def _bcoo_transpose_abstract_eval(data, indices, *, permutation: Sequence[int], shape: Tuple[int]):
batch_perm, _, dense_perm = _validate_permutation(data, indices, permutation, shape)
n_batch = len(batch_perm)
indices_shape = np.array(indices.shape)[[*batch_perm, n_batch, n_batch + 1]]
data_shape = np.array(data.shape)[[*batch_perm, n_batch, *(d + n_batch + 1 for d in dense_perm)]]
return core.ShapedArray(data_shape, data.dtype), core.ShapedArray(indices_shape, indices.dtype)
def _bcoo_transpose_jvp(primals, tangents, *, permutation, shape):
data, indices = primals
data_dot, _ = tangents
primals_out = bcoo_transpose(data, indices, permutation=permutation, shape=shape)
data_dot_out, _ = bcoo_transpose(data_dot, indices, permutation=permutation, shape=shape)
return primals_out, (data_dot_out, ad.Zero.from_value(indices))
def _bcoo_transpose_transpose(ct, data, indices, *, permutation, shape):
data_ct, indices_ct = ct
assert isinstance(indices_ct, ad.Zero)
if ad.is_undefined_primal(indices):
raise ValueError("Cannot transpose with respect to sparse indices")
assert data_ct.dtype == data.aval.dtype
ct_shape = tuple(shape[p] for p in permutation)
rev_permutation = np.argsort(permutation)
# TODO(jakevdp) avoid dummy indices?
dummy_indices = jnp.zeros([1 for i in range(indices.ndim - 2)] + list(indices.shape[-2:]), dtype=int)
data_trans, _ = bcoo_transpose(data_ct, dummy_indices, permutation=rev_permutation, shape=ct_shape)
return data_trans, indices_ct
def _bcoo_transpose_batch_rule(batched_args, batch_dims, *, permutation, shape):
data, indices = batched_args
batch_dims = list(batch_dims)
batch_size = max(0 if dim is None else arg.shape[dim]
for arg, dim in zip(batched_args, batch_dims))
if batch_dims[0] is None:
data = data[None]
else:
assert batch_dims[0] == 0
if batch_dims[1] is None:
indices = indices[None]
else:
assert batch_dims[1] == 0
batched_shape = (batch_size, *shape)
batched_permutation = (0, *(p + 1 for p in permutation))
data, indices = bcoo_transpose(data, indices, permutation=batched_permutation, shape=batched_shape)
if batch_dims[0] is None:
data = data[0]
if batch_dims[1] is None:
indices = indices[0]
return (data, indices), batch_dims
ad.primitive_jvps[bcoo_transpose_p] = _bcoo_transpose_jvp
ad.primitive_transposes[bcoo_transpose_p] = _bcoo_transpose_transpose
batching.primitive_batchers[bcoo_transpose_p] = _bcoo_transpose_batch_rule
xla.translations[bcoo_transpose_p] = xla.lower_fun(
_bcoo_transpose_impl, multiple_results=True)
#----------------------------------------------------------------------
# bcoo_dot_general
# (batched) general dot product of a BCOO sparse ND array and a dense ND array,
@ -1080,7 +1179,7 @@ class JAXSparse:
def matmat(self, B):
raise NotImplementedError("matmat")
def transpose(self):
def transpose(self, axes=None):
raise NotImplementedError()
@property
@ -1130,7 +1229,8 @@ class CSR(JAXSparse):
def matmat(self, B):
return csr_matmat(self.data, self.indices, self.indptr, B, shape=self.shape)
def transpose(self):
def transpose(self, axes=None):
assert axes is None
return CSC((self.data, self.indices, self.indptr), shape=self.shape[::-1])
def tree_flatten(self):
@ -1168,7 +1268,8 @@ class CSC(JAXSparse):
def matmat(self, B):
return csr_matmat(self.data, self.indices, self.indptr, B, shape=self.shape[::-1], transpose=True)
def transpose(self):
def transpose(self, axes=None):
assert axes is None
return CSR((self.data, self.indices, self.indptr), shape=self.shape[::-1])
def tree_flatten(self):
@ -1206,7 +1307,8 @@ class COO(JAXSparse):
def matmat(self, B):
return coo_matmat(self.data, self.row, self.col, B, shape=self.shape)
def transpose(self):
def transpose(self, axes=None):
assert axes is None
return COO((self.data, self.col, self.row), shape=self.shape[::-1])
def tree_flatten(self):
@ -1271,10 +1373,11 @@ class BCOO(JAXSparse):
rhs_shape=self.shape,
dimension_numbers=(([other.ndim - 1], [0]), ([], [])))
def transpose(self):
if self.n_batch or self.n_dense:
raise NotImplementedError("BCOO transpose with batch or dense dimensions")
return BCOO((self.data, self.indices[::-1]), shape=self.shape[::-1])
def transpose(self, axes=None):
axes = np.arange(self.ndim)[::-1] if axes is None else axes
data_T, indices_T = bcoo_transpose(self.data, self.indices, shape=self.shape, permutation=axes)
shape_T = [self.shape[i] for i in axes]
return BCOO((data_T, indices_T), shape=shape_T)
def tree_flatten(self):
children = (self.data, self.indices)
@ -1291,7 +1394,9 @@ class BCOO(JAXSparse):
if _is_dummy(data, indices):
shape = sparse_shape
else:
assert len(sparse_shape) == indices.shape[-2]
if np.ndim(indices) < 2 or len(sparse_shape) != np.shape(indices)[-2]:
raise ValueError(f"Invalid sparse representation: got indices.shape={np.shape(indices)}, "
f"data.shape={np.shape(data)}, sparse_shape={sparse_shape}")
n_batch = indices.ndim - 2
shape = (
tuple(np.maximum(data.shape[:n_batch], indices.shape[:n_batch]))

View File

@ -501,6 +501,73 @@ class BCOOTest(jtu.JaxTestCase):
self.assertEqual(j1.shape, data.shape + M.shape)
self.assertEqual(hess.shape, data.shape + 2 * M.shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
for dtype in jtu.dtypes.floating
for n_batch in range(len(shape) + 1)
for n_dense in range(len(shape) + 1 - n_batch)))
def test_bcoo_transpose(self, shape, dtype, n_batch, n_dense):
n_sparse = len(shape) - n_batch - n_dense
rng = self.rng()
sprng = rand_sparse(rng)
M = sprng(shape, dtype)
data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
permutation = np.concatenate([
rng.permutation(range(n_batch)),
rng.permutation(range(n_batch, n_batch + n_sparse)),
rng.permutation(range(n_batch + n_sparse, len(shape)))]).astype(int)
M_T = M.transpose(permutation)
trans = partial(sparse.bcoo_transpose, shape=shape, permutation=permutation)
self.assertArraysEqual(M_T, sparse.bcoo_todense(*trans(data, indices), shape=M_T.shape))
self.assertArraysEqual(M_T, sparse.bcoo_todense(*jit(trans)(data, indices), shape=M_T.shape))
# test batched
def trans(M):
return M.transpose([p - n_batch for p in permutation[n_batch:]])
for _ in range(n_batch):
trans = jax.vmap(trans)
Msp = sparse.BCOO.fromdense(M, n_batch=n_batch, n_dense=n_dense)
self.assertArraysEqual(trans(M), trans(Msp).todense())
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
for dtype in jtu.dtypes.floating
for n_batch in range(len(shape) + 1)
for n_dense in range(len(shape) + 1 - n_batch)))
def test_bcoo_transpose_ad(self, shape, dtype, n_batch, n_dense):
n_sparse = len(shape) - n_batch - n_dense
rng = self.rng()
sprng = rand_sparse(self.rng())
M = sprng(shape, dtype)
data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
permutation = np.concatenate([
rng.permutation(range(n_batch)),
rng.permutation(range(n_batch, n_batch + n_sparse)),
rng.permutation(range(n_batch + n_sparse, len(shape)))]).astype(int)
def f_sparse(data):
return sparse.bcoo_transpose(data, indices, shape=shape, permutation=permutation)[0]
jf_sparse = jax.jacfwd(f_sparse)(data)
jr_sparse = jax.jacrev(f_sparse)(data)
tol = {}
if jtu.device_under_test() == "tpu":
tol = {np.float32: 5E-3}
# TODO(jakevdp) also test against dense version?
self.assertAllClose(jf_sparse, jr_sparse, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),