Merge pull request #9964 from jakevdp:bcoo-mul

PiperOrigin-RevId: 435905671
This commit is contained in:
jax authors 2022-03-19 12:16:15 -07:00
commit 1ffa285bd6
2 changed files with 26 additions and 16 deletions

View File

@ -1209,13 +1209,14 @@ def bcoo_multiply_sparse(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_sp
# Similar requirement as lax.mul:
raise TypeError("bcoo_multiply_sparse: arrays must have same number of dimensions, "
f"got {lhs_shape}, {rhs_shape}")
if (lhs.n_batch, lhs.n_sparse, lhs.n_dense) != (rhs.n_batch, rhs.n_sparse, rhs.n_dense):
if lhs.n_dense != rhs.n_dense:
raise NotImplementedError("bcoo_multiply_sparse: arrays with differing numbers of "
f"batch & dense dimensions: {lhs}, {rhs}")
f"dense dimensions: {lhs}, {rhs}")
n_batch = min(lhs.n_batch, rhs.n_batch)
_mul = functools.partial(_bcoo_multiply_sparse_unbatched,
lhs_shape=lhs_shape[lhs.n_batch:],
rhs_shape=rhs_shape[rhs.n_batch:])
for _ in range(lhs.n_batch):
lhs_shape=lhs_shape[n_batch:],
rhs_shape=rhs_shape[n_batch:])
for _ in range(n_batch):
_mul = broadcasting_vmap(_mul)
data, indices = _mul(lhs_data, lhs_indices, rhs_data, rhs_indices)
return data, indices, jnp.broadcast_shapes(lhs_shape, rhs_shape)
@ -1223,7 +1224,15 @@ def bcoo_multiply_sparse(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_sp
def _bcoo_multiply_sparse_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_shape, rhs_shape):
lhs = _validate_bcoo(lhs_data, lhs_indices, lhs_shape)
rhs = _validate_bcoo(rhs_data, rhs_indices, rhs_shape)
assert lhs.n_batch == rhs.n_batch == 0
assert (lhs.n_batch == 0) or (rhs.n_batch == 0) # Ensured at call site above
# TODO(jakevdp): this can be made more efficient by utilizing batch structure.
if lhs.n_batch:
lhs_data, lhs_indices = _unbatch_bcoo(lhs_data, lhs_indices, lhs_shape)
lhs = _validate_bcoo(lhs_data, lhs_indices, lhs_shape)
elif rhs.n_batch:
rhs_data, rhs_indices = _unbatch_bcoo(rhs_data, rhs_indices, rhs_shape)
rhs = _validate_bcoo(rhs_data, rhs_indices, rhs_shape)
dims = jnp.array([i for i, (s1, s2) in enumerate(safe_zip(lhs_shape[:lhs.n_sparse], rhs_shape[:rhs.n_sparse]))
if s1 != 1 and s2 != 1], dtype=int)

View File

@ -1659,13 +1659,12 @@ class BCOOTest(jtu.JaxTestCase):
self.assertAllClose(out1, out2, rtol=tol)
self.assertAllClose(out1, out3, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_n_batch={}_n_dense={}".format(
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype),
n_batch, n_dense),
{"testcase_name": "_{}_n_batch={}_{}_n_batch={}_n_dense={}".format(
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), lhs_n_batch,
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype), rhs_n_batch, n_dense),
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
"n_batch": n_batch, "n_dense": n_dense,
"lhs_n_batch": lhs_n_batch, "rhs_n_batch": rhs_n_batch, "n_dense": n_dense,
}
# TODO(jakevdp): add broadcasted shapes (from bcoo_mul_dense) once sparse-sparse mul
# supports inputs of differing rank.
@ -1673,19 +1672,21 @@ class BCOOTest(jtu.JaxTestCase):
[(3, 4), (1, 1)], [(3, 4), (1, 4)], [(3, 4), (3, 1)], [(3, 4), (3, 4)],
[(3, 4, 5), (1, 4, 5)], [(3, 4, 5), (3, 1, 1)], [(3, 4, 5), (1, 4, 1)]]
# TODO(jakevdp): add tests for batch & dense dimensions.
for n_batch in range(len(lhs_shape) + 1)
for n_dense in range(len(lhs_shape) + 1 - n_batch)
for lhs_n_batch in range(len(lhs_shape) + 1)
for rhs_n_batch in range(len(lhs_shape) + 1)
for n_dense in range(len(lhs_shape) + 1 - max(lhs_n_batch, rhs_n_batch))
for lhs_dtype in all_dtypes
for rhs_dtype in all_dtypes))
def test_bcoo_mul_sparse(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, n_batch, n_dense):
def test_bcoo_mul_sparse(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, lhs_n_batch, rhs_n_batch, n_dense):
rng = rand_sparse(self.rng())
lhs = jnp.array(rng(lhs_shape, lhs_dtype))
rhs = jnp.array(rng(rhs_shape, rhs_dtype))
sp = lambda x: sparse.BCOO.fromdense(x, n_batch=n_batch, n_dense=n_dense)
lhs_sp = sparse.BCOO.fromdense(lhs, n_batch=lhs_n_batch, n_dense=n_dense)
rhs_sp = sparse.BCOO.fromdense(rhs, n_batch=rhs_n_batch, n_dense=n_dense)
out1 = lhs * rhs
out2 = (sp(lhs) * sp(rhs)).todense()
out2 = (lhs_sp * rhs_sp).todense()
tol = {np.float64: 1E-13, np.complex128: 1E-13,
np.float32: 1E-6, np.complex64: 1E-6}