[MHLO] Add direct MHLO lowerings for sparse primitives.

PiperOrigin-RevId: 440374054
This commit is contained in:
Peter Hawkins 2022-04-08 08:43:23 -07:00 committed by jax authors
parent 1cb4fccd1d
commit 648a512488
7 changed files with 1032 additions and 6 deletions

View File

@ -41,6 +41,7 @@ from jax.experimental.sparse.csr import CSR, CSC
from jax.experimental.sparse.util import _coo_extract
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src import dtypes
@ -102,6 +103,9 @@ batching.primitive_batchers[todense_p] = _todense_batching_rule
xla.register_translation(todense_p, xla.lower_fun(
_todense_impl, multiple_results=False, new_style=True))
mlir.register_lowering(todense_p, mlir.lower_fun(
_todense_impl, multiple_results=False))
def empty(shape, dtype=None, index_dtype='int32', sparse_format='bcoo', **kwds):
"""Create an empty sparse array.

View File

@ -29,6 +29,7 @@ from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.util import _safe_asarray, CuSparseEfficiencyWarning
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir
from jax.interpreters import xla
import jax.numpy as jnp
from jax.interpreters import ad
@ -38,6 +39,9 @@ from jax._src.api_util import flatten_axes
from jax._src.lax.lax import (
ranges_like, remaining, _dot_general_batch_dim_nums, _dot_general_shape_rule,
DotDimensionNumbers)
import jax._src.lib
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib import xla_client as xc
from jax._src.numpy.setops import _unique
@ -146,6 +150,9 @@ def _bcoo_sort_indices(data, indices, shape):
_bcoo_sort_indices_rule = xla.lower_fun(
_bcoo_sort_indices, multiple_results=True, new_style=True)
_bcoo_sort_indices_mhlo = mlir.lower_fun(
_bcoo_sort_indices, multiple_results=True)
def _unbatch_bcoo(data, indices, shape):
n_batch = _validate_bcoo(data, indices, shape).n_batch
if n_batch == 0:
@ -282,6 +289,8 @@ ad.primitive_transposes[bcoo_todense_p] = _bcoo_todense_transpose
batching.primitive_batchers[bcoo_todense_p] = _bcoo_todense_batching_rule
xla.register_translation(bcoo_todense_p, xla.lower_fun(
_bcoo_todense_impl, multiple_results=False, new_style=True))
mlir.register_lowering(bcoo_todense_p, mlir.lower_fun(
_bcoo_todense_impl, multiple_results=False))
#--------------------------------------------------------------------
# bcoo_fromdense
@ -408,6 +417,8 @@ ad.primitive_transposes[bcoo_fromdense_p] = _bcoo_fromdense_transpose
batching.primitive_batchers[bcoo_fromdense_p] = _bcoo_fromdense_batching_rule
xla.register_translation(bcoo_fromdense_p, xla.lower_fun(
_bcoo_fromdense_impl, multiple_results=True, new_style=True))
mlir.register_lowering(bcoo_fromdense_p, mlir.lower_fun(
_bcoo_fromdense_impl, multiple_results=True))
#----------------------------------------------------------------------
# bcoo_extract
@ -487,6 +498,8 @@ ad.primitive_transposes[bcoo_extract_p] = _bcoo_extract_transpose
batching.primitive_batchers[bcoo_extract_p] = _bcoo_extract_batching_rule
xla.register_translation(bcoo_extract_p, xla.lower_fun(
_bcoo_extract_impl, multiple_results=False, new_style=True))
mlir.register_lowering(bcoo_extract_p, mlir.lower_fun(
_bcoo_extract_impl, multiple_results=False))
#----------------------------------------------------------------------
# bcoo_transpose
@ -602,6 +615,8 @@ ad.primitive_transposes[bcoo_transpose_p] = _bcoo_transpose_transpose
batching.primitive_batchers[bcoo_transpose_p] = _bcoo_transpose_batch_rule
xla.register_translation(bcoo_transpose_p, xla.lower_fun(
_bcoo_transpose_impl, multiple_results=True, new_style=True))
mlir.register_lowering(bcoo_transpose_p, mlir.lower_fun(
_bcoo_transpose_impl, multiple_results=True))
#----------------------------------------------------------------------
# bcoo_dot_general
@ -862,6 +877,146 @@ def _bcoo_dot_general_gpu_translation_rule(
ctx, avals_in, avals_out, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
_bcoo_dot_general_default_lowering = mlir.lower_fun(
_bcoo_dot_general_impl, multiple_results=False)
def _collapse_mhlo(x, start, end):
x_type = ir.RankedTensorType(x.type)
shape = x_type.shape
shape = (shape[:start]
+ [functools.reduce(operator.mul, shape[start:end + 1])]
+ shape[end + 1:])
return mhlo.ReshapeOp(
ir.RankedTensorType.get(shape, x_type.element_type), x).result
def _bcoo_dot_general_cuda_lowering(
ctx, lhs_data, lhs_indices, rhs, *, dimension_numbers,
lhs_spinfo: BCOOInfo):
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
lhs_data_aval, lhs_indices_aval, rhs_aval, = ctx.avals_in
props = _validate_bcoo_indices(lhs_indices_aval, lhs_spinfo.shape)
rhs_ndim = len(ir.RankedTensorType(rhs.type).shape)
# Checks the shapes of lhs and rhs.
assert props.n_dense == 0
assert props.n_batch == 0
assert props.n_sparse in [1, 2]
assert rhs_ndim in [1, 2]
# Checks the operation dimensions.
assert len(lhs_batch) == 0
assert len(rhs_batch) == 0
assert len(lhs_contract) == 1
# Checks the dtype.
assert lhs_data_aval.dtype in [np.float32, np.float64, np.complex64,
np.complex128]
assert lhs_data_aval.dtype == rhs_aval.dtype
assert lhs_indices_aval.dtype == np.int32
assert sparse_apis is not None
if rhs_ndim == 1:
bcoo_dot_general_fn = sparse_apis.coo_matvec_mhlo
elif rhs_ndim == 2:
bcoo_dot_general_fn = sparse_apis.coo_matmat_mhlo
if rhs_contract[0] == 1:
rhs = mhlo.TransposeOp(
rhs, permutation=mlir.dense_int_elements([1, 0])).result
else:
raise ValueError(f"rhs has to be 1d or 2d; get {rhs_ndim}d.")
lhs_transpose = False
if props.n_sparse == 1:
# Converts lhs to a row vector.
col = _collapse_mhlo(lhs_indices, start=0, end=1)
row = mlir.full_like_aval(
0, core.ShapedArray(ir.RankedTensorType(col.type).shape,
np.dtype(np.int32)))
lhs_shape = (1, lhs_spinfo.shape[0])
dot_product = bcoo_dot_general_fn(
lhs_data, row, col, rhs, shape=lhs_shape, transpose=lhs_transpose,
data_dtype=lhs_data_aval.dtype, index_dtype=lhs_indices_aval.dtype,
x_dtype=rhs_aval.dtype)
if rhs_ndim == 1:
# Transforms a single-element array to a scalar.
return [mhlo.ReshapeOp(
ir.RankedTensorType(
[], ir.RankedTensorType(dot_product.type).element_type),
dot_product).result]
else:
return [_collapse_mhlo(dot_product, start=0, end=1)]
elif props.n_sparse == 2:
lhs_indices_shape = ir.RankedTensorType(lhs_indices.type).shape
row = _collapse_mhlo(
mhlo.SliceOp(
lhs_indices,
start_indices=mlir.dense_int_elements([0, 0]),
limit_indices=mlir.dense_int_elements([lhs_indices_shape[0], 1]),
strides=mlir.dense_int_elements([1, 1])).result,
start=0, end=1)
col = _collapse_mhlo(
mhlo.SliceOp(
lhs_indices,
start_indices=mlir.dense_int_elements([0, 1]),
limit_indices=mlir.dense_int_elements([lhs_indices_shape[0], 2]),
strides=mlir.dense_int_elements([1, 1])).result,
start=0, end=1)
if lhs_contract[0] == 0:
lhs_transpose = True
return [bcoo_dot_general_fn(
lhs_data, row, col, rhs, shape=lhs_spinfo.shape,
transpose=lhs_transpose, data_dtype=lhs_data_aval.dtype,
index_dtype=lhs_indices_aval.dtype,
x_dtype=rhs_aval.dtype)]
else:
raise ValueError(f"lhs has to be 1d or 2d; get {props.n_sparse}d.")
def _bcoo_dot_general_gpu_lowering(
ctx, lhs_data, lhs_indices, rhs, *, dimension_numbers,
lhs_spinfo: BCOOInfo):
if not config.jax_bcoo_cusparse_lowering:
return _bcoo_dot_general_default_lowering(
ctx, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
lhs_data_aval, lhs_indices_aval, rhs_aval, = ctx.avals_in
n_batch, n_sparse, n_dense, nse = _validate_bcoo(
lhs_data_aval, lhs_indices_aval, lhs_spinfo.shape)
dtype = lhs_data_aval.dtype
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
warnings.warn(f'bcoo_dot_general cusparse/hipsparse lowering not available '
f'for dtype={dtype}. Falling back to default implementation.',
CuSparseEfficiencyWarning)
return _bcoo_dot_general_default_lowering(
ctx, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
if (n_batch or n_dense or
n_sparse not in [1, 2] or rhs_aval.ndim not in [1, 2] or
lhs_batch or rhs_batch or len(lhs_contract) != 1):
return _bcoo_dot_general_default_lowering(
ctx, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
else:
# Sorts lhs by row indices.
sub_ctx = mlir.LoweringRuleContext(module_context=ctx.module_context,
primitive=None,
avals_in=ctx.avals_in[:2],
avals_out=ctx.avals_in[:2])
(lhs_data,), (lhs_indices,) = _bcoo_sort_indices_mhlo(
sub_ctx, lhs_data, lhs_indices, shape=lhs_spinfo.shape)
return _bcoo_dot_general_cuda_lowering(
ctx, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
def _bcoo_dot_general_jvp_lhs(lhs_data_dot, lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: BCOOInfo):
return _bcoo_dot_general(lhs_data_dot, lhs_indices, rhs, dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
@ -930,10 +1085,16 @@ batching.primitive_batchers[bcoo_dot_general_p] = _bcoo_dot_general_batch_rule
xla.register_translation(
bcoo_dot_general_p, _bcoo_dot_general_default_translation_rule)
mlir.register_lowering(
bcoo_dot_general_p, _bcoo_dot_general_default_lowering)
if sparse_apis and sparse_apis.is_supported:
xla.register_translation(bcoo_dot_general_p,
_bcoo_dot_general_gpu_translation_rule,
platform='gpu')
if jax._src.lib.version > (0, 3, 5):
mlir.register_lowering(bcoo_dot_general_p,
_bcoo_dot_general_gpu_lowering,
platform='gpu')
#----------------------------------------------------------------------
# bcoo_dot_general_sampled
@ -1005,6 +1166,9 @@ ad.primitive_transposes[bcoo_dot_general_sampled_p] = _bcoo_dot_general_sampled_
batching.primitive_batchers[bcoo_dot_general_sampled_p] = _bcoo_dot_general_sampled_batch_rule
xla.register_translation(bcoo_dot_general_sampled_p, xla.lower_fun(
_bcoo_dot_general_sampled_impl, multiple_results=False, new_style=True))
mlir.register_lowering(
bcoo_dot_general_sampled_p,
mlir.lower_fun(_bcoo_dot_general_sampled_impl, multiple_results=False))
#----------------------------------------------------------------------
# bcoo_spdot_general
@ -1205,6 +1369,8 @@ batching.primitive_batchers[bcoo_spdot_general_p] = _bcoo_spdot_general_batch_ru
ad.primitive_jvps[bcoo_spdot_general_p] = _bcoo_spdot_general_jvp
xla.register_translation(bcoo_spdot_general_p, xla.lower_fun(
_bcoo_spdot_general_impl, multiple_results=True, new_style=True))
mlir.register_lowering(bcoo_spdot_general_p, mlir.lower_fun(
_bcoo_spdot_general_impl, multiple_results=True))
#----------------------------------------------------------------------
# BCOO functions that maybe should be primitives?

View File

@ -23,10 +23,13 @@ import numpy as np
from jax import core
from jax import lax
from jax.interpreters import ad
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.util import _coo_extract, _safe_asarray, CuSparseEfficiencyWarning
from jax import tree_util
import jax._src.lib
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib import sparse_apis
from jax._src.numpy.lax_numpy import _promote_dtypes
from jax._src.lib import xla_client
@ -183,6 +186,37 @@ def _coo_todense_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
return [xops.Transpose(result, (1, 0))] if transpose else [result]
_coo_todense_lowering = mlir.lower_fun(
_coo_todense_impl, multiple_results=False)
def _coo_todense_gpu_lowering(ctx, data, row, col, *, spinfo):
data_aval, row_aval, _ = ctx.avals_in
dtype = data_aval.dtype
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
warnings.warn(f"coo_todense cusparse/hipsparse lowering not available for dtype={dtype}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_todense_lowering(ctx, data, row, col, spinfo=spinfo)
if spinfo.rows_sorted:
shape = spinfo.shape
transpose = False
elif spinfo.cols_sorted:
row, col = col, row
transpose = True
shape = spinfo.shape[::-1]
else:
warnings.warn("coo_todense GPU lowering requires matrices with sorted rows or sorted cols. "
"To sort the rows in your matrix, use e.g. mat = mat._sort_rows(). Falling "
"back to the default implementation.", CuSparseEfficiencyWarning)
return _coo_todense_lowering(ctx, data, row, col, spinfo=spinfo)
result = sparse_apis.coo_todense_mhlo(
data, row, col, shape=shape, data_dtype=dtype, index_dtype=row_aval.dtype)
return (
[mhlo.TransposeOp(result, mlir.dense_int_elements([1, 0])).result]
if transpose else [result])
def _coo_todense_jvp(data_dot, data, row, col, *, spinfo):
return _coo_todense(data_dot, row, col, spinfo=spinfo)
@ -200,9 +234,13 @@ def _coo_todense_transpose(ct, data, row, col, *, spinfo):
ad.defjvp(coo_todense_p, _coo_todense_jvp, None, None)
ad.primitive_transposes[coo_todense_p] = _coo_todense_transpose
xla.register_translation(coo_todense_p, _coo_todense_translation_rule)
mlir.register_lowering(coo_todense_p, _coo_todense_lowering)
if sparse_apis and sparse_apis.is_supported:
xla.register_translation(coo_todense_p, _coo_todense_gpu_translation_rule,
platform='gpu')
if jax._src.lib.version > (0, 3, 5):
mlir.register_lowering(coo_todense_p, _coo_todense_gpu_lowering,
platform='gpu')
#--------------------------------------------------------------------
# coo_fromdense
@ -278,6 +316,23 @@ def _coo_fromdense_gpu_translation_rule(ctx, avals_in, avals_out, mat, *, nse,
ctx.builder, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
return [data, row, col]
_coo_fromdense_lowering = mlir.lower_fun(
_coo_fromdense_impl, multiple_results=True)
def _coo_fromdense_gpu_lowering(ctx, mat, *, nse, index_dtype):
dtype = ctx.avals_in[0].dtype
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
warnings.warn(f"coo_fromdense cusparse/hipsparse lowering not available for dtype={dtype}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_fromdense_lowering(ctx, mat, nse=nse, index_dtype=index_dtype)
data, row, col = sparse_apis.coo_fromdense_mhlo(
mat, nnz=nse,
data_dtype=dtype,
index_dtype=np.dtype(index_dtype),
index_type=mlir.dtype_to_ir_type(np.dtype(index_dtype)))
return [data, row, col]
def _coo_fromdense_jvp(primals, tangents, *, nse, index_dtype):
M, = primals
Mdot, = tangents
@ -307,10 +362,15 @@ ad.primitive_jvps[coo_fromdense_p] = _coo_fromdense_jvp
ad.primitive_transposes[coo_fromdense_p] = _coo_fromdense_transpose
xla.register_translation(coo_fromdense_p, _coo_fromdense_translation_rule)
mlir.register_lowering(coo_fromdense_p, _coo_fromdense_lowering)
if sparse_apis and sparse_apis.is_supported:
xla.register_translation(coo_fromdense_p,
_coo_fromdense_gpu_translation_rule,
platform='gpu')
if jax._src.lib.version > (0, 3, 5):
mlir.register_lowering(coo_fromdense_p,
_coo_fromdense_gpu_lowering,
platform='gpu')
#--------------------------------------------------------------------
# coo_matvec
@ -400,6 +460,36 @@ def _coo_matvec_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
return [sparse_apis.coo_matvec(ctx.builder, data, row, col, v, shape=shape,
transpose=transpose)]
_coo_matvec_lowering = mlir.lower_fun(
_coo_matvec_impl, multiple_results=False)
def _coo_matvec_gpu_lowering(ctx, data, row, col, v, *, spinfo, transpose):
data_aval, row_aval, _, x_aval = ctx.avals_in
dtype = data_aval.dtype
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
warnings.warn(f"coo_matvec cusparse/hipsparse lowering not available for dtype={dtype}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_matvec_lowering(ctx, data, row, col, v, spinfo=spinfo,
transpose=transpose)
if spinfo.rows_sorted:
shape = spinfo.shape
elif spinfo.cols_sorted:
row, col = col, row
transpose = not transpose
shape = spinfo.shape[::-1]
else:
warnings.warn("coo_matvec GPU lowering requires matrices with sorted rows or sorted cols. "
"To sort the rows in your matrix, use e.g. mat = mat._sort_rows(). Falling "
"back to the default implementation.", CuSparseEfficiencyWarning)
return _coo_matvec_lowering(ctx, data, row, col, v, spinfo=spinfo,
transpose=transpose)
return [sparse_apis.coo_matvec_mhlo(
data, row, col, v, shape=shape, transpose=transpose,
index_dtype=row_aval.dtype, data_dtype=dtype, x_dtype=x_aval.dtype)]
def _coo_matvec_jvp_mat(data_dot, data, row, col, v, *, spinfo, transpose):
return _coo_matvec(data_dot, row, col, v, spinfo=spinfo, transpose=transpose)
@ -421,9 +511,13 @@ def _coo_matvec_transpose(ct, data, row, col, v, *, spinfo, transpose):
ad.defjvp(coo_matvec_p, _coo_matvec_jvp_mat, None, None, _coo_matvec_jvp_vec)
ad.primitive_transposes[coo_matvec_p] = _coo_matvec_transpose
xla.register_translation(coo_matvec_p, _coo_matvec_translation_rule)
mlir.register_lowering(coo_matvec_p, _coo_matvec_lowering)
if sparse_apis and sparse_apis.is_supported:
xla.register_translation(coo_matvec_p, _coo_matvec_gpu_translation_rule,
platform='gpu')
if jax._src.lib.version > (0, 3, 5):
mlir.register_lowering(coo_matvec_p, _coo_matvec_gpu_lowering,
platform='gpu')
#--------------------------------------------------------------------
# coo_matmat
@ -511,6 +605,35 @@ def _coo_matmat_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
return [sparse_apis.coo_matmat(ctx.builder, data, row, col, B, shape=shape,
transpose=transpose)]
_coo_matmat_lowering = mlir.lower_fun(_coo_matmat_impl, multiple_results=False)
def _coo_matmat_gpu_lowering(ctx, data, row, col, B, *, spinfo, transpose):
data_aval, row_aval, _, B_aval = ctx.avals_in
dtype = data_aval.dtype
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
warnings.warn(f"coo_matmat cusparse/hipsprse lowering not available for dtype={dtype}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_matmat_lowering(ctx, data, row, col, B, spinfo=spinfo,
transpose=transpose)
if spinfo.rows_sorted:
shape = spinfo.shape
elif spinfo.cols_sorted:
row, col = col, row
transpose = not transpose
shape = spinfo.shape[::-1]
else:
warnings.warn("coo_matmat GPU lowering requires matrices with sorted rows or sorted cols. "
"To sort the rows in your matrix, use e.g. mat = mat._sort_rows(). Falling "
"back to the default implementation.", CuSparseEfficiencyWarning)
return _coo_matmat_lowering(ctx, data, row, col, B, spinfo=spinfo,
transpose=transpose)
return [sparse_apis.coo_matmat_mhlo(data, row, col, B, shape=shape,
transpose=transpose, x_dtype=B_aval.dtype,
data_dtype=data_aval.dtype,
index_dtype=row_aval.dtype)]
def _coo_matmat_jvp_left(data_dot, data, row, col, B, *, spinfo, transpose):
return _coo_matmat(data_dot, row, col, B, spinfo=spinfo, transpose=transpose)
@ -529,6 +652,10 @@ def _coo_matmat_transpose(ct, data, row, col, B, *, spinfo, transpose):
ad.defjvp(coo_matmat_p, _coo_matmat_jvp_left, None, None, _coo_matmat_jvp_right)
ad.primitive_transposes[coo_matmat_p] = _coo_matmat_transpose
xla.register_translation(coo_matmat_p, _coo_matmat_translation_rule)
mlir.register_lowering(coo_matmat_p, _coo_matmat_lowering)
if sparse_apis and sparse_apis.is_supported:
xla.register_translation(coo_matmat_p, _coo_matmat_gpu_translation_rule,
platform='gpu')
if jax._src.lib.version > (0, 3, 5):
mlir.register_lowering(coo_matmat_p, _coo_matmat_gpu_lowering,
platform='gpu')

View File

@ -22,11 +22,13 @@ import numpy as np
from jax import core
from jax.interpreters import ad
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.coo import _coo_matmat, _coo_matvec, _coo_todense, COOInfo
from jax.experimental.sparse.util import _csr_to_coo, _csr_extract, _safe_asarray, CuSparseEfficiencyWarning
from jax import tree_util
import jax._src.lib
from jax._src.lib import sparse_apis
from jax._src.numpy.lax_numpy import _promote_dtypes
import jax.numpy as jnp
@ -184,6 +186,21 @@ def _csr_todense_gpu_translation_rule(ctx, avals_in, avals_out, data, indices,
indptr, shape=shape)
return [sparse_apis.csr_todense(ctx.builder, data, indices, indptr, shape=shape)]
_csr_todense_lowering = mlir.lower_fun(
_csr_todense_impl, multiple_results=False)
def _csr_todense_gpu_lowering(ctx, data, indices, indptr, *, shape):
data_aval, indices_aval, _ = ctx.avals_in
dtype = data_aval.dtype
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
warnings.warn(f"csr_todense cusparse/hipsparse lowering not available for dtype={dtype}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _csr_todense_lowering(ctx, data, indices, indptr, shape=shape)
return [sparse_apis.csr_todense_mhlo(
data, indices, indptr, shape=shape, data_dtype=dtype,
index_dtype=indices_aval.dtype)]
def _csr_todense_jvp(data_dot, data, indices, indptr, *, shape):
return csr_todense(data_dot, indices, indptr, shape=shape)
@ -201,9 +218,13 @@ def _csr_todense_transpose(ct, data, indices, indptr, *, shape):
ad.defjvp(csr_todense_p, _csr_todense_jvp, None, None)
ad.primitive_transposes[csr_todense_p] = _csr_todense_transpose
xla.register_translation(csr_todense_p, _csr_todense_translation_rule)
mlir.register_lowering(csr_todense_p, _csr_todense_lowering)
if sparse_apis and sparse_apis.is_supported:
xla.register_translation(csr_todense_p, _csr_todense_gpu_translation_rule,
platform='gpu')
if jax._src.lib.version > (0, 3, 5):
mlir.register_lowering(csr_todense_p, _csr_todense_gpu_lowering,
platform='gpu')
#--------------------------------------------------------------------
# csr_fromdense
@ -267,6 +288,21 @@ def _csr_fromdense_gpu_translation_rule(ctx, avals_in, avals_out, mat, *, nse,
ctx.builder, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
return [data, indices, indptr]
_csr_fromdense_lowering = mlir.lower_fun(_csr_fromdense_impl,
multiple_results=True)
def _csr_fromdense_gpu_lowering(ctx, mat, *, nse, index_dtype):
dtype = ctx.avals_in[0].dtype
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
warnings.warn(f"csr_fromdense cusparse/hipsparse lowering not available for dtype={dtype}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _csr_fromdense_lowering(ctx, mat, nse=nse, index_dtype=index_dtype)
data, indices, indptr = sparse_apis.csr_fromdense_mhlo(
mat, nnz=nse, index_dtype=np.dtype(index_dtype),
data_dtype=dtype, index_type=mlir.dtype_to_ir_type(np.dtype(index_dtype)))
return [data, indices, indptr]
def _csr_fromdense_jvp(primals, tangents, *, nse, index_dtype):
M, = primals
Mdot, = tangents
@ -295,10 +331,15 @@ def _csr_fromdense_transpose(ct, M, *, nse, index_dtype):
ad.primitive_jvps[csr_fromdense_p] = _csr_fromdense_jvp
ad.primitive_transposes[csr_fromdense_p] = _csr_fromdense_transpose
xla.register_translation(csr_fromdense_p, _csr_fromdense_translation_rule)
mlir.register_lowering(csr_fromdense_p, _csr_fromdense_lowering)
if sparse_apis and sparse_apis.is_supported:
xla.register_translation(csr_fromdense_p,
_csr_fromdense_gpu_translation_rule,
platform='gpu')
if jax._src.lib.version > (0, 3, 5):
mlir.register_lowering(csr_fromdense_p,
_csr_fromdense_gpu_lowering,
platform='gpu')
#--------------------------------------------------------------------
# csr_matvec
@ -354,6 +395,22 @@ def _csr_matvec_gpu_translation_rule(ctx, avals_in, avals_out, data, indices,
return [sparse_apis.csr_matvec(ctx.builder, data, indices, indptr, v,
shape=shape, transpose=transpose)]
_csr_matvec_lowering = mlir.lower_fun(_csr_matvec_impl, multiple_results=False)
def _csr_matvec_gpu_lowering(ctx, data, indices, indptr, v, *,
shape, transpose):
data_aval, indices_aval, _, v_aval = ctx.avals_in
dtype = data_aval.dtype
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
warnings.warn(f"csr_matvec cusparse/hipsparse lowering not available for dtype={dtype}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _csr_matvec_lowering(ctx, data, indices, indptr, v, shape=shape,
transpose=transpose)
return [sparse_apis.csr_matvec_mhlo(
data, indices, indptr, v, shape=shape, transpose=transpose,
data_dtype=dtype, index_dtype=indices_aval.dtype, x_dtype=v_aval.dtype)]
def _csr_matvec_jvp_mat(data_dot, data, indices, indptr, v, *, shape, transpose):
return csr_matvec(data_dot, indices, indptr, v, shape=shape, transpose=transpose)
@ -376,9 +433,13 @@ def _csr_matvec_transpose(ct, data, indices, indptr, v, *, shape, transpose):
ad.defjvp(csr_matvec_p, _csr_matvec_jvp_mat, None, None, _csr_matvec_jvp_vec)
ad.primitive_transposes[csr_matvec_p] = _csr_matvec_transpose
xla.register_translation(csr_matvec_p, _csr_matvec_translation_rule)
mlir.register_lowering(csr_matvec_p, _csr_matvec_lowering)
if sparse_apis and sparse_apis.is_supported:
xla.register_translation(csr_matvec_p, _csr_matvec_gpu_translation_rule,
platform='gpu')
if jax._src.lib.version > (0, 3, 5):
mlir.register_lowering(csr_matvec_p, _csr_matvec_gpu_lowering,
platform='gpu')
#--------------------------------------------------------------------
@ -436,6 +497,22 @@ def _csr_matmat_gpu_translation_rule(ctx, avals_in, avals_out, data, indices,
return [sparse_apis.csr_matmat(ctx.builder, data, indices, indptr, B,
shape=shape, transpose=transpose)]
_csr_matmat_lowering = mlir.lower_fun(_csr_matmat_impl, multiple_results=False)
def _csr_matmat_gpu_lowering(ctx, data, indices, indptr, B, *, shape, transpose):
data_aval, indices_aval, _, B_aval = ctx.avals_in
dtype = data_aval.dtype
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
warnings.warn(f"csr_matmat cusparse/hipsparse lowering not available for dtype={dtype}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _csr_matmat_lowering(ctx, data, indices, indptr, B, shape=shape,
transpose=transpose)
return [sparse_apis.csr_matmat_mhlo(
data, indices, indptr, B, shape=shape, transpose=transpose,
index_dtype=indices_aval.dtype, data_dtype=data_aval.dtype,
B_dtype=B_aval.dtype)]
def _csr_matmat_jvp_left(data_dot, data, indices, indptr, B, *, shape, transpose):
return csr_matmat(data_dot, indices, indptr, B, shape=shape, transpose=transpose)
@ -456,6 +533,10 @@ def _csr_matmat_transpose(ct, data, indices, indptr, B, *, shape, transpose):
ad.defjvp(csr_matmat_p, _csr_matmat_jvp_left, None, None, _csr_matmat_jvp_right)
ad.primitive_transposes[csr_matmat_p] = _csr_matmat_transpose
xla.register_translation(csr_matmat_p, _csr_matmat_translation_rule)
mlir.register_lowering(csr_matmat_p, _csr_matmat_lowering)
if sparse_apis and sparse_apis.is_supported:
xla.register_translation(csr_matmat_p, _csr_matmat_gpu_translation_rule,
platform='gpu')
if jax._src.lib.version > (0, 3, 5):
mlir.register_lowering(csr_matmat_p, _csr_matmat_gpu_lowering,
platform='gpu')

View File

@ -46,6 +46,18 @@ def _validate_csr(c, data, indices, indptr, shape):
assert c.get_shape(indptr).dimensions() == (shape[0] + 1,)
return data_dtype, index_dtype, nnz
def _validate_csr_mhlo(data, indices, indptr, shape):
data_type = ir.RankedTensorType(data.type)
indices_type = ir.RankedTensorType(indices.type)
indptr_type = ir.RankedTensorType(indptr.type)
nnz, = data_type.shape
assert indices_type.shape == [nnz]
assert indptr_type.element_type == indices_type.element_type
assert indptr_type.shape == [shape[0] + 1]
return data_type.element_type, indices_type.element_type, nnz
def _validate_coo(c, data, row, col, shape):
data_dtype = np.dtype(c.get_shape(data).element_type())
index_dtype = np.dtype(c.get_shape(row).element_type())
@ -55,6 +67,17 @@ def _validate_coo(c, data, row, col, shape):
assert c.get_shape(col).dimensions() == (nnz,)
return data_dtype, index_dtype, nnz
def _validate_coo_mhlo(data, row, col, shape):
data_type = ir.RankedTensorType(data.type)
row_type = ir.RankedTensorType(row.type)
col_type = ir.RankedTensorType(col.type)
nnz, = data_type.shape
assert row_type.shape == [nnz]
assert col_type.element_type == row_type.element_type
assert col_type.shape == [nnz]
return data_type.element_type, row_type.element_type, nnz
def csr_todense(c, data, indices, indptr, *, shape):
"""CSR to dense matrix."""
data_dtype, index_dtype, nnz = _validate_csr(c, data, indices, indptr, shape)
@ -83,6 +106,38 @@ def csr_todense(c, data, indices, indptr, *, shape):
return _ops.GetTupleElement(out, 0)
def csr_todense_mhlo(data, indices, indptr, *, shape, data_dtype, index_dtype):
"""CSR to dense matrix."""
data_type, index_type, nnz = _validate_csr_mhlo(data, indices, indptr, shape)
rows, cols = shape
buffer_size, opaque = _cusparse.build_csr_todense_descriptor(
data_dtype, index_dtype, rows, cols, nnz)
i32_type = ir.IntegerType.get_signless(32)
out = mhlo.CustomCallOp(
[ir.TupleType.get_tuple([
ir.RankedTensorType.get(shape, data_type),
ir.RankedTensorType.get([buffer_size],
ir.IntegerType.get_signless(8)),
])],
[data, indices, indptr],
call_target_name=ir.StringAttr.get("cusparse_csr_todense"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(opaque),
api_version=ir.IntegerAttr.get(i32_type, 2),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
] * 3),
result_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([1, 0]),
type=ir.IndexType.get()),
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
]))
return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
def csr_fromdense(c, mat, *, nnz, index_dtype):
"""CSR from dense matrix."""
data_dtype = np.dtype(c.get_shape(mat).element_type())
@ -113,7 +168,43 @@ def csr_fromdense(c, mat, *, nnz, index_dtype):
return tuple(_ops.GetTupleElement(out, i) for i in range(3))
def csr_matvec(c, data, indices, indptr, x, *, shape, transpose=False, compute_dtype=None):
def csr_fromdense_mhlo(mat, *, nnz, index_dtype, data_dtype, index_type):
"""CSR from dense matrix."""
mat_type = ir.RankedTensorType(mat.type)
rows, cols = mat_type.shape
buffer_size, opaque = _cusparse.build_csr_fromdense_descriptor(
data_dtype, index_dtype, rows, cols, nnz)
i32_type = ir.IntegerType.get_signless(32)
out = mhlo.CustomCallOp(
[ir.TupleType.get_tuple([
ir.RankedTensorType.get([nnz], mat_type.element_type),
ir.RankedTensorType.get([nnz], index_type),
ir.RankedTensorType.get([rows + 1], index_type),
ir.RankedTensorType.get([buffer_size],
ir.IntegerType.get_signless(8)),
])],
[mat],
call_target_name=ir.StringAttr.get("cusparse_csr_fromdense"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(opaque),
api_version=ir.IntegerAttr.get(i32_type, 2),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([1, 0]),
type=ir.IndexType.get()),
]),
result_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
] * 4))
return [
mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
for i in range(3)
]
def csr_matvec(c, data, indices, indptr, x, *, shape, transpose=False,
compute_dtype=None):
"""CSR matrix/vector multiply."""
data_dtype, index_dtype, nnz = _validate_csr(c, data, indices, indptr, shape)
rows, cols = shape
@ -147,8 +238,46 @@ def csr_matvec(c, data, indices, indptr, x, *, shape, transpose=False, compute_d
)
return _ops.GetTupleElement(out, 0)
def csr_matvec_mhlo(data, indices, indptr, x, *, shape, transpose=False,
compute_dtype=None, compute_type=None, data_dtype,
index_dtype, x_dtype):
"""CSR matrix/vector multiply."""
data_type, index_type, nnz = _validate_csr_mhlo(data, indices, indptr, shape)
rows, cols = shape
def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False, compute_dtype=None):
if compute_dtype is None:
compute_dtype = data_dtype
compute_type = data_type
buffer_size, opaque = _cusparse.build_csr_matvec_descriptor(
data_dtype, x_dtype, compute_dtype, index_dtype,
rows, cols, nnz, transpose)
out_size = cols if transpose else rows
i32_type = ir.IntegerType.get_signless(32)
out = mhlo.CustomCallOp(
[ir.TupleType.get_tuple([
ir.RankedTensorType.get([out_size], compute_type),
ir.RankedTensorType.get([buffer_size],
ir.IntegerType.get_signless(8)),
])],
[data, indices, indptr, x],
call_target_name=ir.StringAttr.get("cusparse_csr_matvec"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(opaque),
api_version=ir.IntegerAttr.get(i32_type, 2),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
] * 4),
result_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
] * 2))
return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False,
compute_dtype=None):
"""CSR from dense matrix."""
data_dtype, index_dtype, nnz = _validate_csr(c, data, indices, indptr, shape)
rows, cols = shape
@ -183,6 +312,50 @@ def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False, compute_d
)
return _ops.GetTupleElement(out, 0)
def csr_matmat_mhlo(data, indices, indptr, B, *, shape, transpose=False,
compute_dtype=None, compute_type=None, index_dtype,
data_dtype, B_dtype):
"""CSR from dense matrix."""
data_type, index_type, nnz = _validate_csr_mhlo(data, indices, indptr, shape)
rows, cols = shape
B_shape = ir.RankedTensorType(B.type).shape
_, Ccols = B_shape
if compute_dtype is None:
compute_dtype = data_dtype
compute_type = data_type
buffer_size, opaque = _cusparse.build_csr_matmat_descriptor(
data_dtype, B_dtype, compute_dtype, index_dtype,
rows, cols, Ccols, nnz, transpose)
out_size = cols if transpose else rows
i32_type = ir.IntegerType.get_signless(32)
out = mhlo.CustomCallOp(
[ir.TupleType.get_tuple([
ir.RankedTensorType.get([out_size, Ccols], compute_type),
ir.RankedTensorType.get([buffer_size],
ir.IntegerType.get_signless(8)),
])],
[data, indices, indptr, B],
call_target_name=ir.StringAttr.get("cusparse_csr_matmat"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(opaque),
api_version=ir.IntegerAttr.get(i32_type, 2),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
ir.DenseIntElementsAttr.get(np.array([1, 0]),
type=ir.IndexType.get()),
]),
result_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([1, 0]), type=ir.IndexType.get()),
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
]))
return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
def coo_todense(c, data, row, col, *, shape):
"""COO to dense matrix."""
@ -211,6 +384,37 @@ def coo_todense(c, data, row, col, *, shape):
)
return _ops.GetTupleElement(out, 0)
def coo_todense_mhlo(data, row, col, *, shape, data_dtype, index_dtype):
"""COO to dense matrix."""
data_type, _, nnz = _validate_coo_mhlo(data, row, col, shape)
rows, cols = shape
buffer_size, opaque = _cusparse.build_coo_todense_descriptor(
data_dtype, index_dtype, rows, cols, nnz)
i32_type = ir.IntegerType.get_signless(32)
out = mhlo.CustomCallOp(
[ir.TupleType.get_tuple([
ir.RankedTensorType.get(shape, data_type),
ir.RankedTensorType.get([buffer_size],
ir.IntegerType.get_signless(8)),
])],
[data, row, col],
call_target_name=ir.StringAttr.get("cusparse_coo_todense"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(opaque),
api_version=ir.IntegerAttr.get(i32_type, 2),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
] * 3),
result_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([1, 0]),
type=ir.IndexType.get()),
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
]))
return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
def coo_fromdense(c, mat, *, nnz, index_dtype):
"""COO from dense matrix."""
@ -241,7 +445,44 @@ def coo_fromdense(c, mat, *, nnz, index_dtype):
return tuple(_ops.GetTupleElement(out, i) for i in range(3))
def coo_matvec(c, data, row, col, x, *, shape, transpose=False, compute_dtype=None):
def coo_fromdense_mhlo(mat, *, nnz, data_dtype, index_dtype,
index_type):
"""COO from dense matrix."""
mat_type = ir.RankedTensorType(mat.type)
rows, cols = mat_type.shape
buffer_size, opaque = _cusparse.build_coo_fromdense_descriptor(
data_dtype, index_dtype, rows, cols, nnz)
i32_type = ir.IntegerType.get_signless(32)
out = mhlo.CustomCallOp(
[ir.TupleType.get_tuple([
ir.RankedTensorType.get([nnz], mat_type.element_type),
ir.RankedTensorType.get([nnz], index_type),
ir.RankedTensorType.get([nnz], index_type),
ir.RankedTensorType.get([buffer_size],
ir.IntegerType.get_signless(8)),
])],
[mat],
call_target_name=ir.StringAttr.get("cusparse_coo_fromdense"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(opaque),
api_version=ir.IntegerAttr.get(i32_type, 2),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([1, 0]),
type=ir.IndexType.get()),
]),
result_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
] * 4))
return [
mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
for i in range(3)
]
def coo_matvec(c, data, row, col, x, *, shape, transpose=False,
compute_dtype=None):
"""COO matrix/vector multiply."""
data_dtype, index_dtype, nnz = _validate_coo(c, data, row, col, shape)
rows, cols = shape
@ -276,7 +517,46 @@ def coo_matvec(c, data, row, col, x, *, shape, transpose=False, compute_dtype=No
return _ops.GetTupleElement(out, 0)
def coo_matmat(c, data, row, col, B, *, shape, transpose=False, compute_dtype=None):
def coo_matvec_mhlo(data, row, col, x, *, shape, transpose=False,
compute_dtype=None,
compute_type=None, index_dtype, data_dtype, x_dtype):
"""COO matrix/vector multiply."""
data_type, index_type, nnz = _validate_coo_mhlo(data, row, col, shape)
rows, cols = shape
if compute_dtype is None:
compute_dtype = data_dtype
compute_type = data_type
buffer_size, opaque = _cusparse.build_coo_matvec_descriptor(
data_dtype, x_dtype, compute_dtype, index_dtype,
rows, cols, nnz, transpose)
out_size = cols if transpose else rows
i32_type = ir.IntegerType.get_signless(32)
out = mhlo.CustomCallOp(
[ir.TupleType.get_tuple([
ir.RankedTensorType.get([out_size], compute_type),
ir.RankedTensorType.get([buffer_size],
ir.IntegerType.get_signless(8)),
])],
[data, row, col, x],
call_target_name=ir.StringAttr.get("cusparse_coo_matvec"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(opaque),
api_version=ir.IntegerAttr.get(i32_type, 2),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
] * 4),
result_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
] * 2))
return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
def coo_matmat(c, data, row, col, B, *, shape, transpose=False,
compute_dtype=None):
"""COO from dense matrix."""
data_dtype, index_dtype, nnz = _validate_coo(c, data, row, col, shape)
rows, cols = shape
@ -311,6 +591,51 @@ def coo_matmat(c, data, row, col, B, *, shape, transpose=False, compute_dtype=No
)
return _ops.GetTupleElement(out, 0)
def coo_matmat_mhlo(data, row, col, B, *, shape, transpose=False,
compute_dtype=None, compute_type=None, x_dtype,
data_dtype, index_dtype):
"""COO from dense matrix."""
data_type, index_type, nnz = _validate_coo_mhlo(data, row, col, shape)
rows, cols = shape
B_shape = ir.RankedTensorType(B.type).shape
_, Ccols = B_shape
if compute_dtype is None:
compute_dtype = data_dtype
compute_type = data_type
buffer_size, opaque = _cusparse.build_coo_matmat_descriptor(
data_dtype, x_dtype, compute_dtype, index_dtype,
rows, cols, Ccols, nnz, transpose)
out_size = cols if transpose else rows
i32_type = ir.IntegerType.get_signless(32)
out = mhlo.CustomCallOp(
[ir.TupleType.get_tuple([
ir.RankedTensorType.get([out_size, Ccols], compute_type),
ir.RankedTensorType.get([buffer_size],
ir.IntegerType.get_signless(8)),
])],
[data, row, col, B],
call_target_name=ir.StringAttr.get("cusparse_coo_matmat"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(opaque),
api_version=ir.IntegerAttr.get(i32_type, 2),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
ir.DenseIntElementsAttr.get(np.array([1, 0]),
type=ir.IndexType.get()),
]),
result_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([1, 0]),
type=ir.IndexType.get()),
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
]))
return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
def gtsv2(c, dl, d, du, B, *, m, n, ldb, t):
"""Calls `cusparse<t>gtsv2(dl, d, du, B, m, n, ldb)`."""
@ -347,7 +672,8 @@ def gtsv2_mhlo(dl, d, du, B, *, m, n, ldb, t):
[ir.TupleType.get_tuple([
ir.RankedTensorType.get(
[ldb, n], ir.F32Type.get() if f32 else ir.F64Type.get()),
ir.RankedTensorType.get([buffer_size], ir.IntegerType.get_signless(8)),
ir.RankedTensorType.get([buffer_size],
ir.IntegerType.get_signless(8)),
])],
[dl, d, du, B],
call_target_name = ir.StringAttr.get(
@ -363,7 +689,8 @@ def gtsv2_mhlo(dl, d, du, B, *, m, n, ldb, t):
ir.DenseIntElementsAttr.get(np.array([1, 0]), type=ir.IndexType.get())
]),
result_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([1, 0]), type=ir.IndexType.get()),
ir.DenseIntElementsAttr.get(np.array([1, 0]),
type=ir.IndexType.get()),
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
]))
return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result

View File

@ -46,6 +46,17 @@ def _validate_csr(c, data, indices, indptr, shape):
assert c.get_shape(indptr).dimensions() == (shape[0] + 1,)
return data_dtype, index_dtype, nnz
def _validate_csr_mhlo(data, indices, indptr, shape):
data_type = ir.RankedTensorType(data.type)
indices_type = ir.RankedTensorType(indices.type)
indptr_type = ir.RankedTensorType(indptr.type)
nnz, = data_type.shape
assert indices_type.shape == [nnz]
assert indptr_type.element_type == indices_type.element_type
assert indptr_type.shape == [shape[0] + 1]
return data_type.element_type, indices_type.element_type, nnz
def _validate_coo(c, data, row, col, shape):
data_dtype = np.dtype(c.get_shape(data).element_type())
index_dtype = np.dtype(c.get_shape(row).element_type())
@ -55,6 +66,17 @@ def _validate_coo(c, data, row, col, shape):
assert c.get_shape(col).dimensions() == (nnz,)
return data_dtype, index_dtype, nnz
def _validate_coo_mhlo(data, row, col, shape):
data_type = ir.RankedTensorType(data.type)
row_type = ir.RankedTensorType(row.type)
col_type = ir.RankedTensorType(col.type)
nnz, = data_type.shape
assert row_type.shape == [nnz]
assert col_type.element_type == row_type.element_type
assert col_type.shape == [nnz]
return data_type.element_type, row_type.element_type, nnz
def csr_todense(c, data, indices, indptr, *, shape):
"""CSR to dense matrix."""
data_dtype, index_dtype, nnz = _validate_csr(c, data, indices, indptr, shape)
@ -83,6 +105,39 @@ def csr_todense(c, data, indices, indptr, *, shape):
return _ops.GetTupleElement(out, 0)
def csr_todense_mhlo(data, indices, indptr, *, shape, data_dtype, index_dtype):
"""CSR to dense matrix."""
data_type, index_type, nnz = _validate_csr_mhlo(data, indices, indptr, shape)
rows, cols = shape
buffer_size, opaque = _hipsparse.build_csr_todense_descriptor(
data_dtype, index_dtype, rows, cols, nnz)
i32_type = ir.IntegerType.get_signless(32)
out = mhlo.CustomCallOp(
[ir.TupleType.get_tuple([
ir.RankedTensorType.get(shape, data_type),
ir.RankedTensorType.get([buffer_size],
ir.IntegerType.get_signless(8)),
])],
[data, indices, indptr],
call_target_name=ir.StringAttr.get("hipsparse_csr_todense"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(opaque),
api_version=ir.IntegerAttr.get(i32_type, 2),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
] * 3),
result_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([1, 0]),
type=ir.IndexType.get()),
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
]))
return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
def csr_fromdense(c, mat, *, nnz, index_dtype):
"""CSR from dense matrix."""
data_dtype = np.dtype(c.get_shape(mat).element_type())
@ -112,6 +167,40 @@ def csr_fromdense(c, mat, *, nnz, index_dtype):
return tuple(_ops.GetTupleElement(out, i) for i in range(3))
def csr_fromdense_mhlo(mat, *, nnz, index_dtype, data_dtype, index_type):
"""CSR from dense matrix."""
mat_type = ir.RankedTensorType(mat.type)
rows, cols = mat_type.shape
buffer_size, opaque = _hipsparse.build_csr_fromdense_descriptor(
data_dtype, index_dtype, rows, cols, nnz)
i32_type = ir.IntegerType.get_signless(32)
out = mhlo.CustomCallOp(
[ir.TupleType.get_tuple([
ir.RankedTensorType.get([nnz], mat_type.element_type),
ir.RankedTensorType.get([nnz], index_type),
ir.RankedTensorType.get([rows + 1], index_type),
ir.RankedTensorType.get([buffer_size],
ir.IntegerType.get_signless(8)),
])],
[mat],
call_target_name=ir.StringAttr.get("hipsparse_csr_fromdense"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(opaque),
api_version=ir.IntegerAttr.get(i32_type, 2),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([1, 0]),
type=ir.IndexType.get()),
]),
result_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
] * 4))
return [
mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
for i in range(3)
]
def csr_matvec(c, data, indices, indptr, x, *, shape, transpose=False, compute_dtype=None):
"""CSR matrix/vector multiply."""
@ -147,6 +236,43 @@ def csr_matvec(c, data, indices, indptr, x, *, shape, transpose=False, compute_d
)
return _ops.GetTupleElement(out, 0)
def csr_matvec_mhlo(data, indices, indptr, x, *, shape, transpose=False,
compute_dtype=None, compute_type=None, data_dtype,
index_dtype, x_dtype):
"""CSR matrix/vector multiply."""
data_type, index_type, nnz = _validate_csr_mhlo(data, indices, indptr, shape)
rows, cols = shape
if compute_dtype is None:
compute_dtype = data_dtype
compute_type = data_type
buffer_size, opaque = _hipsparse.build_csr_matvec_descriptor(
data_dtype, x_dtype, compute_dtype, index_dtype,
rows, cols, nnz, transpose)
out_size = cols if transpose else rows
i32_type = ir.IntegerType.get_signless(32)
out = mhlo.CustomCallOp(
[ir.TupleType.get_tuple([
ir.RankedTensorType.get([out_size], compute_type),
ir.RankedTensorType.get([buffer_size],
ir.IntegerType.get_signless(8)),
])],
[data, indices, indptr, x],
call_target_name=ir.StringAttr.get("hipsparse_csr_matvec"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(opaque),
api_version=ir.IntegerAttr.get(i32_type, 2),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
] * 4),
result_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
] * 2))
return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False, compute_dtype=None):
"""CSR from dense matrix."""
@ -183,6 +309,50 @@ def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False, compute_d
)
return _ops.GetTupleElement(out, 0)
def csr_matmat_mhlo(data, indices, indptr, B, *, shape, transpose=False,
compute_dtype=None, compute_type=None, index_dtype,
data_dtype, B_dtype):
"""CSR from dense matrix."""
data_type, index_type, nnz = _validate_csr_mhlo(data, indices, indptr, shape)
rows, cols = shape
B_shape = ir.RankedTensorType(B.type).shape
_, Ccols = B_shape
if compute_dtype is None:
compute_dtype = data_dtype
compute_type = data_type
buffer_size, opaque = _hipsparse.build_csr_matmat_descriptor(
data_dtype, B_dtype, compute_dtype, index_dtype,
rows, cols, Ccols, nnz, transpose)
out_size = cols if transpose else rows
i32_type = ir.IntegerType.get_signless(32)
out = mhlo.CustomCallOp(
[ir.TupleType.get_tuple([
ir.RankedTensorType.get([out_size, Ccols], compute_type),
ir.RankedTensorType.get([buffer_size],
ir.IntegerType.get_signless(8)),
])],
[data, indices, indptr, B],
call_target_name=ir.StringAttr.get("hipsparse_csr_matmat"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(opaque),
api_version=ir.IntegerAttr.get(i32_type, 2),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
ir.DenseIntElementsAttr.get(np.array([1, 0]),
type=ir.IndexType.get()),
]),
result_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([1, 0]), type=ir.IndexType.get()),
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
]))
return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
def coo_todense(c, data, row, col, *, shape):
"""COO to dense matrix."""
@ -211,6 +381,36 @@ def coo_todense(c, data, row, col, *, shape):
)
return _ops.GetTupleElement(out, 0)
def coo_todense_mhlo(data, row, col, *, shape, data_dtype, index_dtype):
"""COO to dense matrix."""
data_type, _, nnz = _validate_coo_mhlo(data, row, col, shape)
rows, cols = shape
buffer_size, opaque = _hipsparse.build_coo_todense_descriptor(
data_dtype, index_dtype, rows, cols, nnz)
i32_type = ir.IntegerType.get_signless(32)
out = mhlo.CustomCallOp(
[ir.TupleType.get_tuple([
ir.RankedTensorType.get(shape, data_type),
ir.RankedTensorType.get([buffer_size],
ir.IntegerType.get_signless(8)),
])],
[data, row, col],
call_target_name=ir.StringAttr.get("hipsparse_coo_todense"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(opaque),
api_version=ir.IntegerAttr.get(i32_type, 2),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
] * 3),
result_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([1, 0]),
type=ir.IndexType.get()),
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
]))
return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
def coo_fromdense(c, mat, *, nnz, index_dtype):
"""COO from dense matrix."""
@ -241,6 +441,42 @@ def coo_fromdense(c, mat, *, nnz, index_dtype):
return tuple(_ops.GetTupleElement(out, i) for i in range(3))
def coo_fromdense_mhlo(mat, *, nnz, data_dtype, index_dtype,
index_type):
"""COO from dense matrix."""
mat_type = ir.RankedTensorType(mat.type)
rows, cols = mat_type.shape
buffer_size, opaque = _hipsparse.build_coo_fromdense_descriptor(
data_dtype, index_dtype, rows, cols, nnz)
i32_type = ir.IntegerType.get_signless(32)
out = mhlo.CustomCallOp(
[ir.TupleType.get_tuple([
ir.RankedTensorType.get([nnz], mat_type.element_type),
ir.RankedTensorType.get([nnz], index_type),
ir.RankedTensorType.get([nnz], index_type),
ir.RankedTensorType.get([buffer_size],
ir.IntegerType.get_signless(8)),
])],
[mat],
call_target_name=ir.StringAttr.get("hipsparse_coo_fromdense"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(opaque),
api_version=ir.IntegerAttr.get(i32_type, 2),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([1, 0]),
type=ir.IndexType.get()),
]),
result_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
] * 4))
return [
mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
for i in range(3)
]
def coo_matvec(c, data, row, col, x, *, shape, transpose=False, compute_dtype=None):
"""COO matrix/vector multiply."""
data_dtype, index_dtype, nnz = _validate_coo(c, data, row, col, shape)
@ -275,6 +511,43 @@ def coo_matvec(c, data, row, col, x, *, shape, transpose=False, compute_dtype=No
)
return _ops.GetTupleElement(out, 0)
def coo_matvec_mhlo(data, row, col, x, *, shape, transpose=False,
compute_dtype=None,
compute_type=None, index_dtype, data_dtype, x_dtype):
"""COO matrix/vector multiply."""
data_type, index_type, nnz = _validate_coo_mhlo(data, row, col, shape)
rows, cols = shape
if compute_dtype is None:
compute_dtype = data_dtype
compute_type = data_type
buffer_size, opaque = _hipsparse.build_coo_matvec_descriptor(
data_dtype, x_dtype, compute_dtype, index_dtype,
rows, cols, nnz, transpose)
out_size = cols if transpose else rows
i32_type = ir.IntegerType.get_signless(32)
out = mhlo.CustomCallOp(
[ir.TupleType.get_tuple([
ir.RankedTensorType.get([out_size], compute_type),
ir.RankedTensorType.get([buffer_size],
ir.IntegerType.get_signless(8)),
])],
[data, row, col, x],
call_target_name=ir.StringAttr.get("hipsparse_coo_matvec"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(opaque),
api_version=ir.IntegerAttr.get(i32_type, 2),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
] * 4),
result_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
] * 2))
return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
def coo_matmat(c, data, row, col, B, *, shape, transpose=False, compute_dtype=None):
"""COO from dense matrix."""
@ -311,6 +584,51 @@ def coo_matmat(c, data, row, col, B, *, shape, transpose=False, compute_dtype=No
)
return _ops.GetTupleElement(out, 0)
def coo_matmat_mhlo(data, row, col, B, *, shape, transpose=False,
compute_dtype=None, compute_type=None, x_dtype,
data_dtype, index_dtype):
"""COO from dense matrix."""
data_type, index_type, nnz = _validate_coo_mhlo(data, row, col, shape)
rows, cols = shape
B_shape = ir.RankedTensorType(B.type).shape
_, Ccols = B_shape
if compute_dtype is None:
compute_dtype = data_dtype
compute_type = data_type
buffer_size, opaque = _hipsparse.build_coo_matmat_descriptor(
data_dtype, x_dtype, compute_dtype, index_dtype,
rows, cols, Ccols, nnz, transpose)
out_size = cols if transpose else rows
i32_type = ir.IntegerType.get_signless(32)
out = mhlo.CustomCallOp(
[ir.TupleType.get_tuple([
ir.RankedTensorType.get([out_size, Ccols], compute_type),
ir.RankedTensorType.get([buffer_size],
ir.IntegerType.get_signless(8)),
])],
[data, row, col, B],
call_target_name=ir.StringAttr.get("hipsparse_coo_matmat"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(opaque),
api_version=ir.IntegerAttr.get(i32_type, 2),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
ir.DenseIntElementsAttr.get(np.array([1, 0]),
type=ir.IndexType.get()),
]),
result_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(np.array([1, 0]),
type=ir.IndexType.get()),
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
]))
return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
def gtsv2(c, dl, d, du, B, *, m, n, ldb, t):
"""Calls `hipsparse<t>gtsv2(dl, d, du, B, m, n, ldb)`."""

View File

@ -236,6 +236,9 @@ def _split_abstract_eval(mat):
xla.register_translation(
split_p, xla.lower_fun(_split_impl, multiple_results=True, new_style=True))
mlir.register_lowering(
split_p, mlir.lower_fun(_split_impl, multiple_results=True))
def make_sparse_array(rng, shape, dtype, nnz=0.2):
mat = rng(shape, dtype)
size = int(np.prod(shape))