[sparse] add cusparse lowering for simplest cases of bcsr_dot_general

PiperOrigin-RevId: 508473938
This commit is contained in:
Jake VanderPlas 2023-02-09 14:18:07 -08:00 committed by jax authors
parent 253cd4d9d1
commit 15c9bca67f

View File

@ -17,11 +17,13 @@ from __future__ import annotations
from functools import partial
import operator
import warnings
from typing import NamedTuple, Optional, Sequence, Tuple, Union
import numpy as np
from jax import config
from jax import core
from jax import lax
from jax import tree_util
@ -30,13 +32,15 @@ from jax.experimental.sparse import bcoo
from jax.experimental.sparse.util import (
nfold_vmap, _count_stored_elements,
_csr_to_coo, _dot_general_validated_shape,
SparseInfo, Shape)
CuSparseEfficiencyWarning, SparseInfo, Shape)
import jax.numpy as jnp
from jax._src import api_util
from jax._src.lax.lax import DotDimensionNumbers
from jax._src.lib import gpu_sparse
from jax.util import split_list, safe_zip
from jax.interpreters import batching
from jax.interpreters import mlir
from jax._src.lib.mlir.dialects import hlo
from jax._src.typing import Array, ArrayLike, DTypeLike
@ -488,12 +492,84 @@ def _bcsr_dot_general_abstract_eval(lhs_data, lhs_indices, lhs_indptr, rhs, *,
# ad.primitive_transposes[bcsr_dot_general_p] = _bcsr_dot_general_transpose
# batching.primitive_batchers[bcsr_dot_general_p] = _bcsr_dot_general_batch_rule
def _bcsr_dot_general_gpu_lowering(
csr_matvec_lowering, csr_matmat_lowering,
ctx, lhs_data, lhs_indices, lhs_indptr, rhs, *, dimension_numbers,
lhs_spinfo: SparseInfo):
if not config.jax_bcoo_cusparse_lowering:
return _bcsr_dot_general_default_lowering(
ctx, lhs_data, lhs_indices, lhs_indptr, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
lhs_data_aval, lhs_indices_aval, lhs_indptr_aval, rhs_aval = ctx.avals_in
props = _validate_bcsr(
lhs_data_aval, lhs_indices_aval, lhs_indptr_aval, lhs_spinfo.shape)
use_default_lowering = False
dtype = lhs_data_aval.dtype
# TODO(vanderplas, tianjianlu): lower batched matmuls to GPU
if lhs_batch or rhs_batch:
# batch dimensions in dot_general are not supported
use_default_lowering = True
elif len(lhs_spinfo.shape) != 2 or rhs_aval.ndim not in [1, 2]:
# only matmat / matvec supported
use_default_lowering = True
elif props.n_batch or props.n_dense:
# batch and dense dimensions in BCSR not supported
use_default_lowering = True
elif list(lhs_contract) != [1]:
# cusparse cannot contract over more than one dimension
use_default_lowering = True
elif dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
# This would be supported if not for the dtype.
warnings.warn(f'bcsr_dot_general cusparse/hipsparse lowering not available '
f'for {dtype=}. Falling back to default implementation.',
CuSparseEfficiencyWarning)
use_default_lowering = True
if use_default_lowering:
return _bcsr_dot_general_default_lowering(
ctx, lhs_data, lhs_indices, lhs_indptr, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
if rhs_aval.ndim == 1:
dot_general_fn = csr_matvec_lowering
elif rhs_aval.ndim == 2:
dot_general_fn = csr_matmat_lowering
if rhs_contract[0] == 1:
rhs = hlo.TransposeOp(
rhs, permutation=mlir.dense_int_elements([1, 0])).result
else:
raise ValueError(f"rhs has to be 1d or 2d; get {rhs_aval.ndim}d.")
return [dot_general_fn(lhs_data, lhs_indices, lhs_indptr, rhs,
shape=lhs_spinfo.shape, transpose=False,
data_dtype=lhs_data_aval.dtype,
index_dtype=lhs_indices_aval.dtype,
B_dtype=rhs_aval.dtype)]
_bcsr_dot_general_default_lowering = mlir.lower_fun(
_bcsr_dot_general_impl, multiple_results=False)
mlir.register_lowering(
bcsr_dot_general_p, _bcsr_dot_general_default_lowering)
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(bcsr_dot_general_p,
partial(_bcsr_dot_general_gpu_lowering,
gpu_sparse.cuda_csr_matvec,
gpu_sparse.cuda_csr_matmat),
platform='cuda')
if gpu_sparse.rocm_is_supported:
mlir.register_lowering(bcsr_dot_general_p,
partial(_bcsr_dot_general_gpu_lowering,
gpu_sparse.rocm_csr_matvec,
gpu_sparse.rocm_csr_matmat),
platform='rocm')
#----------------------------------------------------------------------
# BCOO functions that maybe should be primitives?