mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 04:16:07 +00:00
[sparse] implement autodiff rules for bcsr_dot_general
This commit is contained in:
parent
7a864d73bc
commit
ac647b9459
@ -86,13 +86,17 @@ def _todense_transpose(ct, *bufs, tree):
|
||||
|
||||
standin = object()
|
||||
obj = tree_util.tree_unflatten(tree, [standin] * len(bufs))
|
||||
from jax.experimental.sparse import BCOO
|
||||
from jax.experimental.sparse import BCOO, BCSR
|
||||
from jax.experimental.sparse.bcoo import _bcoo_extract
|
||||
from jax.experimental.sparse.bcsr import bcsr_extract
|
||||
if obj is standin:
|
||||
return (ct,)
|
||||
elif isinstance(obj, BCOO):
|
||||
_, indices = bufs
|
||||
return _bcoo_extract(indices, ct), indices
|
||||
elif isinstance(obj, BCSR):
|
||||
_, indices, indptr = bufs
|
||||
return bcsr_extract(indices, indptr, ct), indices, indptr
|
||||
elif isinstance(obj, COO):
|
||||
_, row, col = bufs
|
||||
return _coo_extract(row, col, ct), row, col
|
||||
|
@ -972,8 +972,6 @@ def _bcoo_dot_general_jvp_rhs(rhs_dot, lhs_data, lhs_indices, rhs, *, dimension_
|
||||
|
||||
def _bcoo_dot_general_transpose(ct, lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: SparseInfo):
|
||||
assert not ad.is_undefined_primal(lhs_indices)
|
||||
if type(ct) is ad.Zero:
|
||||
return ad.Zero
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
lhs_ndim = len(lhs_spinfo.shape)
|
||||
rhs_ndim = rhs.aval.ndim if ad.is_undefined_primal(rhs) else rhs.ndim
|
||||
|
@ -37,7 +37,7 @@ from jax.experimental.sparse.util import (
|
||||
from jax.util import split_list, safe_zip
|
||||
|
||||
from jax._src import api_util
|
||||
from jax._src.lax.lax import DotDimensionNumbers
|
||||
from jax._src.lax.lax import DotDimensionNumbers, _dot_general_batch_dim_nums
|
||||
from jax._src.lib import gpu_sparse
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.interpreters import ad
|
||||
@ -46,7 +46,6 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.typing import Array, ArrayLike, DTypeLike
|
||||
|
||||
|
||||
|
||||
def bcsr_eliminate_zeros(mat: BCSR, nse: Optional[int] = None) -> BCSR:
|
||||
"""Eliminate zeros in BCSR representation."""
|
||||
return BCSR.from_bcoo(bcoo.bcoo_eliminate_zeros(mat.to_bcoo(), nse=nse))
|
||||
@ -88,7 +87,6 @@ def _bcsr_batch_dims_to_front(batched_args, batch_dims, spinfo, batch_size=None)
|
||||
return batched_data, batched_indices, batched_indptr, batched_spinfo
|
||||
|
||||
|
||||
|
||||
class BCSRProperties(NamedTuple):
|
||||
n_batch: int
|
||||
n_dense: int
|
||||
@ -549,45 +547,51 @@ def _bcsr_dot_general_abstract_eval(lhs_data, lhs_indices, lhs_indptr, rhs, *,
|
||||
return core.ShapedArray(out_shape, lhs_data.dtype)
|
||||
|
||||
|
||||
# def _bcsr_dot_general_jvp_lhs(lhs_data_dot, lhs_data, lhs_indices, lhs_indptr,
|
||||
# rhs, *, dimension_numbers, lhs_spinfo):
|
||||
# del lhs_data
|
||||
# return _bcsr_dot_general(lhs_data_dot, lhs_indices, lhs_indptr, rhs,
|
||||
# dimension_numbers=dimension_numbers,
|
||||
# lhs_spinfo=lhs_spinfo)
|
||||
def _bcsr_dot_general_jvp_lhs(lhs_data_dot, lhs_data, lhs_indices, lhs_indptr, rhs, *,
|
||||
dimension_numbers, lhs_spinfo):
|
||||
del lhs_data
|
||||
return _bcsr_dot_general(lhs_data_dot, lhs_indices, lhs_indptr, rhs,
|
||||
dimension_numbers=dimension_numbers,
|
||||
lhs_spinfo=lhs_spinfo)
|
||||
|
||||
|
||||
# def _bcsr_dot_general_jvp_rhs(rhs_dot, lhs_data, lhs_indices, lhs_indptr, rhs,
|
||||
# *, dimension_numbers, lhs_spinfo):
|
||||
# del rhs
|
||||
# return _bcsr_dot_general(lhs_data, lhs_indices, lhs_indptr, rhs_dot,
|
||||
# dimension_numbers=dimension_numbers,
|
||||
# lhs_spinfo=lhs_spinfo)
|
||||
def _bcsr_dot_general_jvp_rhs(rhs_dot, lhs_data, lhs_indices, lhs_indptr, rhs, *,
|
||||
dimension_numbers, lhs_spinfo):
|
||||
del rhs
|
||||
return _bcsr_dot_general(lhs_data, lhs_indices, lhs_indptr, rhs_dot,
|
||||
dimension_numbers=dimension_numbers,
|
||||
lhs_spinfo=lhs_spinfo)
|
||||
|
||||
|
||||
# def _bcsr_dot_general_transpose(ct, lhs_data, lhs_indices, lhs_inptr, rhs, *,
|
||||
# dimension_numbers, lhs_spinfo):
|
||||
# lhs_bcoo_indices = _bcsr_to_bcoo(
|
||||
# lhs_indices, lhs_inptr, shape=lhs_spinfo.shape)
|
||||
# return bcoo._bcoo_dot_general_transpose(
|
||||
# ct, lhs_data, lhs_bcoo_indices, rhs, dimension_numbers=dimension_numbers,
|
||||
# lhs_spinfo=lhs_spinfo)
|
||||
def _bcsr_dot_general_transpose(ct, lhs_data, lhs_indices, lhs_indptr, rhs, *,
|
||||
dimension_numbers, lhs_spinfo):
|
||||
# TODO(jakevdp): implement this in terms of bcsr_dot_general
|
||||
lhs_bcoo_indices = _bcsr_to_bcoo(
|
||||
lhs_indices, lhs_indptr, shape=lhs_spinfo.shape)
|
||||
data_out, _, rhs_out = bcoo._bcoo_dot_general_transpose(
|
||||
ct, lhs_data, lhs_bcoo_indices, rhs, dimension_numbers=dimension_numbers,
|
||||
lhs_spinfo=lhs_spinfo)
|
||||
return data_out, lhs_indices, lhs_indptr, rhs_out
|
||||
|
||||
|
||||
# def _bcsr_dot_general_batch_rule(batched_args, batch_dims, *,
|
||||
# dimension_numbers, lhs_spinfo):
|
||||
# lhs_data, lhs_indices, lhs_indptr, rhs = batched_args
|
||||
# lhs_bcoo_indices = _bcsr_to_bcoo(
|
||||
# lhs_indices, lhs_indptr, shape=lhs_spinfo.shape)
|
||||
# return bcoo._bcoo_dot_general_batch_rule(
|
||||
# (lhs_data, lhs_bcoo_indices, rhs), batch_dims,
|
||||
# dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
|
||||
def _bcsr_dot_general_batch_rule(batched_args, batch_dims, *,
|
||||
dimension_numbers, lhs_spinfo):
|
||||
*lhs_args, rhs = batched_args
|
||||
*lhs_dims, rhs_bdim = batch_dims
|
||||
*new_lhs_args, new_lhs_spinfo = _bcsr_batch_dims_to_front(
|
||||
lhs_args, lhs_dims, lhs_spinfo,
|
||||
batch_size=None if rhs_bdim is None else rhs.shape[rhs_bdim])
|
||||
new_dimension_numbers, result_batch_dim = _dot_general_batch_dim_nums(
|
||||
(len(lhs_spinfo.shape), rhs.ndim), (0, rhs_bdim), dimension_numbers)
|
||||
batched_out = _bcsr_dot_general(*new_lhs_args, rhs, lhs_spinfo=new_lhs_spinfo,
|
||||
dimension_numbers=new_dimension_numbers)
|
||||
return batched_out, result_batch_dim
|
||||
|
||||
|
||||
# ad.defjvp(bcsr_dot_general_p, _bcsr_dot_general_jvp_lhs, None,
|
||||
# _bcsr_dot_general_jvp_rhs)
|
||||
# ad.primitive_transposes[bcsr_dot_general_p] = _bcsr_dot_general_transpose
|
||||
# batching.primitive_batchers[bcsr_dot_general_p] = _bcsr_dot_general_batch_rule
|
||||
ad.defjvp(bcsr_dot_general_p, _bcsr_dot_general_jvp_lhs, None, None,
|
||||
_bcsr_dot_general_jvp_rhs)
|
||||
ad.primitive_transposes[bcsr_dot_general_p] = _bcsr_dot_general_transpose
|
||||
batching.primitive_batchers[bcsr_dot_general_p] = _bcsr_dot_general_batch_rule
|
||||
|
||||
def _bcsr_dot_general_gpu_lowering(
|
||||
csr_matvec_lowering, csr_matmat_lowering,
|
||||
|
@ -2184,6 +2184,13 @@ class BCSRTest(sptu.SparseTestCase):
|
||||
np.float32: 1E-5, np.complex64: 1E-5}
|
||||
|
||||
self._CheckAgainstDense(dense_fun, sparse_fun, args_maker, tol=tol)
|
||||
if jnp.issubdtype(dtype, jnp.floating) and props.n_dense == 0:
|
||||
# Dense dimensions not yet fully supported in reverse mode.
|
||||
modes = ['fwd'] if props.n_dense != 0 else ['fwd', 'rev']
|
||||
self._CheckGradsSparse(dense_fun, sparse_fun, args_maker, modes=modes, atol=tol, rtol=tol)
|
||||
# TODO: add this once bcsr_broadcast_in_dim & bcsr_concatenate are implemented
|
||||
# self._CheckBatchingSparse(dense_fun, sparse_fun, args_maker, atol=tol, rtol=tol,
|
||||
# bdims=self._random_bdims(props.n_batch, len(props.rhs_shape)))
|
||||
|
||||
class SparseGradTest(sptu.SparseTestCase):
|
||||
@jtu.sample_product(has_aux=[True, False])
|
||||
|
Loading…
x
Reference in New Issue
Block a user