mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #7070 from jakevdp:bcoo-transpose
PiperOrigin-RevId: 381087809
This commit is contained in:
commit
35557794fb
@ -25,6 +25,8 @@ from .ops import (
|
||||
bcoo_rdot_general,
|
||||
bcoo_todense,
|
||||
bcoo_todense_p,
|
||||
bcoo_transpose,
|
||||
bcoo_transpose_p,
|
||||
coo_fromdense,
|
||||
coo_fromdense_p,
|
||||
coo_matmat,
|
||||
|
@ -32,7 +32,7 @@ Further down are some examples of potential high-level wrappers for sparse objec
|
||||
import functools
|
||||
import operator
|
||||
|
||||
from typing import Any, Tuple
|
||||
from typing import Any, Sequence, Tuple
|
||||
|
||||
from jax import api
|
||||
from jax import core
|
||||
@ -603,11 +603,18 @@ def _validate_bcoo(data, indices, shape):
|
||||
def _compatible(shape1, shape2):
|
||||
return all(s1 in (1, s2) for s1, s2 in safe_zip(shape1, shape2))
|
||||
|
||||
assert _compatible(data.shape[:n_batch], shape[:n_batch])
|
||||
assert data.shape[-(n_dense + 1):] == (nse,) + shape[n_batch + n_sparse:]
|
||||
|
||||
assert _compatible(indices.shape[:n_batch], shape[:n_batch])
|
||||
assert indices.shape[n_batch:] == (n_sparse, nse)
|
||||
if not _compatible(data.shape[:n_batch], shape[:n_batch]):
|
||||
raise ValueError("data batch dimensions not compatible for "
|
||||
f"data.shape={data.shape}, shape={shape}")
|
||||
if data.shape[-(n_dense + 1):] != (nse,) + shape[n_batch + n_sparse:]:
|
||||
raise ValueError(f"Invalid data.shape={data.shape} for "
|
||||
f"nse={nse}, n_batch={n_batch}, n_dense={n_dense}")
|
||||
if not _compatible(indices.shape[:n_batch], shape[:n_batch]):
|
||||
raise ValueError("indices batch dimensions not compatible for "
|
||||
f"indices.shape={indices.shape}, shape={shape}")
|
||||
if indices.shape[n_batch:] != (n_sparse, nse):
|
||||
raise ValueError(f"Invalid indices.shape={indices.shape} for "
|
||||
f"nse={nse}, n_batch={n_batch}, n_dense={n_dense}")
|
||||
|
||||
return n_batch, n_sparse, n_dense
|
||||
|
||||
@ -833,6 +840,98 @@ batching.primitive_batchers[bcoo_extract_p] = _bcoo_extract_batching_rule
|
||||
xla.translations[bcoo_extract_p] = xla.lower_fun(
|
||||
_bcoo_extract_impl, multiple_results=False)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
# bcoo_transpose
|
||||
# transpose of a BCOO array
|
||||
|
||||
bcoo_transpose_p = core.Primitive('bcoo_transpose')
|
||||
bcoo_transpose_p.multiple_results = True
|
||||
|
||||
def bcoo_transpose(data, indices, *, permutation, shape):
|
||||
return bcoo_transpose_p.bind(data, indices, permutation=permutation, shape=shape)
|
||||
|
||||
def _validate_permutation(data, indices, permutation, shape):
|
||||
if not isinstance(permutation, (tuple, list, np.ndarray)):
|
||||
raise TypeError(f"transpose permutation must be a tuple/list/ndarray, got {type(permutation)}.")
|
||||
if tuple(sorted(permutation)) != tuple(range(len(shape))):
|
||||
raise TypeError("transpose permutation isn't a permutation of operand dimensions, "
|
||||
f"got permutation {permutation} for shape {shape}.")
|
||||
n_batch, n_sparse, n_dense = _validate_bcoo(data, indices, shape)
|
||||
batch_perm = permutation[:n_batch]
|
||||
sparse_perm = [p - n_batch for p in permutation[n_batch: n_batch + n_sparse]]
|
||||
dense_perm = [p - n_sparse - n_batch for p in permutation[n_batch + n_sparse:]]
|
||||
if n_batch and tuple(sorted(batch_perm)) != tuple(range(n_batch)):
|
||||
raise NotImplementedError("transpose permutation cannot permute batch axes with non-batch axes; "
|
||||
f"got permutation {permutation}, with n_batch={n_batch}.")
|
||||
if n_dense and tuple(sorted(dense_perm)) != tuple(range(n_dense)):
|
||||
raise NotImplementedError("transpose permutation cannot permute dense axes with non-dense axes; "
|
||||
f"got permutation {permutation}, with n_dense={n_dense}.")
|
||||
return batch_perm, sparse_perm, dense_perm
|
||||
|
||||
@bcoo_transpose_p.def_impl
|
||||
def _bcoo_transpose_impl(data, indices, *, permutation: Sequence[int], shape: Tuple[int]):
|
||||
batch_perm, sparse_perm, dense_perm = _validate_permutation(data, indices, permutation, shape)
|
||||
n_batch = len(batch_perm)
|
||||
indices = indices[..., sparse_perm, :].transpose(*batch_perm, n_batch, n_batch + 1)
|
||||
data = data.transpose(*batch_perm, n_batch, *(d + n_batch + 1 for d in dense_perm))
|
||||
return data, indices
|
||||
|
||||
@bcoo_transpose_p.def_abstract_eval
|
||||
def _bcoo_transpose_abstract_eval(data, indices, *, permutation: Sequence[int], shape: Tuple[int]):
|
||||
batch_perm, _, dense_perm = _validate_permutation(data, indices, permutation, shape)
|
||||
n_batch = len(batch_perm)
|
||||
indices_shape = np.array(indices.shape)[[*batch_perm, n_batch, n_batch + 1]]
|
||||
data_shape = np.array(data.shape)[[*batch_perm, n_batch, *(d + n_batch + 1 for d in dense_perm)]]
|
||||
return core.ShapedArray(data_shape, data.dtype), core.ShapedArray(indices_shape, indices.dtype)
|
||||
|
||||
def _bcoo_transpose_jvp(primals, tangents, *, permutation, shape):
|
||||
data, indices = primals
|
||||
data_dot, _ = tangents
|
||||
primals_out = bcoo_transpose(data, indices, permutation=permutation, shape=shape)
|
||||
data_dot_out, _ = bcoo_transpose(data_dot, indices, permutation=permutation, shape=shape)
|
||||
return primals_out, (data_dot_out, ad.Zero.from_value(indices))
|
||||
|
||||
def _bcoo_transpose_transpose(ct, data, indices, *, permutation, shape):
|
||||
data_ct, indices_ct = ct
|
||||
assert isinstance(indices_ct, ad.Zero)
|
||||
if ad.is_undefined_primal(indices):
|
||||
raise ValueError("Cannot transpose with respect to sparse indices")
|
||||
assert data_ct.dtype == data.aval.dtype
|
||||
ct_shape = tuple(shape[p] for p in permutation)
|
||||
rev_permutation = np.argsort(permutation)
|
||||
# TODO(jakevdp) avoid dummy indices?
|
||||
dummy_indices = jnp.zeros([1 for i in range(indices.ndim - 2)] + list(indices.shape[-2:]), dtype=int)
|
||||
data_trans, _ = bcoo_transpose(data_ct, dummy_indices, permutation=rev_permutation, shape=ct_shape)
|
||||
return data_trans, indices_ct
|
||||
|
||||
def _bcoo_transpose_batch_rule(batched_args, batch_dims, *, permutation, shape):
|
||||
data, indices = batched_args
|
||||
batch_dims = list(batch_dims)
|
||||
batch_size = max(0 if dim is None else arg.shape[dim]
|
||||
for arg, dim in zip(batched_args, batch_dims))
|
||||
if batch_dims[0] is None:
|
||||
data = data[None]
|
||||
else:
|
||||
assert batch_dims[0] == 0
|
||||
if batch_dims[1] is None:
|
||||
indices = indices[None]
|
||||
else:
|
||||
assert batch_dims[1] == 0
|
||||
batched_shape = (batch_size, *shape)
|
||||
batched_permutation = (0, *(p + 1 for p in permutation))
|
||||
data, indices = bcoo_transpose(data, indices, permutation=batched_permutation, shape=batched_shape)
|
||||
if batch_dims[0] is None:
|
||||
data = data[0]
|
||||
if batch_dims[1] is None:
|
||||
indices = indices[0]
|
||||
return (data, indices), batch_dims
|
||||
|
||||
ad.primitive_jvps[bcoo_transpose_p] = _bcoo_transpose_jvp
|
||||
ad.primitive_transposes[bcoo_transpose_p] = _bcoo_transpose_transpose
|
||||
batching.primitive_batchers[bcoo_transpose_p] = _bcoo_transpose_batch_rule
|
||||
xla.translations[bcoo_transpose_p] = xla.lower_fun(
|
||||
_bcoo_transpose_impl, multiple_results=True)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
# bcoo_dot_general
|
||||
# (batched) general dot product of a BCOO sparse ND array and a dense ND array,
|
||||
@ -1080,7 +1179,7 @@ class JAXSparse:
|
||||
def matmat(self, B):
|
||||
raise NotImplementedError("matmat")
|
||||
|
||||
def transpose(self):
|
||||
def transpose(self, axes=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@ -1130,7 +1229,8 @@ class CSR(JAXSparse):
|
||||
def matmat(self, B):
|
||||
return csr_matmat(self.data, self.indices, self.indptr, B, shape=self.shape)
|
||||
|
||||
def transpose(self):
|
||||
def transpose(self, axes=None):
|
||||
assert axes is None
|
||||
return CSC((self.data, self.indices, self.indptr), shape=self.shape[::-1])
|
||||
|
||||
def tree_flatten(self):
|
||||
@ -1168,7 +1268,8 @@ class CSC(JAXSparse):
|
||||
def matmat(self, B):
|
||||
return csr_matmat(self.data, self.indices, self.indptr, B, shape=self.shape[::-1], transpose=True)
|
||||
|
||||
def transpose(self):
|
||||
def transpose(self, axes=None):
|
||||
assert axes is None
|
||||
return CSR((self.data, self.indices, self.indptr), shape=self.shape[::-1])
|
||||
|
||||
def tree_flatten(self):
|
||||
@ -1206,7 +1307,8 @@ class COO(JAXSparse):
|
||||
def matmat(self, B):
|
||||
return coo_matmat(self.data, self.row, self.col, B, shape=self.shape)
|
||||
|
||||
def transpose(self):
|
||||
def transpose(self, axes=None):
|
||||
assert axes is None
|
||||
return COO((self.data, self.col, self.row), shape=self.shape[::-1])
|
||||
|
||||
def tree_flatten(self):
|
||||
@ -1271,10 +1373,11 @@ class BCOO(JAXSparse):
|
||||
rhs_shape=self.shape,
|
||||
dimension_numbers=(([other.ndim - 1], [0]), ([], [])))
|
||||
|
||||
def transpose(self):
|
||||
if self.n_batch or self.n_dense:
|
||||
raise NotImplementedError("BCOO transpose with batch or dense dimensions")
|
||||
return BCOO((self.data, self.indices[::-1]), shape=self.shape[::-1])
|
||||
def transpose(self, axes=None):
|
||||
axes = np.arange(self.ndim)[::-1] if axes is None else axes
|
||||
data_T, indices_T = bcoo_transpose(self.data, self.indices, shape=self.shape, permutation=axes)
|
||||
shape_T = [self.shape[i] for i in axes]
|
||||
return BCOO((data_T, indices_T), shape=shape_T)
|
||||
|
||||
def tree_flatten(self):
|
||||
children = (self.data, self.indices)
|
||||
@ -1291,7 +1394,9 @@ class BCOO(JAXSparse):
|
||||
if _is_dummy(data, indices):
|
||||
shape = sparse_shape
|
||||
else:
|
||||
assert len(sparse_shape) == indices.shape[-2]
|
||||
if np.ndim(indices) < 2 or len(sparse_shape) != np.shape(indices)[-2]:
|
||||
raise ValueError(f"Invalid sparse representation: got indices.shape={np.shape(indices)}, "
|
||||
f"data.shape={np.shape(data)}, sparse_shape={sparse_shape}")
|
||||
n_batch = indices.ndim - 2
|
||||
shape = (
|
||||
tuple(np.maximum(data.shape[:n_batch], indices.shape[:n_batch]))
|
||||
|
@ -501,6 +501,73 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
self.assertEqual(j1.shape, data.shape + M.shape)
|
||||
self.assertEqual(hess.shape, data.shape + 2 * M.shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
||||
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
||||
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
||||
for dtype in jtu.dtypes.floating
|
||||
for n_batch in range(len(shape) + 1)
|
||||
for n_dense in range(len(shape) + 1 - n_batch)))
|
||||
def test_bcoo_transpose(self, shape, dtype, n_batch, n_dense):
|
||||
n_sparse = len(shape) - n_batch - n_dense
|
||||
rng = self.rng()
|
||||
sprng = rand_sparse(rng)
|
||||
M = sprng(shape, dtype)
|
||||
data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
|
||||
|
||||
permutation = np.concatenate([
|
||||
rng.permutation(range(n_batch)),
|
||||
rng.permutation(range(n_batch, n_batch + n_sparse)),
|
||||
rng.permutation(range(n_batch + n_sparse, len(shape)))]).astype(int)
|
||||
|
||||
M_T = M.transpose(permutation)
|
||||
trans = partial(sparse.bcoo_transpose, shape=shape, permutation=permutation)
|
||||
self.assertArraysEqual(M_T, sparse.bcoo_todense(*trans(data, indices), shape=M_T.shape))
|
||||
self.assertArraysEqual(M_T, sparse.bcoo_todense(*jit(trans)(data, indices), shape=M_T.shape))
|
||||
|
||||
# test batched
|
||||
def trans(M):
|
||||
return M.transpose([p - n_batch for p in permutation[n_batch:]])
|
||||
for _ in range(n_batch):
|
||||
trans = jax.vmap(trans)
|
||||
Msp = sparse.BCOO.fromdense(M, n_batch=n_batch, n_dense=n_dense)
|
||||
self.assertArraysEqual(trans(M), trans(Msp).todense())
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
||||
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
||||
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
||||
for dtype in jtu.dtypes.floating
|
||||
for n_batch in range(len(shape) + 1)
|
||||
for n_dense in range(len(shape) + 1 - n_batch)))
|
||||
def test_bcoo_transpose_ad(self, shape, dtype, n_batch, n_dense):
|
||||
n_sparse = len(shape) - n_batch - n_dense
|
||||
rng = self.rng()
|
||||
sprng = rand_sparse(self.rng())
|
||||
|
||||
M = sprng(shape, dtype)
|
||||
data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
|
||||
|
||||
permutation = np.concatenate([
|
||||
rng.permutation(range(n_batch)),
|
||||
rng.permutation(range(n_batch, n_batch + n_sparse)),
|
||||
rng.permutation(range(n_batch + n_sparse, len(shape)))]).astype(int)
|
||||
|
||||
def f_sparse(data):
|
||||
return sparse.bcoo_transpose(data, indices, shape=shape, permutation=permutation)[0]
|
||||
|
||||
jf_sparse = jax.jacfwd(f_sparse)(data)
|
||||
jr_sparse = jax.jacrev(f_sparse)(data)
|
||||
|
||||
tol = {}
|
||||
if jtu.device_under_test() == "tpu":
|
||||
tol = {np.float32: 5E-3}
|
||||
|
||||
# TODO(jakevdp) also test against dense version?
|
||||
self.assertAllClose(jf_sparse, jr_sparse, rtol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
||||
|
Loading…
x
Reference in New Issue
Block a user