mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sparse] add bcoo_add_batchdim
This commit is contained in:
parent
4012267a01
commit
93a24f3b83
@ -188,6 +188,7 @@ from jax.experimental.sparse.ad import (
|
||||
value_and_grad as value_and_grad,
|
||||
)
|
||||
from jax.experimental.sparse.bcoo import (
|
||||
bcoo_add_batch_dim as bcoo_add_batch_dim,
|
||||
bcoo_broadcast_in_dim as bcoo_broadcast_in_dim,
|
||||
bcoo_dot_general as bcoo_dot_general,
|
||||
bcoo_dot_general_p as bcoo_dot_general_p,
|
||||
|
@ -1228,6 +1228,45 @@ xla.register_translation(bcoo_spdot_general_p, xla.lower_fun(
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
# BCOO functions that maybe should be primitives?
|
||||
|
||||
|
||||
def bcoo_add_batch_dim(M):
|
||||
"""Convert a sparse dimension to a batch dimension
|
||||
|
||||
Please note that this function may result in a far less efficient storage scheme
|
||||
for the matrix (storage required will increase by a factor of `M.shape[0] * M.nse`).
|
||||
This utility is provided for convenience, e.g. to allow vmapping over non-batched
|
||||
matrices.
|
||||
|
||||
Args:
|
||||
M: BCOO matrix
|
||||
|
||||
Returns:
|
||||
M2: BCOO matrix with n_batch = M.n_batch + 1 and n_sparse = M.n_sparse - 1
|
||||
"""
|
||||
# TODO(jakevdp): allow user-specified nse?
|
||||
if M.n_sparse == 0:
|
||||
raise ValueError("Cannot add a batch dimension to a matrix with n_sparse=0")
|
||||
f = _add_batch_dim
|
||||
for _ in range(M.n_batch):
|
||||
f = vmap(f)
|
||||
return f(M)
|
||||
|
||||
def _add_batch_dim(M):
|
||||
assert M.n_batch == 0
|
||||
assert M.n_sparse > 0
|
||||
data = jnp.zeros_like(M.data, shape=(M.shape[0], *M.data.shape))
|
||||
data = data.at[M.indices[:, 0], jnp.arange(M.nse)].set(M.data)
|
||||
indices_shape = (M.shape[0], M.nse, M.n_sparse - 1)
|
||||
if M.n_sparse > 1:
|
||||
fill_value = jnp.array(M.shape[M.n_batch + 1: M.n_batch + M.n_sparse])
|
||||
indices = jnp.full_like(M.indices, shape=indices_shape, fill_value=fill_value)
|
||||
indices = indices.at[M.indices[:, 0], jnp.arange(M.nse)].set(M.indices[:, 1:])
|
||||
else:
|
||||
indices = jnp.empty_like(M.indices, shape=indices_shape)
|
||||
return BCOO((data, indices), shape=M.shape)
|
||||
|
||||
|
||||
def bcoo_broadcast_in_dim(mat, *, shape, broadcast_dimensions):
|
||||
"""Expand the size and rank of a BCOO array by duplicating the data.
|
||||
|
||||
|
@ -1772,6 +1772,24 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
self.assertEqual(M1.dtype, M2.dtype)
|
||||
self.assertArraysEqual(M1.todense(), M2.todense())
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
||||
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
||||
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
||||
for n_batch in range(len(shape))
|
||||
for n_dense in range(len(shape) - n_batch)))
|
||||
def test_bcoo_add_batch_dim(self, shape, dtype, n_batch, n_dense):
|
||||
rng_sparse = rand_sparse(self.rng())
|
||||
M1 = sparse.BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch, n_dense=n_dense)
|
||||
M2 = sparse.bcoo_add_batch_dim(M1)
|
||||
self.assertEqual(M2.n_batch, M1.n_batch + 1)
|
||||
self.assertEqual(M1.n_dense, M2.n_dense)
|
||||
self.assertEqual(M1.shape, M2.shape)
|
||||
self.assertEqual(M1.dtype, M2.dtype)
|
||||
self.assertArraysEqual(M1.todense(), M2.todense())
|
||||
|
||||
def test_bcoo_bad_fillvals(self):
|
||||
# Extra values have 100 rather than zero. This lets us check that logic is
|
||||
# properly ignoring these indices.
|
||||
|
Loading…
x
Reference in New Issue
Block a user