mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #13175 from jakevdp:bcoo-transpose
PiperOrigin-RevId: 487333413
This commit is contained in:
commit
fa0217bdd3
@ -46,6 +46,7 @@ from jax._src.lib import xla_bridge
|
||||
from jax._src.lib import gpu_sparse
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.numpy.setops import _unique
|
||||
from jax._src.typing import Array
|
||||
from jax._src.util import canonicalize_axis
|
||||
|
||||
|
||||
@ -459,7 +460,7 @@ mlir.register_lowering(bcoo_extract_p, mlir.lower_fun(
|
||||
bcoo_transpose_p = core.Primitive('bcoo_transpose')
|
||||
bcoo_transpose_p.multiple_results = True
|
||||
|
||||
def bcoo_transpose(mat, *, permutation: Sequence[int]):
|
||||
def bcoo_transpose(mat: BCOO, *, permutation: Sequence[int]) -> BCOO:
|
||||
"""Transpose a BCOO-format array.
|
||||
|
||||
Args:
|
||||
@ -474,10 +475,12 @@ def bcoo_transpose(mat, *, permutation: Sequence[int]):
|
||||
Returns:
|
||||
A BCOO-format array.
|
||||
"""
|
||||
return BCOO(_bcoo_transpose(mat.data, mat.indices, permutation=permutation, spinfo=mat._info),
|
||||
shape=mat._info.shape, unique_indices=mat.unique_indices)
|
||||
buffers = _bcoo_transpose(mat.data, mat.indices, permutation=permutation, spinfo=mat._info)
|
||||
out_shape = tuple(mat.shape[p] for p in permutation)
|
||||
return BCOO(buffers, shape=out_shape, unique_indices=mat.unique_indices)
|
||||
|
||||
def _bcoo_transpose(data, indices, *, permutation: Sequence[int], spinfo: BCOOInfo):
|
||||
def _bcoo_transpose(data: Array, indices: Array, *,
|
||||
permutation: Sequence[int], spinfo: BCOOInfo) -> Tuple[Array, Array]:
|
||||
permutation = tuple(permutation)
|
||||
if permutation == tuple(range(len(spinfo.shape))):
|
||||
return data, indices
|
||||
|
@ -920,9 +920,7 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
rng = self.rng()
|
||||
sprng = rand_sparse(rng)
|
||||
M = sprng(shape, dtype)
|
||||
nse = sparse.util._count_stored_elements(M, n_batch=n_batch,
|
||||
n_dense=n_dense)
|
||||
data, indices = sparse_bcoo._bcoo_fromdense(M, nse=nse, n_batch=n_batch, n_dense=n_dense)
|
||||
M_bcoo = sparse.BCOO.fromdense(M)
|
||||
|
||||
permutation = np.concatenate([
|
||||
rng.permutation(range(n_batch)),
|
||||
@ -930,9 +928,10 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
rng.permutation(range(n_batch + n_sparse, len(shape)))]).astype(int)
|
||||
|
||||
M_T = M.transpose(permutation)
|
||||
trans = partial(sparse_bcoo._bcoo_transpose, spinfo=BCOOInfo(shape), permutation=permutation)
|
||||
self.assertArraysEqual(M_T, sparse_bcoo._bcoo_todense(*trans(data, indices), spinfo=BCOOInfo(M_T.shape)))
|
||||
self.assertArraysEqual(M_T, sparse_bcoo._bcoo_todense(*jit(trans)(data, indices), spinfo=BCOOInfo(M_T.shape)))
|
||||
M_T_bcoo = sparse.bcoo_transpose(M_bcoo, permutation=permutation)
|
||||
M_T_bcoo_jit = jit(partial(sparse.bcoo_transpose, permutation=permutation))(M_bcoo)
|
||||
self.assertArraysEqual(M_T, M_T_bcoo.todense())
|
||||
self.assertArraysEqual(M_T, M_T_bcoo_jit.todense())
|
||||
|
||||
# test batched
|
||||
def trans(M):
|
||||
|
Loading…
x
Reference in New Issue
Block a user