[sparse] implement autodiff rules for bcsr_dot_general

This commit is contained in:
Jake VanderPlas 2023-02-10 10:18:46 -08:00
parent 7a864d73bc
commit ac647b9459
4 changed files with 50 additions and 37 deletions

View File

@ -86,13 +86,17 @@ def _todense_transpose(ct, *bufs, tree):
standin = object()
obj = tree_util.tree_unflatten(tree, [standin] * len(bufs))
from jax.experimental.sparse import BCOO
from jax.experimental.sparse import BCOO, BCSR
from jax.experimental.sparse.bcoo import _bcoo_extract
from jax.experimental.sparse.bcsr import bcsr_extract
if obj is standin:
return (ct,)
elif isinstance(obj, BCOO):
_, indices = bufs
return _bcoo_extract(indices, ct), indices
elif isinstance(obj, BCSR):
_, indices, indptr = bufs
return bcsr_extract(indices, indptr, ct), indices, indptr
elif isinstance(obj, COO):
_, row, col = bufs
return _coo_extract(row, col, ct), row, col

View File

@ -972,8 +972,6 @@ def _bcoo_dot_general_jvp_rhs(rhs_dot, lhs_data, lhs_indices, rhs, *, dimension_
def _bcoo_dot_general_transpose(ct, lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: SparseInfo):
assert not ad.is_undefined_primal(lhs_indices)
if type(ct) is ad.Zero:
return ad.Zero
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
lhs_ndim = len(lhs_spinfo.shape)
rhs_ndim = rhs.aval.ndim if ad.is_undefined_primal(rhs) else rhs.ndim

View File

@ -37,7 +37,7 @@ from jax.experimental.sparse.util import (
from jax.util import split_list, safe_zip
from jax._src import api_util
from jax._src.lax.lax import DotDimensionNumbers
from jax._src.lax.lax import DotDimensionNumbers, _dot_general_batch_dim_nums
from jax._src.lib import gpu_sparse
from jax._src.lib.mlir.dialects import hlo
from jax._src.interpreters import ad
@ -46,7 +46,6 @@ from jax._src.interpreters import mlir
from jax._src.typing import Array, ArrayLike, DTypeLike
def bcsr_eliminate_zeros(mat: BCSR, nse: Optional[int] = None) -> BCSR:
"""Eliminate zeros in BCSR representation."""
return BCSR.from_bcoo(bcoo.bcoo_eliminate_zeros(mat.to_bcoo(), nse=nse))
@ -88,7 +87,6 @@ def _bcsr_batch_dims_to_front(batched_args, batch_dims, spinfo, batch_size=None)
return batched_data, batched_indices, batched_indptr, batched_spinfo
class BCSRProperties(NamedTuple):
n_batch: int
n_dense: int
@ -549,45 +547,51 @@ def _bcsr_dot_general_abstract_eval(lhs_data, lhs_indices, lhs_indptr, rhs, *,
return core.ShapedArray(out_shape, lhs_data.dtype)
# def _bcsr_dot_general_jvp_lhs(lhs_data_dot, lhs_data, lhs_indices, lhs_indptr,
# rhs, *, dimension_numbers, lhs_spinfo):
# del lhs_data
# return _bcsr_dot_general(lhs_data_dot, lhs_indices, lhs_indptr, rhs,
# dimension_numbers=dimension_numbers,
# lhs_spinfo=lhs_spinfo)
def _bcsr_dot_general_jvp_lhs(lhs_data_dot, lhs_data, lhs_indices, lhs_indptr, rhs, *,
dimension_numbers, lhs_spinfo):
del lhs_data
return _bcsr_dot_general(lhs_data_dot, lhs_indices, lhs_indptr, rhs,
dimension_numbers=dimension_numbers,
lhs_spinfo=lhs_spinfo)
# def _bcsr_dot_general_jvp_rhs(rhs_dot, lhs_data, lhs_indices, lhs_indptr, rhs,
# *, dimension_numbers, lhs_spinfo):
# del rhs
# return _bcsr_dot_general(lhs_data, lhs_indices, lhs_indptr, rhs_dot,
# dimension_numbers=dimension_numbers,
# lhs_spinfo=lhs_spinfo)
def _bcsr_dot_general_jvp_rhs(rhs_dot, lhs_data, lhs_indices, lhs_indptr, rhs, *,
dimension_numbers, lhs_spinfo):
del rhs
return _bcsr_dot_general(lhs_data, lhs_indices, lhs_indptr, rhs_dot,
dimension_numbers=dimension_numbers,
lhs_spinfo=lhs_spinfo)
# def _bcsr_dot_general_transpose(ct, lhs_data, lhs_indices, lhs_inptr, rhs, *,
# dimension_numbers, lhs_spinfo):
# lhs_bcoo_indices = _bcsr_to_bcoo(
# lhs_indices, lhs_inptr, shape=lhs_spinfo.shape)
# return bcoo._bcoo_dot_general_transpose(
# ct, lhs_data, lhs_bcoo_indices, rhs, dimension_numbers=dimension_numbers,
# lhs_spinfo=lhs_spinfo)
def _bcsr_dot_general_transpose(ct, lhs_data, lhs_indices, lhs_indptr, rhs, *,
dimension_numbers, lhs_spinfo):
# TODO(jakevdp): implement this in terms of bcsr_dot_general
lhs_bcoo_indices = _bcsr_to_bcoo(
lhs_indices, lhs_indptr, shape=lhs_spinfo.shape)
data_out, _, rhs_out = bcoo._bcoo_dot_general_transpose(
ct, lhs_data, lhs_bcoo_indices, rhs, dimension_numbers=dimension_numbers,
lhs_spinfo=lhs_spinfo)
return data_out, lhs_indices, lhs_indptr, rhs_out
# def _bcsr_dot_general_batch_rule(batched_args, batch_dims, *,
# dimension_numbers, lhs_spinfo):
# lhs_data, lhs_indices, lhs_indptr, rhs = batched_args
# lhs_bcoo_indices = _bcsr_to_bcoo(
# lhs_indices, lhs_indptr, shape=lhs_spinfo.shape)
# return bcoo._bcoo_dot_general_batch_rule(
# (lhs_data, lhs_bcoo_indices, rhs), batch_dims,
# dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
def _bcsr_dot_general_batch_rule(batched_args, batch_dims, *,
dimension_numbers, lhs_spinfo):
*lhs_args, rhs = batched_args
*lhs_dims, rhs_bdim = batch_dims
*new_lhs_args, new_lhs_spinfo = _bcsr_batch_dims_to_front(
lhs_args, lhs_dims, lhs_spinfo,
batch_size=None if rhs_bdim is None else rhs.shape[rhs_bdim])
new_dimension_numbers, result_batch_dim = _dot_general_batch_dim_nums(
(len(lhs_spinfo.shape), rhs.ndim), (0, rhs_bdim), dimension_numbers)
batched_out = _bcsr_dot_general(*new_lhs_args, rhs, lhs_spinfo=new_lhs_spinfo,
dimension_numbers=new_dimension_numbers)
return batched_out, result_batch_dim
# ad.defjvp(bcsr_dot_general_p, _bcsr_dot_general_jvp_lhs, None,
# _bcsr_dot_general_jvp_rhs)
# ad.primitive_transposes[bcsr_dot_general_p] = _bcsr_dot_general_transpose
# batching.primitive_batchers[bcsr_dot_general_p] = _bcsr_dot_general_batch_rule
ad.defjvp(bcsr_dot_general_p, _bcsr_dot_general_jvp_lhs, None, None,
_bcsr_dot_general_jvp_rhs)
ad.primitive_transposes[bcsr_dot_general_p] = _bcsr_dot_general_transpose
batching.primitive_batchers[bcsr_dot_general_p] = _bcsr_dot_general_batch_rule
def _bcsr_dot_general_gpu_lowering(
csr_matvec_lowering, csr_matmat_lowering,

View File

@ -2184,6 +2184,13 @@ class BCSRTest(sptu.SparseTestCase):
np.float32: 1E-5, np.complex64: 1E-5}
self._CheckAgainstDense(dense_fun, sparse_fun, args_maker, tol=tol)
if jnp.issubdtype(dtype, jnp.floating) and props.n_dense == 0:
# Dense dimensions not yet fully supported in reverse mode.
modes = ['fwd'] if props.n_dense != 0 else ['fwd', 'rev']
self._CheckGradsSparse(dense_fun, sparse_fun, args_maker, modes=modes, atol=tol, rtol=tol)
# TODO: add this once bcsr_broadcast_in_dim & bcsr_concatenate are implemented
# self._CheckBatchingSparse(dense_fun, sparse_fun, args_maker, atol=tol, rtol=tol,
# bdims=self._random_bdims(props.n_batch, len(props.rhs_shape)))
class SparseGradTest(sptu.SparseTestCase):
@jtu.sample_product(has_aux=[True, False])