mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sparse] add cusparse lowering for simplest cases of bcsr_dot_general
PiperOrigin-RevId: 508473938
This commit is contained in:
parent
253cd4d9d1
commit
15c9bca67f
@ -17,11 +17,13 @@ from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
import operator
|
||||
import warnings
|
||||
|
||||
from typing import NamedTuple, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import config
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import tree_util
|
||||
@ -30,13 +32,15 @@ from jax.experimental.sparse import bcoo
|
||||
from jax.experimental.sparse.util import (
|
||||
nfold_vmap, _count_stored_elements,
|
||||
_csr_to_coo, _dot_general_validated_shape,
|
||||
SparseInfo, Shape)
|
||||
CuSparseEfficiencyWarning, SparseInfo, Shape)
|
||||
import jax.numpy as jnp
|
||||
from jax._src import api_util
|
||||
from jax._src.lax.lax import DotDimensionNumbers
|
||||
from jax._src.lib import gpu_sparse
|
||||
from jax.util import split_list, safe_zip
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.typing import Array, ArrayLike, DTypeLike
|
||||
|
||||
|
||||
@ -488,12 +492,84 @@ def _bcsr_dot_general_abstract_eval(lhs_data, lhs_indices, lhs_indptr, rhs, *,
|
||||
# ad.primitive_transposes[bcsr_dot_general_p] = _bcsr_dot_general_transpose
|
||||
# batching.primitive_batchers[bcsr_dot_general_p] = _bcsr_dot_general_batch_rule
|
||||
|
||||
def _bcsr_dot_general_gpu_lowering(
|
||||
csr_matvec_lowering, csr_matmat_lowering,
|
||||
ctx, lhs_data, lhs_indices, lhs_indptr, rhs, *, dimension_numbers,
|
||||
lhs_spinfo: SparseInfo):
|
||||
|
||||
if not config.jax_bcoo_cusparse_lowering:
|
||||
return _bcsr_dot_general_default_lowering(
|
||||
ctx, lhs_data, lhs_indices, lhs_indptr, rhs,
|
||||
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
|
||||
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
lhs_data_aval, lhs_indices_aval, lhs_indptr_aval, rhs_aval = ctx.avals_in
|
||||
props = _validate_bcsr(
|
||||
lhs_data_aval, lhs_indices_aval, lhs_indptr_aval, lhs_spinfo.shape)
|
||||
|
||||
use_default_lowering = False
|
||||
dtype = lhs_data_aval.dtype
|
||||
# TODO(vanderplas, tianjianlu): lower batched matmuls to GPU
|
||||
if lhs_batch or rhs_batch:
|
||||
# batch dimensions in dot_general are not supported
|
||||
use_default_lowering = True
|
||||
elif len(lhs_spinfo.shape) != 2 or rhs_aval.ndim not in [1, 2]:
|
||||
# only matmat / matvec supported
|
||||
use_default_lowering = True
|
||||
elif props.n_batch or props.n_dense:
|
||||
# batch and dense dimensions in BCSR not supported
|
||||
use_default_lowering = True
|
||||
elif list(lhs_contract) != [1]:
|
||||
# cusparse cannot contract over more than one dimension
|
||||
use_default_lowering = True
|
||||
elif dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
|
||||
# This would be supported if not for the dtype.
|
||||
warnings.warn(f'bcsr_dot_general cusparse/hipsparse lowering not available '
|
||||
f'for {dtype=}. Falling back to default implementation.',
|
||||
CuSparseEfficiencyWarning)
|
||||
use_default_lowering = True
|
||||
|
||||
if use_default_lowering:
|
||||
return _bcsr_dot_general_default_lowering(
|
||||
ctx, lhs_data, lhs_indices, lhs_indptr, rhs,
|
||||
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
|
||||
|
||||
if rhs_aval.ndim == 1:
|
||||
dot_general_fn = csr_matvec_lowering
|
||||
elif rhs_aval.ndim == 2:
|
||||
dot_general_fn = csr_matmat_lowering
|
||||
if rhs_contract[0] == 1:
|
||||
rhs = hlo.TransposeOp(
|
||||
rhs, permutation=mlir.dense_int_elements([1, 0])).result
|
||||
else:
|
||||
raise ValueError(f"rhs has to be 1d or 2d; get {rhs_aval.ndim}d.")
|
||||
|
||||
return [dot_general_fn(lhs_data, lhs_indices, lhs_indptr, rhs,
|
||||
shape=lhs_spinfo.shape, transpose=False,
|
||||
data_dtype=lhs_data_aval.dtype,
|
||||
index_dtype=lhs_indices_aval.dtype,
|
||||
B_dtype=rhs_aval.dtype)]
|
||||
|
||||
_bcsr_dot_general_default_lowering = mlir.lower_fun(
|
||||
_bcsr_dot_general_impl, multiple_results=False)
|
||||
|
||||
mlir.register_lowering(
|
||||
bcsr_dot_general_p, _bcsr_dot_general_default_lowering)
|
||||
|
||||
if gpu_sparse.cuda_is_supported:
|
||||
mlir.register_lowering(bcsr_dot_general_p,
|
||||
partial(_bcsr_dot_general_gpu_lowering,
|
||||
gpu_sparse.cuda_csr_matvec,
|
||||
gpu_sparse.cuda_csr_matmat),
|
||||
platform='cuda')
|
||||
if gpu_sparse.rocm_is_supported:
|
||||
mlir.register_lowering(bcsr_dot_general_p,
|
||||
partial(_bcsr_dot_general_gpu_lowering,
|
||||
gpu_sparse.rocm_csr_matvec,
|
||||
gpu_sparse.rocm_csr_matmat),
|
||||
platform='rocm')
|
||||
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
# BCOO functions that maybe should be primitives?
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user