mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sparse] bcoo_dot_general_sampled: faster special case
This commit is contained in:
parent
7e001d842e
commit
54bd631c1a
@ -1086,11 +1086,43 @@ def bcoo_dot_general_sampled(A: Array, B: Array, indices: Array, *, dimension_nu
|
||||
return bcoo_dot_general_sampled_p.bind(A, B, indices,
|
||||
dimension_numbers=(cdims, bdims))
|
||||
|
||||
def _bcoo_dot_general_sampled_slow(A, B, indices, *, dimension_numbers):
|
||||
return _bcoo_extract(indices, lax.dot_general(A, B, dimension_numbers=dimension_numbers))
|
||||
|
||||
def _bcoo_dot_general_sampled_simple(A, B, indices, *, dimension_numbers):
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
assert not (lhs_contract or rhs_contract or lhs_batch or rhs_batch)
|
||||
assert A.ndim == B.ndim == 1
|
||||
n_batch = indices.ndim - 2
|
||||
n_sparse = indices.shape[-1]
|
||||
nse = indices.shape[-2]
|
||||
assert n_batch + n_sparse == 2
|
||||
if n_batch == 0:
|
||||
return A[indices[:, 0]] * B[indices[:, 1]]
|
||||
elif n_batch == 1:
|
||||
return A[:, None] * B[indices[..., 0]]
|
||||
elif n_batch == 2:
|
||||
out = A[:, None, None] * B[None, :, None]
|
||||
return lax.broadcast_in_dim(out, (len(A), len(B), nse), (0, 1, 2))
|
||||
else:
|
||||
raise ValueError("too many batch dimensions.")
|
||||
|
||||
@bcoo_dot_general_sampled_p.def_impl
|
||||
def _bcoo_dot_general_sampled_impl(A, B, indices, *, dimension_numbers):
|
||||
# TODO(jakevdp): use a more efficient implementation that avoids the full dot product.
|
||||
dense_result = lax.dot_general(A, B, dimension_numbers=dimension_numbers)
|
||||
return _bcoo_extract(indices, dense_result)
|
||||
A = jnp.asarray(A)
|
||||
B = jnp.asarray(B)
|
||||
indices = jnp.asarray(indices)
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
n_batch = indices.ndim - 2
|
||||
n_sparse = indices.shape[-1]
|
||||
|
||||
# TODO(jakevdp): add fast approach for more general cases.
|
||||
if (not (lhs_contract or rhs_contract or lhs_batch or rhs_batch)
|
||||
and A.ndim == B.ndim == 1 and n_sparse + n_batch == 2):
|
||||
return _bcoo_dot_general_sampled_simple(A, B, indices, dimension_numbers=dimension_numbers)
|
||||
|
||||
return _bcoo_dot_general_sampled_slow(A, B, indices, dimension_numbers=dimension_numbers)
|
||||
|
||||
|
||||
@bcoo_dot_general_sampled_p.def_abstract_eval
|
||||
def _bcoo_dot_general_sampled_abstract_eval(A, B, indices, *, dimension_numbers):
|
||||
|
@ -1294,6 +1294,30 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
# TODO(jakevdp) fix forward-mode autodiff & enable tests here.
|
||||
self._CheckGradsSparse(dense_fun, sparse_fun, args_maker, modes=['rev'], argnums=[0, 1])
|
||||
|
||||
@jtu.sample_product(
|
||||
xshape=[(3,), (5,)],
|
||||
yshape=[(3,), (5,)],
|
||||
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
|
||||
n_batch=[0, 1, 2],
|
||||
)
|
||||
def test_bcoo_dot_general_sampled_fast(self, xshape, yshape, n_batch, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
sprng = sptu.rand_bcoo(self.rng(), n_batch=n_batch)
|
||||
|
||||
dimension_numbers = (([], []), ([], []))
|
||||
args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype),
|
||||
sprng(xshape + yshape, dtype).indices]
|
||||
|
||||
def f1(x, y, indices):
|
||||
mat_full = lax.dot_general(x, y, dimension_numbers=dimension_numbers)
|
||||
return sparse_bcoo._bcoo_extract(indices, mat_full)
|
||||
|
||||
def f2(x, y, indices):
|
||||
return sparse.bcoo_dot_general_sampled(x, y, indices, dimension_numbers=dimension_numbers)
|
||||
|
||||
self._CheckAgainstNumpy(f1, f2, args_maker)
|
||||
self._CompileAndCheck(f2, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(n_batch=n_batch, n_dense=n_dense, lhs_shape=lhs_shape,
|
||||
rhs_shape=rhs_shape, dimension_numbers=dimension_numbers)
|
||||
|
Loading…
x
Reference in New Issue
Block a user