Merge pull request #10184 from jakevdp:merge-bcoo-dot-general

PiperOrigin-RevId: 440178509
This commit is contained in:
jax authors 2022-04-07 12:57:19 -07:00
commit 28cb44e8f6
4 changed files with 24 additions and 51 deletions

View File

@ -201,8 +201,6 @@ from jax.experimental.sparse.bcoo import (
bcoo_multiply_dense as bcoo_multiply_dense,
bcoo_multiply_sparse as bcoo_multiply_sparse,
bcoo_reduce_sum as bcoo_reduce_sum,
bcoo_rdot_general as bcoo_rdot_general,
bcoo_spdot_general as bcoo_spdot_general,
bcoo_spdot_general_p as bcoo_spdot_general_p,
bcoo_todense as bcoo_todense,
bcoo_todense_p as bcoo_todense_p,

View File

@ -622,17 +622,31 @@ def bcoo_dot_general(lhs, rhs, *, dimension_numbers):
"""A general contraction operation.
Args:
lhs: A BCOO-format array.
rhs: An ndarray.
lhs: An ndarray or BCOO-format sparse array.
rhs: An ndarray or BCOO-format sparse array..
dimension_numbers: a tuple of tuples of the form
`((lhs_contracting_dims, rhs_contracting_dims),
(lhs_batch_dims, rhs_batch_dims))`.
Returns:
An ndarray containing the result.
An ndarray or BCOO-format sparse array containing the result. If both inputs
are sparse, the result will be sparse, of type BCOO. If either input is dense,
the result will be dense, of type ndarray.
"""
return _bcoo_dot_general(*lhs._bufs, rhs, dimension_numbers=dimension_numbers,
lhs_spinfo=lhs._info)
if isinstance(lhs, BCOO) and isinstance(rhs, BCOO):
shape = _dot_general_validated_shape(lhs.shape, rhs.shape, dimension_numbers)
bufs = _bcoo_spdot_general(lhs.data, lhs.indices, rhs.data, rhs.indices,
lhs_spinfo=lhs._info, rhs_spinfo=rhs._info,
dimension_numbers=dimension_numbers)
return BCOO(bufs, shape=shape)
elif isinstance(lhs, BCOO):
return _bcoo_dot_general(*lhs._bufs, rhs, dimension_numbers=dimension_numbers,
lhs_spinfo=lhs._info)
elif isinstance(rhs, BCOO):
return _bcoo_rdot_general(lhs, *rhs._bufs, dimension_numbers=dimension_numbers,
rhs_spinfo=rhs._info)
else:
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)
def _bcoo_dot_general(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: BCOOInfo):
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
@ -644,23 +658,6 @@ def _bcoo_dot_general(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spin
dimension_numbers=(cdims, bdims),
lhs_spinfo=lhs_spinfo)
def bcoo_rdot_general(lhs, rhs, *, dimension_numbers: DotDimensionNumbers):
"""A general contraction operation.
Args:
lhs: An ndarray.
rhs: A BCOO-format array.
dimension_numbers: a tuple of tuples of the form
`((lhs_contracting_dims, rhs_contracting_dims),
(lhs_batch_dims, rhs_batch_dims))`.
Returns:
An ndarray containing the result.
"""
return _bcoo_rdot_general(lhs, rhs.data, rhs.indices,
dimension_numbers=dimension_numbers,
rhs_spinfo=rhs._info)
def _bcoo_rdot_general(lhs, rhs_data, rhs_indices, *, dimension_numbers: DotDimensionNumbers, rhs_spinfo: BCOOInfo):
# TODO(jakevdp): perhaps this should be part of the bcoo_dot_general primitive?
result = _bcoo_dot_general(rhs_data, rhs_indices, lhs, lhs_spinfo=rhs_spinfo,
@ -1017,25 +1014,6 @@ xla.register_translation(bcoo_dot_general_sampled_p, xla.lower_fun(
bcoo_spdot_general_p = core.Primitive('bcoo_spdot_general')
bcoo_spdot_general_p.multiple_results = True
def bcoo_spdot_general(lhs, rhs, *, dimension_numbers: DotDimensionNumbers):
"""A general contraction operation.
Args:
lhs: A BCOO-format array.
rhs: A BCOO-format array.
dimension_numbers: a tuple of tuples of the form
`((lhs_contracting_dims, rhs_contracting_dims),
(lhs_batch_dims, rhs_batch_dims))`.
Returns:
A BCOO array containing the result.
"""
shape = _dot_general_validated_shape(lhs.shape, rhs.shape, dimension_numbers)
data, indices = _bcoo_spdot_general(lhs.data, lhs.indices, rhs.data, rhs.indices,
lhs_spinfo=lhs._info, rhs_spinfo=rhs._info,
dimension_numbers=dimension_numbers)
return BCOO((data, indices), shape=shape)
def _bcoo_spdot_general(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo: BCOOInfo, rhs_spinfo: BCOOInfo, dimension_numbers: DotDimensionNumbers):
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
cdims = (api_util._ensure_index_tuple(lhs_contract),

View File

@ -434,14 +434,8 @@ for _prim in [
def _dot_general_sparse(spenv, *spvalues, dimension_numbers, precision, preferred_element_type):
# TODO(jakevdp): pass along these unused configurations?
del precision, preferred_element_type # unused
if spvalues[0].is_sparse() and spvalues[1].is_sparse():
func = sparse.bcoo_spdot_general
elif spvalues[0].is_sparse():
func = sparse.bcoo_dot_general
else:
func = sparse.bcoo_rdot_general
A, B = spvalues_to_arrays(spenv, spvalues)
result = func(A, B, dimension_numbers=dimension_numbers)
result = sparse.bcoo_dot_general(*spvalues_to_arrays(spenv, spvalues),
dimension_numbers=dimension_numbers)
return arrays_to_spvalues(spenv, [result])
sparse_rules[lax.dot_general_p] = _dot_general_sparse

View File

@ -1815,6 +1815,9 @@ class BCOOTest(jtu.JaxTestCase):
# bcoo_dot_general
self.assertArraysEqual(x_sp @ y_de, x_de @ y_de)
# bcoo_rdot_general
self.assertArraysEqual(x_de @ y_sp, x_de @ y_de)
# bcoo_spdot_general
self.assertArraysEqual((x_sp @ y_sp).todense(), x_de @ y_de)
self.assertArraysEqual((y_sp @ x_sp).todense(), y_de @ x_de)