mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #9964 from jakevdp:bcoo-mul
PiperOrigin-RevId: 435905671
This commit is contained in:
commit
1ffa285bd6
@ -1209,13 +1209,14 @@ def bcoo_multiply_sparse(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_sp
|
||||
# Similar requirement as lax.mul:
|
||||
raise TypeError("bcoo_multiply_sparse: arrays must have same number of dimensions, "
|
||||
f"got {lhs_shape}, {rhs_shape}")
|
||||
if (lhs.n_batch, lhs.n_sparse, lhs.n_dense) != (rhs.n_batch, rhs.n_sparse, rhs.n_dense):
|
||||
if lhs.n_dense != rhs.n_dense:
|
||||
raise NotImplementedError("bcoo_multiply_sparse: arrays with differing numbers of "
|
||||
f"batch & dense dimensions: {lhs}, {rhs}")
|
||||
f"dense dimensions: {lhs}, {rhs}")
|
||||
n_batch = min(lhs.n_batch, rhs.n_batch)
|
||||
_mul = functools.partial(_bcoo_multiply_sparse_unbatched,
|
||||
lhs_shape=lhs_shape[lhs.n_batch:],
|
||||
rhs_shape=rhs_shape[rhs.n_batch:])
|
||||
for _ in range(lhs.n_batch):
|
||||
lhs_shape=lhs_shape[n_batch:],
|
||||
rhs_shape=rhs_shape[n_batch:])
|
||||
for _ in range(n_batch):
|
||||
_mul = broadcasting_vmap(_mul)
|
||||
data, indices = _mul(lhs_data, lhs_indices, rhs_data, rhs_indices)
|
||||
return data, indices, jnp.broadcast_shapes(lhs_shape, rhs_shape)
|
||||
@ -1223,7 +1224,15 @@ def bcoo_multiply_sparse(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_sp
|
||||
def _bcoo_multiply_sparse_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_shape, rhs_shape):
|
||||
lhs = _validate_bcoo(lhs_data, lhs_indices, lhs_shape)
|
||||
rhs = _validate_bcoo(rhs_data, rhs_indices, rhs_shape)
|
||||
assert lhs.n_batch == rhs.n_batch == 0
|
||||
assert (lhs.n_batch == 0) or (rhs.n_batch == 0) # Ensured at call site above
|
||||
|
||||
# TODO(jakevdp): this can be made more efficient by utilizing batch structure.
|
||||
if lhs.n_batch:
|
||||
lhs_data, lhs_indices = _unbatch_bcoo(lhs_data, lhs_indices, lhs_shape)
|
||||
lhs = _validate_bcoo(lhs_data, lhs_indices, lhs_shape)
|
||||
elif rhs.n_batch:
|
||||
rhs_data, rhs_indices = _unbatch_bcoo(rhs_data, rhs_indices, rhs_shape)
|
||||
rhs = _validate_bcoo(rhs_data, rhs_indices, rhs_shape)
|
||||
dims = jnp.array([i for i, (s1, s2) in enumerate(safe_zip(lhs_shape[:lhs.n_sparse], rhs_shape[:rhs.n_sparse]))
|
||||
if s1 != 1 and s2 != 1], dtype=int)
|
||||
|
||||
|
@ -1659,13 +1659,12 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(out1, out2, rtol=tol)
|
||||
self.assertAllClose(out1, out3, rtol=tol)
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_{}_n_batch={}_n_dense={}".format(
|
||||
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
|
||||
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype),
|
||||
n_batch, n_dense),
|
||||
{"testcase_name": "_{}_n_batch={}_{}_n_batch={}_n_dense={}".format(
|
||||
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), lhs_n_batch,
|
||||
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype), rhs_n_batch, n_dense),
|
||||
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
|
||||
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
|
||||
"n_batch": n_batch, "n_dense": n_dense,
|
||||
"lhs_n_batch": lhs_n_batch, "rhs_n_batch": rhs_n_batch, "n_dense": n_dense,
|
||||
}
|
||||
# TODO(jakevdp): add broadcasted shapes (from bcoo_mul_dense) once sparse-sparse mul
|
||||
# supports inputs of differing rank.
|
||||
@ -1673,19 +1672,21 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
[(3, 4), (1, 1)], [(3, 4), (1, 4)], [(3, 4), (3, 1)], [(3, 4), (3, 4)],
|
||||
[(3, 4, 5), (1, 4, 5)], [(3, 4, 5), (3, 1, 1)], [(3, 4, 5), (1, 4, 1)]]
|
||||
# TODO(jakevdp): add tests for batch & dense dimensions.
|
||||
for n_batch in range(len(lhs_shape) + 1)
|
||||
for n_dense in range(len(lhs_shape) + 1 - n_batch)
|
||||
for lhs_n_batch in range(len(lhs_shape) + 1)
|
||||
for rhs_n_batch in range(len(lhs_shape) + 1)
|
||||
for n_dense in range(len(lhs_shape) + 1 - max(lhs_n_batch, rhs_n_batch))
|
||||
for lhs_dtype in all_dtypes
|
||||
for rhs_dtype in all_dtypes))
|
||||
def test_bcoo_mul_sparse(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, n_batch, n_dense):
|
||||
def test_bcoo_mul_sparse(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, lhs_n_batch, rhs_n_batch, n_dense):
|
||||
rng = rand_sparse(self.rng())
|
||||
lhs = jnp.array(rng(lhs_shape, lhs_dtype))
|
||||
rhs = jnp.array(rng(rhs_shape, rhs_dtype))
|
||||
|
||||
sp = lambda x: sparse.BCOO.fromdense(x, n_batch=n_batch, n_dense=n_dense)
|
||||
lhs_sp = sparse.BCOO.fromdense(lhs, n_batch=lhs_n_batch, n_dense=n_dense)
|
||||
rhs_sp = sparse.BCOO.fromdense(rhs, n_batch=rhs_n_batch, n_dense=n_dense)
|
||||
|
||||
out1 = lhs * rhs
|
||||
out2 = (sp(lhs) * sp(rhs)).todense()
|
||||
out2 = (lhs_sp * rhs_sp).todense()
|
||||
|
||||
tol = {np.float64: 1E-13, np.complex128: 1E-13,
|
||||
np.float32: 1E-6, np.complex64: 1E-6}
|
||||
|
Loading…
x
Reference in New Issue
Block a user