mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sparse] spdot_general: implement many more cases
This commit is contained in:
parent
d42255486b
commit
848675df45
@ -47,12 +47,13 @@ def broadcasting_vmap(fun, in_axes=0, out_axes=0):
|
||||
def batched_fun(*args):
|
||||
args_flat, in_tree = tree_util.tree_flatten(args)
|
||||
in_axes_flat = flatten_axes("vmap in_axes", in_tree, in_axes, kws=False)
|
||||
size = max(arg.shape[i] for arg, i in safe_zip(args_flat, in_axes_flat))
|
||||
size = max(arg.shape[i] for arg, i in safe_zip(args_flat, in_axes_flat) if i is not None)
|
||||
if size > 1:
|
||||
if any(arg.shape[i] not in (1, size) for arg, i in safe_zip(args_flat, in_axes_flat)):
|
||||
if any(i is not None and arg.shape[i] not in (1, size)
|
||||
for arg, i in safe_zip(args_flat, in_axes_flat)):
|
||||
raise ValueError("broadcasting_vmap: mismatched input shapes")
|
||||
args_flat, in_axes_flat = zip(*(
|
||||
(lax.squeeze(arg, (i,)), None) if arg.shape[i] == 1 else (arg, i)
|
||||
(arg, None) if i is None else (lax.squeeze(arg, (i,)), None) if arg.shape[i] == 1 else (arg, i)
|
||||
for arg, i in zip(args_flat, in_axes_flat)
|
||||
))
|
||||
new_args = tree_util.tree_unflatten(in_tree, args_flat)
|
||||
@ -728,65 +729,75 @@ def bcoo_spdot_general(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_shap
|
||||
return bcoo_spdot_general_p.bind(lhs_data, lhs_indices, rhs_data, rhs_indices,
|
||||
lhs_shape=lhs_shape, rhs_shape=rhs_shape, dimension_numbers=dimension_numbers)
|
||||
|
||||
def _bcoo_Mv(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_shape, rhs_shape, dtype, lhs_contract):
|
||||
"""Helper function to compute the dot product of a sparse array and a sparse vector."""
|
||||
def _bcoo_spdot_general_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_shape, rhs_shape, lhs_contracting, rhs_contracting):
|
||||
lhs = _validate_bcoo(lhs_data, lhs_indices, lhs_shape)
|
||||
rhs = _validate_bcoo(rhs_data, rhs_indices, rhs_shape)
|
||||
# Inputs should be unbatched; batching is handled by vmapping at the call site.
|
||||
|
||||
assert lhs.n_batch == rhs.n_batch == 0
|
||||
assert lhs.n_dense == rhs.n_dense == 0
|
||||
assert lhs.n_sparse >= 1
|
||||
assert rhs.n_sparse == 1
|
||||
assert (lhs_shape[lhs_contract],) == rhs_shape
|
||||
rhs_data, rhs_indices = _bcoo_sum_duplicates(rhs_data, rhs_indices, rhs_shape, nse=rhs.nse)
|
||||
lhs_i = lhs_indices[:, lhs_contract]
|
||||
rhs_i = rhs_indices[:, 0]
|
||||
mask = jnp.isin(lhs_i, rhs_i, assume_unique=True)
|
||||
lhs_i_inv = (lhs_i[None, :] == rhs_i[:, None]).argmax(0)
|
||||
lhs_i_inv = jnp.where(lhs_i < rhs_shape[0], lhs_i_inv, rhs_shape[0])
|
||||
rhs_data_at_lhs_indices = jnp.where(mask, rhs_data.at[lhs_i_inv].get(mode='fill', fill_value=0), 0)
|
||||
out_data = lhs_data.at[jnp.arange(lhs.nse)].mul(rhs_data_at_lhs_indices)
|
||||
out_indices = jnp.concatenate([lhs_indices[:, :lhs_contract], lhs_indices[:, lhs_contract + 1:]], axis=1)
|
||||
return out_data, out_indices
|
||||
assert [lhs_shape[d] for d in lhs_contracting] == [rhs_shape[d] for d in rhs_contracting]
|
||||
assert max(lhs_contracting, default=-1) < lhs.n_sparse
|
||||
assert max(rhs_contracting, default=-1) < rhs.n_sparse
|
||||
|
||||
out_shape = (
|
||||
[s for i, s in enumerate(lhs_shape) if i not in lhs_contracting] +
|
||||
[s for i, s in enumerate(rhs_shape) if i not in rhs_contracting])
|
||||
|
||||
lhs_i = lhs_indices[:, jnp.array(lhs_contracting, dtype=int)]
|
||||
rhs_i = rhs_indices[:, jnp.array(rhs_contracting, dtype=int)]
|
||||
lhs_j = lhs_indices[:, jnp.array(remaining(range(lhs.n_sparse), lhs_contracting), dtype=int)]
|
||||
rhs_j = rhs_indices[:, jnp.array(remaining(range(rhs.n_sparse), rhs_contracting), dtype=int)]
|
||||
|
||||
# TODO(jakevdp): can we do this more efficiently than using an outer product? Note that
|
||||
# jnp.isin() currently doesn't help much, because it also does all() over an outer
|
||||
# comparison.
|
||||
overlap = (lhs_i[:, None] == rhs_i[None, :]).all(-1)
|
||||
lhs_valid = (lhs_i < jnp.array([lhs_shape[d] for d in lhs_contracting])).all(-1)
|
||||
rhs_valid = (rhs_i < jnp.array([rhs_shape[d] for d in rhs_contracting])).all(-1)
|
||||
out_data = jnp.where(overlap & lhs_valid[:, None] & rhs_valid,
|
||||
lhs_data[:, None] * rhs_data[None, :], 0).ravel()
|
||||
|
||||
out_indices = jnp.empty([lhs.nse, rhs.nse, lhs_j.shape[-1] + rhs_j.shape[-1]],
|
||||
dtype=jnp.result_type(lhs_indices, rhs_indices))
|
||||
out_indices = out_indices.at[:, :, :lhs_j.shape[-1]].set(lhs_j[:, None])
|
||||
out_indices = out_indices.at[:, :, lhs_j.shape[-1]:].set(rhs_j[None, :])
|
||||
out_indices = out_indices.reshape(len(out_data), out_indices.shape[-1])
|
||||
out_nse = (lhs.nse if lhs_j.shape[1] else 1) * (rhs.nse if rhs_j.shape[1] else 1)
|
||||
return _bcoo_sum_duplicates(out_data, out_indices, out_shape, nse=out_nse)
|
||||
|
||||
@bcoo_spdot_general_p.def_impl
|
||||
def _bcoo_spdot_general_impl(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_shape, rhs_shape, dimension_numbers):
|
||||
out_shape = _dot_general_validated_shape(lhs_shape, rhs_shape, dimension_numbers)
|
||||
lhs = _validate_bcoo(lhs_data, lhs_indices, lhs_shape)
|
||||
rhs = _validate_bcoo(rhs_data, rhs_indices, rhs_shape)
|
||||
assert lhs.n_dense == rhs.n_dense == 0
|
||||
data_aval, indices_aval = _bcoo_spdot_general_abstract_eval(
|
||||
lhs_data.aval, lhs_indices.aval, rhs_data.aval, rhs_indices.aval,
|
||||
lhs_shape=lhs_shape, rhs_shape=rhs_shape, dimension_numbers=dimension_numbers)
|
||||
out_shape = _dot_general_validated_shape(lhs_shape, rhs_shape, dimension_numbers)
|
||||
_validate_bcoo(data_aval, indices_aval, out_shape)
|
||||
lhs = _validate_bcoo(lhs_data, lhs_indices, lhs_shape)
|
||||
rhs = _validate_bcoo(rhs_data, rhs_indices, rhs_shape)
|
||||
|
||||
# Move batch dimension to front
|
||||
(lhs_contracting, _), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
|
||||
lhs_perm = tuple(lhs_batch) + tuple(i for i in range(lhs.n_batch) if i not in lhs_batch)
|
||||
rhs_perm = tuple(rhs_batch) + tuple(i for i in range(rhs.n_batch) if i not in rhs_batch)
|
||||
lhs_indices = lhs_indices.transpose(lhs_perm + (lhs.n_batch, lhs.n_batch + 1))
|
||||
rhs_indices = rhs_indices.transpose(rhs_perm + (rhs.n_batch, rhs.n_batch + 1))
|
||||
lhs_data = lhs_data.transpose(lhs_perm + (lhs.n_batch,))
|
||||
rhs_data = rhs_data.transpose(rhs_perm + (rhs.n_batch,))
|
||||
# Move batch dimensions to front of each array.
|
||||
lhs_batch_perm = [*lhs_batch, *remaining(range(lhs.n_batch), lhs_batch)]
|
||||
rhs_batch_perm = [*rhs_batch, *remaining(range(rhs.n_batch), rhs_batch)]
|
||||
lhs_data = lhs_data.transpose([*lhs_batch_perm, *range(lhs.n_batch, lhs_data.ndim)])
|
||||
rhs_data = rhs_data.transpose([*rhs_batch_perm, *range(rhs.n_batch, rhs_data.ndim)])
|
||||
lhs_indices = lhs_indices.transpose([*lhs_batch_perm, *range(lhs.n_batch, lhs_indices.ndim)])
|
||||
rhs_indices = rhs_indices.transpose([*rhs_batch_perm, *range(rhs.n_batch, rhs_indices.ndim)])
|
||||
|
||||
# Implement batched dot product via vmap
|
||||
func = functools.partial(_bcoo_Mv,
|
||||
func = functools.partial(_bcoo_spdot_general_unbatched,
|
||||
lhs_shape=lhs_shape[lhs.n_batch:], rhs_shape=rhs_shape[rhs.n_batch:],
|
||||
dtype=data_aval.dtype, lhs_contract=lhs_contracting[0] - lhs.n_batch)
|
||||
lhs_contracting=[d - lhs.n_batch for d in lhs_contracting],
|
||||
rhs_contracting=[d - rhs.n_batch for d in rhs_contracting])
|
||||
|
||||
if rhs_data.shape[:rhs.n_batch] != rhs_indices.shape[:rhs.n_batch]:
|
||||
raise NotImplementedError("unequal batches in rhs")
|
||||
if lhs_data.shape[:lhs.n_batch] != lhs_indices.shape[:lhs.n_batch]:
|
||||
raise NotImplementedError("unequal batches in lhs")
|
||||
|
||||
for dim in reversed(range(len(rhs_batch), rhs.n_batch)):
|
||||
func = vmap(func, in_axes=(None, None, 0, 0))
|
||||
for dim in reversed(range(len(lhs_batch), lhs.n_batch)):
|
||||
func = vmap(func, in_axes=(0, 0, None, None))
|
||||
for dim in range(len(lhs_batch)):
|
||||
if lhs_data.shape[dim] != rhs_data.shape[dim]:
|
||||
raise NotImplementedError("unequal batches in batched dims")
|
||||
func = vmap(func, in_axes=0)
|
||||
for _ in reversed(range(len(rhs_batch), rhs.n_batch)):
|
||||
func = broadcasting_vmap(func, in_axes=(None, None, 0, 0))
|
||||
for _ in reversed(range(len(lhs_batch), lhs.n_batch)):
|
||||
func = broadcasting_vmap(func, in_axes=(0, 0, None, None))
|
||||
for _ in range(len(lhs_batch)):
|
||||
func = broadcasting_vmap(func, in_axes=0)
|
||||
return func(lhs_data, lhs_indices, rhs_data, rhs_indices)
|
||||
|
||||
@bcoo_spdot_general_p.def_abstract_eval
|
||||
@ -796,34 +807,40 @@ def _bcoo_spdot_general_abstract_eval(lhs_data, lhs_indices, rhs_data, rhs_indic
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
_ = _dot_general_validated_shape(lhs_shape, rhs_shape, dimension_numbers)
|
||||
|
||||
if not (lhs.n_dense == rhs.n_dense == 0):
|
||||
if lhs.n_dense or rhs.n_dense:
|
||||
# TODO(jakevdp): handle dense dimensions
|
||||
raise NotImplementedError("bcoo_spdot_general with dense dimensions.")
|
||||
|
||||
if not (rhs.n_sparse == 1):
|
||||
raise NotImplementedError("bcoo_spdot_general with n_sparse != 1 on the rhs.")
|
||||
|
||||
if max(lhs_batch, default=-1) >= lhs.n_batch or max(rhs_batch, default=-1) >= rhs.n_batch:
|
||||
if (lhs_batch and max(lhs_batch) >= lhs.n_batch) or (rhs_batch and max(rhs_batch) >= rhs.n_batch):
|
||||
raise NotImplementedError("bcoo_spdot_general: batch_dims must correspond to batch dimensions of the sparse representation.")
|
||||
|
||||
if tuple(rhs_contracting) != (rhs.n_batch,) or lhs_contracting[0] not in range(lhs.n_batch, lhs.n_batch + lhs.n_sparse):
|
||||
if lhs_contracting and (min(lhs_contracting) < lhs.n_batch or max(lhs_contracting) >= lhs.n_batch + lhs.n_sparse):
|
||||
raise NotImplementedError("bcoo_spdot_general only supports contraction of sparse indices.")
|
||||
|
||||
if rhs_contracting and (min(rhs_contracting) < rhs.n_batch or max(rhs_contracting) >= rhs.n_batch + rhs.n_sparse):
|
||||
raise NotImplementedError("bcoo_spdot_general only supports contraction of sparse indices.")
|
||||
|
||||
if rhs.n_batch > len(rhs_batch) and lhs.n_sparse > len(lhs_contracting):
|
||||
raise ValueError("Cannot have unused batch dims on rhs with unused sparse dims on lhs.")
|
||||
raise ValueError("bcoo_spdot_general: cannot have unused batch dims on rhs with unused sparse dims on lhs.")
|
||||
|
||||
out_nse = (
|
||||
(lhs.nse if lhs.n_sparse > len(lhs_contracting) else 1) *
|
||||
(rhs.nse if rhs.n_sparse > len(rhs_contracting) else 1)
|
||||
)
|
||||
|
||||
data_shape = (
|
||||
*(lhs_shape[dim] for dim in lhs_batch),
|
||||
*(lhs_data.shape[dim] for dim in range(lhs.n_batch) if dim not in lhs_batch),
|
||||
*(rhs_data.shape[dim] for dim in range(rhs.n_batch) if dim not in rhs_batch),
|
||||
lhs.nse)
|
||||
out_nse)
|
||||
indices_shape = (
|
||||
*(lhs_shape[dim] for dim in lhs_batch),
|
||||
*(lhs_indices.shape[dim] for dim in range(lhs.n_batch) if dim not in lhs_batch),
|
||||
*(rhs_indices.shape[dim] for dim in range(rhs.n_batch) if dim not in rhs_batch),
|
||||
lhs.nse, lhs.n_sparse - len(lhs_contracting))
|
||||
out_dtype = jnp.promote_types(lhs_data.dtype, rhs_data.dtype)
|
||||
return core.ShapedArray(data_shape, out_dtype), core.ShapedArray(indices_shape, lhs_indices.dtype)
|
||||
out_nse, lhs.n_sparse + rhs.n_sparse - 2 * len(lhs_contracting))
|
||||
data_dtype = jnp.result_type(lhs_data, rhs_data)
|
||||
indices_dtype = jnp.result_type(lhs_indices, rhs_indices)
|
||||
return core.ShapedArray(data_shape, data_dtype), core.ShapedArray(indices_shape, indices_dtype)
|
||||
|
||||
def _bcoo_spdot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers, lhs_shape, rhs_shape):
|
||||
lhs_data, lhs_indices, rhs_data, rhs_indices = batched_args
|
||||
|
@ -1059,14 +1059,18 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
|
||||
@unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision")
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}[n_batch={}]_{}[n_batch={}]_dims={}".format(
|
||||
{"testcase_name": "_{}[n_batch={}]_{}[n_batch={}]_swap={}_dims={}".format(
|
||||
jtu.format_shape_dtype_string(lhs_shape, dtype), lhs_n_batch,
|
||||
jtu.format_shape_dtype_string(rhs_shape, dtype), rhs_n_batch,
|
||||
dimension_numbers),
|
||||
swap, dimension_numbers),
|
||||
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape,
|
||||
"lhs_n_batch": lhs_n_batch, "rhs_n_batch": rhs_n_batch,
|
||||
"dimension_numbers": dimension_numbers, "dtype": dtype}
|
||||
"dimension_numbers": dimension_numbers, "swap": swap, "dtype": dtype}
|
||||
for lhs_shape, lhs_n_batch, rhs_shape, rhs_n_batch, dimension_numbers in [
|
||||
# (batched) outer products (no contraction)
|
||||
((5,), 0, (6,), 0, (([], []), ([], []))),
|
||||
((3, 5), 0, (2, 4), 0, (([], []), ([], []))),
|
||||
((3, 5), 1, (3, 4), 1, (([], []), ([0], [0]))),
|
||||
# (batched) vector-vector products
|
||||
((5,), 0, (5,), 0, (([0], [0]), ([], []))),
|
||||
((7,), 0, (7,), 0, (([0], [0]), ([], []))),
|
||||
@ -1081,9 +1085,27 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
((2, 3, 4), 1, (2, 4), 1, (([2], [1]), ([0], [0]))),
|
||||
((3, 2, 4), 1, (3, 4), 1, (([2], [1]), ([0], [0]))),
|
||||
((2, 3, 4), 0, (2,), 0, (([0], [0]), ([], []))),
|
||||
# (batched) matrix-matrix products
|
||||
((5, 7), 0, (7, 3), 0, (([1], [0]), ([], []))),
|
||||
((2, 3, 4), 1, (4, 3), 0, (([2], [0]), ([], []))),
|
||||
((2, 3, 4), 1, (2, 4, 3), 1, (([2], [1]), ([0], [0]))),
|
||||
# more general operations
|
||||
((2, 3, 4, 3), 1, (2, 4, 3, 4), 1, (([2, 3], [1, 2]), ([0], [0]))),
|
||||
((2, 3, 4, 3, 1), 2, (3, 2, 3, 4), 2, (([2, 3], [3, 2]), ([0, 1], [1, 0]))),
|
||||
]
|
||||
for swap in [True, False]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
||||
def test_bcoo_spdot_general(self, lhs_shape, lhs_n_batch, rhs_shape, rhs_n_batch, dtype, dimension_numbers):
|
||||
def test_bcoo_spdot_general(self, lhs_shape, lhs_n_batch, rhs_shape, rhs_n_batch, dtype, swap, dimension_numbers):
|
||||
if swap:
|
||||
dimension_numbers = tuple(d[::-1] for d in dimension_numbers)
|
||||
lhs_shape, rhs_shape = rhs_shape, lhs_shape
|
||||
lhs_n_batch, rhs_n_batch = rhs_n_batch, lhs_n_batch
|
||||
|
||||
lhs_n_sparse = len(lhs_shape) - lhs_n_batch
|
||||
rhs_batch = dimension_numbers[1][1]
|
||||
lhs_contracting = dimension_numbers[0][0]
|
||||
should_error = (rhs_n_batch > len(rhs_batch) and lhs_n_sparse > len(lhs_contracting))
|
||||
|
||||
sprng = rand_sparse(self.rng())
|
||||
def args_maker():
|
||||
x = sprng(lhs_shape, dtype)
|
||||
@ -1096,16 +1118,34 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
return lax.dot_general(x, y, dimension_numbers=dimension_numbers)
|
||||
|
||||
def f_sparse(x, y, xsp, ysp):
|
||||
shape = sparse.bcoo._dot_general_validated_shape(x.shape, y.shape, dimension_numbers)
|
||||
shape = sparse.bcoo._dot_general_validated_shape(xsp.shape, ysp.shape, dimension_numbers)
|
||||
data, indices = sparse.bcoo_spdot_general(xsp.data, xsp.indices, ysp.data, ysp.indices,
|
||||
lhs_shape=x.shape, rhs_shape=y.shape,
|
||||
dimension_numbers=dimension_numbers)
|
||||
return sparse.bcoo_todense(data, indices, shape=shape)
|
||||
|
||||
self._CheckAgainstNumpy(f_dense, f_sparse, args_maker)
|
||||
# TODO(jakevdp): This occasionally fails python_should_be_executing check. Why?
|
||||
# self._CompileAndCheck(f_sparse, args_maker)
|
||||
self._CheckAgainstNumpy(jit(f_dense), jit(f_sparse), args_maker)
|
||||
tol = {"complex128": 1E-14}
|
||||
if should_error:
|
||||
with self.assertRaisesRegex(ValueError, ".*cannot have unused batch dims on rhs with unused sparse dims on lhs."):
|
||||
f_sparse(*args_maker())
|
||||
else:
|
||||
self._CheckAgainstNumpy(f_dense, f_sparse, args_maker, tol=tol)
|
||||
self._CheckAgainstNumpy(jit(f_dense), jit(f_sparse), args_maker, tol=tol)
|
||||
# TODO(jakevdp): This occasionally fails python_should_be_executing check. Why?
|
||||
# self._CompileAndCheck(f_sparse, args_maker)
|
||||
|
||||
def test_bcoo_spdot_general_nse(self):
|
||||
# vector-vector product -> nse=1
|
||||
x = sparse.BCOO.fromdense(jnp.arange(3))
|
||||
self.assertEqual((x @ x).nse, 1)
|
||||
|
||||
# matrix-vector product -> nse matches matrix
|
||||
M = sparse.BCOO.fromdense(jnp.arange(6).reshape(2, 3))
|
||||
self.assertEqual((M @ x).nse, M.nse)
|
||||
|
||||
# matrix-matrix product -> product of nse
|
||||
N = sparse.BCOO.fromdense(jnp.arange(12).reshape(3, 4))
|
||||
self.assertEqual((M @ N).nse, M.nse * N.nse)
|
||||
|
||||
@unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision")
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
|
Loading…
x
Reference in New Issue
Block a user