mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #10184 from jakevdp:merge-bcoo-dot-general
PiperOrigin-RevId: 440178509
This commit is contained in:
commit
28cb44e8f6
@ -201,8 +201,6 @@ from jax.experimental.sparse.bcoo import (
|
||||
bcoo_multiply_dense as bcoo_multiply_dense,
|
||||
bcoo_multiply_sparse as bcoo_multiply_sparse,
|
||||
bcoo_reduce_sum as bcoo_reduce_sum,
|
||||
bcoo_rdot_general as bcoo_rdot_general,
|
||||
bcoo_spdot_general as bcoo_spdot_general,
|
||||
bcoo_spdot_general_p as bcoo_spdot_general_p,
|
||||
bcoo_todense as bcoo_todense,
|
||||
bcoo_todense_p as bcoo_todense_p,
|
||||
|
@ -622,17 +622,31 @@ def bcoo_dot_general(lhs, rhs, *, dimension_numbers):
|
||||
"""A general contraction operation.
|
||||
|
||||
Args:
|
||||
lhs: A BCOO-format array.
|
||||
rhs: An ndarray.
|
||||
lhs: An ndarray or BCOO-format sparse array.
|
||||
rhs: An ndarray or BCOO-format sparse array..
|
||||
dimension_numbers: a tuple of tuples of the form
|
||||
`((lhs_contracting_dims, rhs_contracting_dims),
|
||||
(lhs_batch_dims, rhs_batch_dims))`.
|
||||
|
||||
Returns:
|
||||
An ndarray containing the result.
|
||||
An ndarray or BCOO-format sparse array containing the result. If both inputs
|
||||
are sparse, the result will be sparse, of type BCOO. If either input is dense,
|
||||
the result will be dense, of type ndarray.
|
||||
"""
|
||||
return _bcoo_dot_general(*lhs._bufs, rhs, dimension_numbers=dimension_numbers,
|
||||
lhs_spinfo=lhs._info)
|
||||
if isinstance(lhs, BCOO) and isinstance(rhs, BCOO):
|
||||
shape = _dot_general_validated_shape(lhs.shape, rhs.shape, dimension_numbers)
|
||||
bufs = _bcoo_spdot_general(lhs.data, lhs.indices, rhs.data, rhs.indices,
|
||||
lhs_spinfo=lhs._info, rhs_spinfo=rhs._info,
|
||||
dimension_numbers=dimension_numbers)
|
||||
return BCOO(bufs, shape=shape)
|
||||
elif isinstance(lhs, BCOO):
|
||||
return _bcoo_dot_general(*lhs._bufs, rhs, dimension_numbers=dimension_numbers,
|
||||
lhs_spinfo=lhs._info)
|
||||
elif isinstance(rhs, BCOO):
|
||||
return _bcoo_rdot_general(lhs, *rhs._bufs, dimension_numbers=dimension_numbers,
|
||||
rhs_spinfo=rhs._info)
|
||||
else:
|
||||
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)
|
||||
|
||||
def _bcoo_dot_general(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: BCOOInfo):
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
@ -644,23 +658,6 @@ def _bcoo_dot_general(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spin
|
||||
dimension_numbers=(cdims, bdims),
|
||||
lhs_spinfo=lhs_spinfo)
|
||||
|
||||
def bcoo_rdot_general(lhs, rhs, *, dimension_numbers: DotDimensionNumbers):
|
||||
"""A general contraction operation.
|
||||
|
||||
Args:
|
||||
lhs: An ndarray.
|
||||
rhs: A BCOO-format array.
|
||||
dimension_numbers: a tuple of tuples of the form
|
||||
`((lhs_contracting_dims, rhs_contracting_dims),
|
||||
(lhs_batch_dims, rhs_batch_dims))`.
|
||||
|
||||
Returns:
|
||||
An ndarray containing the result.
|
||||
"""
|
||||
return _bcoo_rdot_general(lhs, rhs.data, rhs.indices,
|
||||
dimension_numbers=dimension_numbers,
|
||||
rhs_spinfo=rhs._info)
|
||||
|
||||
def _bcoo_rdot_general(lhs, rhs_data, rhs_indices, *, dimension_numbers: DotDimensionNumbers, rhs_spinfo: BCOOInfo):
|
||||
# TODO(jakevdp): perhaps this should be part of the bcoo_dot_general primitive?
|
||||
result = _bcoo_dot_general(rhs_data, rhs_indices, lhs, lhs_spinfo=rhs_spinfo,
|
||||
@ -1017,25 +1014,6 @@ xla.register_translation(bcoo_dot_general_sampled_p, xla.lower_fun(
|
||||
bcoo_spdot_general_p = core.Primitive('bcoo_spdot_general')
|
||||
bcoo_spdot_general_p.multiple_results = True
|
||||
|
||||
def bcoo_spdot_general(lhs, rhs, *, dimension_numbers: DotDimensionNumbers):
|
||||
"""A general contraction operation.
|
||||
|
||||
Args:
|
||||
lhs: A BCOO-format array.
|
||||
rhs: A BCOO-format array.
|
||||
dimension_numbers: a tuple of tuples of the form
|
||||
`((lhs_contracting_dims, rhs_contracting_dims),
|
||||
(lhs_batch_dims, rhs_batch_dims))`.
|
||||
|
||||
Returns:
|
||||
A BCOO array containing the result.
|
||||
"""
|
||||
shape = _dot_general_validated_shape(lhs.shape, rhs.shape, dimension_numbers)
|
||||
data, indices = _bcoo_spdot_general(lhs.data, lhs.indices, rhs.data, rhs.indices,
|
||||
lhs_spinfo=lhs._info, rhs_spinfo=rhs._info,
|
||||
dimension_numbers=dimension_numbers)
|
||||
return BCOO((data, indices), shape=shape)
|
||||
|
||||
def _bcoo_spdot_general(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo: BCOOInfo, rhs_spinfo: BCOOInfo, dimension_numbers: DotDimensionNumbers):
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
cdims = (api_util._ensure_index_tuple(lhs_contract),
|
||||
|
@ -434,14 +434,8 @@ for _prim in [
|
||||
def _dot_general_sparse(spenv, *spvalues, dimension_numbers, precision, preferred_element_type):
|
||||
# TODO(jakevdp): pass along these unused configurations?
|
||||
del precision, preferred_element_type # unused
|
||||
if spvalues[0].is_sparse() and spvalues[1].is_sparse():
|
||||
func = sparse.bcoo_spdot_general
|
||||
elif spvalues[0].is_sparse():
|
||||
func = sparse.bcoo_dot_general
|
||||
else:
|
||||
func = sparse.bcoo_rdot_general
|
||||
A, B = spvalues_to_arrays(spenv, spvalues)
|
||||
result = func(A, B, dimension_numbers=dimension_numbers)
|
||||
result = sparse.bcoo_dot_general(*spvalues_to_arrays(spenv, spvalues),
|
||||
dimension_numbers=dimension_numbers)
|
||||
return arrays_to_spvalues(spenv, [result])
|
||||
|
||||
sparse_rules[lax.dot_general_p] = _dot_general_sparse
|
||||
|
@ -1815,6 +1815,9 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
# bcoo_dot_general
|
||||
self.assertArraysEqual(x_sp @ y_de, x_de @ y_de)
|
||||
|
||||
# bcoo_rdot_general
|
||||
self.assertArraysEqual(x_de @ y_sp, x_de @ y_de)
|
||||
|
||||
# bcoo_spdot_general
|
||||
self.assertArraysEqual((x_sp @ y_sp).todense(), x_de @ y_de)
|
||||
self.assertArraysEqual((y_sp @ x_sp).todense(), y_de @ x_de)
|
||||
|
Loading…
x
Reference in New Issue
Block a user