Merge pull request #13175 from jakevdp:bcoo-transpose

PiperOrigin-RevId: 487333413
This commit is contained in:
jax authors 2022-11-09 13:34:55 -08:00
commit fa0217bdd3
2 changed files with 12 additions and 10 deletions

View File

@ -46,6 +46,7 @@ from jax._src.lib import xla_bridge
from jax._src.lib import gpu_sparse
from jax._src.lib.mlir.dialects import mhlo
from jax._src.numpy.setops import _unique
from jax._src.typing import Array
from jax._src.util import canonicalize_axis
@ -459,7 +460,7 @@ mlir.register_lowering(bcoo_extract_p, mlir.lower_fun(
bcoo_transpose_p = core.Primitive('bcoo_transpose')
bcoo_transpose_p.multiple_results = True
def bcoo_transpose(mat, *, permutation: Sequence[int]):
def bcoo_transpose(mat: BCOO, *, permutation: Sequence[int]) -> BCOO:
"""Transpose a BCOO-format array.
Args:
@ -474,10 +475,12 @@ def bcoo_transpose(mat, *, permutation: Sequence[int]):
Returns:
A BCOO-format array.
"""
return BCOO(_bcoo_transpose(mat.data, mat.indices, permutation=permutation, spinfo=mat._info),
shape=mat._info.shape, unique_indices=mat.unique_indices)
buffers = _bcoo_transpose(mat.data, mat.indices, permutation=permutation, spinfo=mat._info)
out_shape = tuple(mat.shape[p] for p in permutation)
return BCOO(buffers, shape=out_shape, unique_indices=mat.unique_indices)
def _bcoo_transpose(data, indices, *, permutation: Sequence[int], spinfo: BCOOInfo):
def _bcoo_transpose(data: Array, indices: Array, *,
permutation: Sequence[int], spinfo: BCOOInfo) -> Tuple[Array, Array]:
permutation = tuple(permutation)
if permutation == tuple(range(len(spinfo.shape))):
return data, indices

View File

@ -920,9 +920,7 @@ class BCOOTest(jtu.JaxTestCase):
rng = self.rng()
sprng = rand_sparse(rng)
M = sprng(shape, dtype)
nse = sparse.util._count_stored_elements(M, n_batch=n_batch,
n_dense=n_dense)
data, indices = sparse_bcoo._bcoo_fromdense(M, nse=nse, n_batch=n_batch, n_dense=n_dense)
M_bcoo = sparse.BCOO.fromdense(M)
permutation = np.concatenate([
rng.permutation(range(n_batch)),
@ -930,9 +928,10 @@ class BCOOTest(jtu.JaxTestCase):
rng.permutation(range(n_batch + n_sparse, len(shape)))]).astype(int)
M_T = M.transpose(permutation)
trans = partial(sparse_bcoo._bcoo_transpose, spinfo=BCOOInfo(shape), permutation=permutation)
self.assertArraysEqual(M_T, sparse_bcoo._bcoo_todense(*trans(data, indices), spinfo=BCOOInfo(M_T.shape)))
self.assertArraysEqual(M_T, sparse_bcoo._bcoo_todense(*jit(trans)(data, indices), spinfo=BCOOInfo(M_T.shape)))
M_T_bcoo = sparse.bcoo_transpose(M_bcoo, permutation=permutation)
M_T_bcoo_jit = jit(partial(sparse.bcoo_transpose, permutation=permutation))(M_bcoo)
self.assertArraysEqual(M_T, M_T_bcoo.todense())
self.assertArraysEqual(M_T, M_T_bcoo_jit.todense())
# test batched
def trans(M):