mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
[MHLO] Add direct MHLO lowerings for sparse primitives.
PiperOrigin-RevId: 440374054
This commit is contained in:
parent
1cb4fccd1d
commit
648a512488
@ -41,6 +41,7 @@ from jax.experimental.sparse.csr import CSR, CSC
|
||||
from jax.experimental.sparse.util import _coo_extract
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax._src import dtypes
|
||||
|
||||
@ -102,6 +103,9 @@ batching.primitive_batchers[todense_p] = _todense_batching_rule
|
||||
xla.register_translation(todense_p, xla.lower_fun(
|
||||
_todense_impl, multiple_results=False, new_style=True))
|
||||
|
||||
mlir.register_lowering(todense_p, mlir.lower_fun(
|
||||
_todense_impl, multiple_results=False))
|
||||
|
||||
|
||||
def empty(shape, dtype=None, index_dtype='int32', sparse_format='bcoo', **kwds):
|
||||
"""Create an empty sparse array.
|
||||
|
@ -29,6 +29,7 @@ from jax.experimental.sparse._base import JAXSparse
|
||||
from jax.experimental.sparse.util import _safe_asarray, CuSparseEfficiencyWarning
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
import jax.numpy as jnp
|
||||
from jax.interpreters import ad
|
||||
@ -38,6 +39,9 @@ from jax._src.api_util import flatten_axes
|
||||
from jax._src.lax.lax import (
|
||||
ranges_like, remaining, _dot_general_batch_dim_nums, _dot_general_shape_rule,
|
||||
DotDimensionNumbers)
|
||||
import jax._src.lib
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.numpy.setops import _unique
|
||||
|
||||
@ -146,6 +150,9 @@ def _bcoo_sort_indices(data, indices, shape):
|
||||
_bcoo_sort_indices_rule = xla.lower_fun(
|
||||
_bcoo_sort_indices, multiple_results=True, new_style=True)
|
||||
|
||||
_bcoo_sort_indices_mhlo = mlir.lower_fun(
|
||||
_bcoo_sort_indices, multiple_results=True)
|
||||
|
||||
def _unbatch_bcoo(data, indices, shape):
|
||||
n_batch = _validate_bcoo(data, indices, shape).n_batch
|
||||
if n_batch == 0:
|
||||
@ -282,6 +289,8 @@ ad.primitive_transposes[bcoo_todense_p] = _bcoo_todense_transpose
|
||||
batching.primitive_batchers[bcoo_todense_p] = _bcoo_todense_batching_rule
|
||||
xla.register_translation(bcoo_todense_p, xla.lower_fun(
|
||||
_bcoo_todense_impl, multiple_results=False, new_style=True))
|
||||
mlir.register_lowering(bcoo_todense_p, mlir.lower_fun(
|
||||
_bcoo_todense_impl, multiple_results=False))
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# bcoo_fromdense
|
||||
@ -408,6 +417,8 @@ ad.primitive_transposes[bcoo_fromdense_p] = _bcoo_fromdense_transpose
|
||||
batching.primitive_batchers[bcoo_fromdense_p] = _bcoo_fromdense_batching_rule
|
||||
xla.register_translation(bcoo_fromdense_p, xla.lower_fun(
|
||||
_bcoo_fromdense_impl, multiple_results=True, new_style=True))
|
||||
mlir.register_lowering(bcoo_fromdense_p, mlir.lower_fun(
|
||||
_bcoo_fromdense_impl, multiple_results=True))
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
# bcoo_extract
|
||||
@ -487,6 +498,8 @@ ad.primitive_transposes[bcoo_extract_p] = _bcoo_extract_transpose
|
||||
batching.primitive_batchers[bcoo_extract_p] = _bcoo_extract_batching_rule
|
||||
xla.register_translation(bcoo_extract_p, xla.lower_fun(
|
||||
_bcoo_extract_impl, multiple_results=False, new_style=True))
|
||||
mlir.register_lowering(bcoo_extract_p, mlir.lower_fun(
|
||||
_bcoo_extract_impl, multiple_results=False))
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
# bcoo_transpose
|
||||
@ -602,6 +615,8 @@ ad.primitive_transposes[bcoo_transpose_p] = _bcoo_transpose_transpose
|
||||
batching.primitive_batchers[bcoo_transpose_p] = _bcoo_transpose_batch_rule
|
||||
xla.register_translation(bcoo_transpose_p, xla.lower_fun(
|
||||
_bcoo_transpose_impl, multiple_results=True, new_style=True))
|
||||
mlir.register_lowering(bcoo_transpose_p, mlir.lower_fun(
|
||||
_bcoo_transpose_impl, multiple_results=True))
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
# bcoo_dot_general
|
||||
@ -862,6 +877,146 @@ def _bcoo_dot_general_gpu_translation_rule(
|
||||
ctx, avals_in, avals_out, lhs_data, lhs_indices, rhs,
|
||||
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
|
||||
|
||||
_bcoo_dot_general_default_lowering = mlir.lower_fun(
|
||||
_bcoo_dot_general_impl, multiple_results=False)
|
||||
|
||||
def _collapse_mhlo(x, start, end):
|
||||
x_type = ir.RankedTensorType(x.type)
|
||||
shape = x_type.shape
|
||||
shape = (shape[:start]
|
||||
+ [functools.reduce(operator.mul, shape[start:end + 1])]
|
||||
+ shape[end + 1:])
|
||||
return mhlo.ReshapeOp(
|
||||
ir.RankedTensorType.get(shape, x_type.element_type), x).result
|
||||
|
||||
def _bcoo_dot_general_cuda_lowering(
|
||||
ctx, lhs_data, lhs_indices, rhs, *, dimension_numbers,
|
||||
lhs_spinfo: BCOOInfo):
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
lhs_data_aval, lhs_indices_aval, rhs_aval, = ctx.avals_in
|
||||
props = _validate_bcoo_indices(lhs_indices_aval, lhs_spinfo.shape)
|
||||
rhs_ndim = len(ir.RankedTensorType(rhs.type).shape)
|
||||
|
||||
# Checks the shapes of lhs and rhs.
|
||||
assert props.n_dense == 0
|
||||
assert props.n_batch == 0
|
||||
assert props.n_sparse in [1, 2]
|
||||
assert rhs_ndim in [1, 2]
|
||||
|
||||
# Checks the operation dimensions.
|
||||
assert len(lhs_batch) == 0
|
||||
assert len(rhs_batch) == 0
|
||||
assert len(lhs_contract) == 1
|
||||
|
||||
# Checks the dtype.
|
||||
assert lhs_data_aval.dtype in [np.float32, np.float64, np.complex64,
|
||||
np.complex128]
|
||||
assert lhs_data_aval.dtype == rhs_aval.dtype
|
||||
assert lhs_indices_aval.dtype == np.int32
|
||||
assert sparse_apis is not None
|
||||
|
||||
if rhs_ndim == 1:
|
||||
bcoo_dot_general_fn = sparse_apis.coo_matvec_mhlo
|
||||
elif rhs_ndim == 2:
|
||||
bcoo_dot_general_fn = sparse_apis.coo_matmat_mhlo
|
||||
if rhs_contract[0] == 1:
|
||||
rhs = mhlo.TransposeOp(
|
||||
rhs, permutation=mlir.dense_int_elements([1, 0])).result
|
||||
else:
|
||||
raise ValueError(f"rhs has to be 1d or 2d; get {rhs_ndim}d.")
|
||||
|
||||
lhs_transpose = False
|
||||
if props.n_sparse == 1:
|
||||
# Converts lhs to a row vector.
|
||||
col = _collapse_mhlo(lhs_indices, start=0, end=1)
|
||||
row = mlir.full_like_aval(
|
||||
0, core.ShapedArray(ir.RankedTensorType(col.type).shape,
|
||||
np.dtype(np.int32)))
|
||||
lhs_shape = (1, lhs_spinfo.shape[0])
|
||||
dot_product = bcoo_dot_general_fn(
|
||||
lhs_data, row, col, rhs, shape=lhs_shape, transpose=lhs_transpose,
|
||||
data_dtype=lhs_data_aval.dtype, index_dtype=lhs_indices_aval.dtype,
|
||||
x_dtype=rhs_aval.dtype)
|
||||
|
||||
if rhs_ndim == 1:
|
||||
# Transforms a single-element array to a scalar.
|
||||
return [mhlo.ReshapeOp(
|
||||
ir.RankedTensorType(
|
||||
[], ir.RankedTensorType(dot_product.type).element_type),
|
||||
dot_product).result]
|
||||
else:
|
||||
return [_collapse_mhlo(dot_product, start=0, end=1)]
|
||||
elif props.n_sparse == 2:
|
||||
lhs_indices_shape = ir.RankedTensorType(lhs_indices.type).shape
|
||||
row = _collapse_mhlo(
|
||||
mhlo.SliceOp(
|
||||
lhs_indices,
|
||||
start_indices=mlir.dense_int_elements([0, 0]),
|
||||
limit_indices=mlir.dense_int_elements([lhs_indices_shape[0], 1]),
|
||||
strides=mlir.dense_int_elements([1, 1])).result,
|
||||
start=0, end=1)
|
||||
col = _collapse_mhlo(
|
||||
mhlo.SliceOp(
|
||||
lhs_indices,
|
||||
start_indices=mlir.dense_int_elements([0, 1]),
|
||||
limit_indices=mlir.dense_int_elements([lhs_indices_shape[0], 2]),
|
||||
strides=mlir.dense_int_elements([1, 1])).result,
|
||||
start=0, end=1)
|
||||
|
||||
if lhs_contract[0] == 0:
|
||||
lhs_transpose = True
|
||||
|
||||
return [bcoo_dot_general_fn(
|
||||
lhs_data, row, col, rhs, shape=lhs_spinfo.shape,
|
||||
transpose=lhs_transpose, data_dtype=lhs_data_aval.dtype,
|
||||
index_dtype=lhs_indices_aval.dtype,
|
||||
x_dtype=rhs_aval.dtype)]
|
||||
else:
|
||||
raise ValueError(f"lhs has to be 1d or 2d; get {props.n_sparse}d.")
|
||||
|
||||
def _bcoo_dot_general_gpu_lowering(
|
||||
ctx, lhs_data, lhs_indices, rhs, *, dimension_numbers,
|
||||
lhs_spinfo: BCOOInfo):
|
||||
|
||||
if not config.jax_bcoo_cusparse_lowering:
|
||||
return _bcoo_dot_general_default_lowering(
|
||||
ctx, lhs_data, lhs_indices, rhs,
|
||||
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
|
||||
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
lhs_data_aval, lhs_indices_aval, rhs_aval, = ctx.avals_in
|
||||
n_batch, n_sparse, n_dense, nse = _validate_bcoo(
|
||||
lhs_data_aval, lhs_indices_aval, lhs_spinfo.shape)
|
||||
|
||||
dtype = lhs_data_aval.dtype
|
||||
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
|
||||
warnings.warn(f'bcoo_dot_general cusparse/hipsparse lowering not available '
|
||||
f'for dtype={dtype}. Falling back to default implementation.',
|
||||
CuSparseEfficiencyWarning)
|
||||
return _bcoo_dot_general_default_lowering(
|
||||
ctx, lhs_data, lhs_indices, rhs,
|
||||
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
|
||||
|
||||
if (n_batch or n_dense or
|
||||
n_sparse not in [1, 2] or rhs_aval.ndim not in [1, 2] or
|
||||
lhs_batch or rhs_batch or len(lhs_contract) != 1):
|
||||
return _bcoo_dot_general_default_lowering(
|
||||
ctx, lhs_data, lhs_indices, rhs,
|
||||
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
|
||||
else:
|
||||
# Sorts lhs by row indices.
|
||||
sub_ctx = mlir.LoweringRuleContext(module_context=ctx.module_context,
|
||||
primitive=None,
|
||||
avals_in=ctx.avals_in[:2],
|
||||
avals_out=ctx.avals_in[:2])
|
||||
|
||||
(lhs_data,), (lhs_indices,) = _bcoo_sort_indices_mhlo(
|
||||
sub_ctx, lhs_data, lhs_indices, shape=lhs_spinfo.shape)
|
||||
|
||||
return _bcoo_dot_general_cuda_lowering(
|
||||
ctx, lhs_data, lhs_indices, rhs,
|
||||
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
|
||||
|
||||
def _bcoo_dot_general_jvp_lhs(lhs_data_dot, lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: BCOOInfo):
|
||||
return _bcoo_dot_general(lhs_data_dot, lhs_indices, rhs, dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
|
||||
|
||||
@ -930,10 +1085,16 @@ batching.primitive_batchers[bcoo_dot_general_p] = _bcoo_dot_general_batch_rule
|
||||
|
||||
xla.register_translation(
|
||||
bcoo_dot_general_p, _bcoo_dot_general_default_translation_rule)
|
||||
mlir.register_lowering(
|
||||
bcoo_dot_general_p, _bcoo_dot_general_default_lowering)
|
||||
if sparse_apis and sparse_apis.is_supported:
|
||||
xla.register_translation(bcoo_dot_general_p,
|
||||
_bcoo_dot_general_gpu_translation_rule,
|
||||
platform='gpu')
|
||||
if jax._src.lib.version > (0, 3, 5):
|
||||
mlir.register_lowering(bcoo_dot_general_p,
|
||||
_bcoo_dot_general_gpu_lowering,
|
||||
platform='gpu')
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
# bcoo_dot_general_sampled
|
||||
@ -1005,6 +1166,9 @@ ad.primitive_transposes[bcoo_dot_general_sampled_p] = _bcoo_dot_general_sampled_
|
||||
batching.primitive_batchers[bcoo_dot_general_sampled_p] = _bcoo_dot_general_sampled_batch_rule
|
||||
xla.register_translation(bcoo_dot_general_sampled_p, xla.lower_fun(
|
||||
_bcoo_dot_general_sampled_impl, multiple_results=False, new_style=True))
|
||||
mlir.register_lowering(
|
||||
bcoo_dot_general_sampled_p,
|
||||
mlir.lower_fun(_bcoo_dot_general_sampled_impl, multiple_results=False))
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
# bcoo_spdot_general
|
||||
@ -1205,6 +1369,8 @@ batching.primitive_batchers[bcoo_spdot_general_p] = _bcoo_spdot_general_batch_ru
|
||||
ad.primitive_jvps[bcoo_spdot_general_p] = _bcoo_spdot_general_jvp
|
||||
xla.register_translation(bcoo_spdot_general_p, xla.lower_fun(
|
||||
_bcoo_spdot_general_impl, multiple_results=True, new_style=True))
|
||||
mlir.register_lowering(bcoo_spdot_general_p, mlir.lower_fun(
|
||||
_bcoo_spdot_general_impl, multiple_results=True))
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
# BCOO functions that maybe should be primitives?
|
||||
|
@ -23,10 +23,13 @@ import numpy as np
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax.experimental.sparse._base import JAXSparse
|
||||
from jax.experimental.sparse.util import _coo_extract, _safe_asarray, CuSparseEfficiencyWarning
|
||||
from jax import tree_util
|
||||
import jax._src.lib
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lib import sparse_apis
|
||||
from jax._src.numpy.lax_numpy import _promote_dtypes
|
||||
from jax._src.lib import xla_client
|
||||
@ -183,6 +186,37 @@ def _coo_todense_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
|
||||
|
||||
return [xops.Transpose(result, (1, 0))] if transpose else [result]
|
||||
|
||||
_coo_todense_lowering = mlir.lower_fun(
|
||||
_coo_todense_impl, multiple_results=False)
|
||||
|
||||
def _coo_todense_gpu_lowering(ctx, data, row, col, *, spinfo):
|
||||
data_aval, row_aval, _ = ctx.avals_in
|
||||
dtype = data_aval.dtype
|
||||
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
|
||||
warnings.warn(f"coo_todense cusparse/hipsparse lowering not available for dtype={dtype}. "
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _coo_todense_lowering(ctx, data, row, col, spinfo=spinfo)
|
||||
|
||||
if spinfo.rows_sorted:
|
||||
shape = spinfo.shape
|
||||
transpose = False
|
||||
elif spinfo.cols_sorted:
|
||||
row, col = col, row
|
||||
transpose = True
|
||||
shape = spinfo.shape[::-1]
|
||||
else:
|
||||
warnings.warn("coo_todense GPU lowering requires matrices with sorted rows or sorted cols. "
|
||||
"To sort the rows in your matrix, use e.g. mat = mat._sort_rows(). Falling "
|
||||
"back to the default implementation.", CuSparseEfficiencyWarning)
|
||||
return _coo_todense_lowering(ctx, data, row, col, spinfo=spinfo)
|
||||
|
||||
result = sparse_apis.coo_todense_mhlo(
|
||||
data, row, col, shape=shape, data_dtype=dtype, index_dtype=row_aval.dtype)
|
||||
return (
|
||||
[mhlo.TransposeOp(result, mlir.dense_int_elements([1, 0])).result]
|
||||
if transpose else [result])
|
||||
|
||||
|
||||
def _coo_todense_jvp(data_dot, data, row, col, *, spinfo):
|
||||
return _coo_todense(data_dot, row, col, spinfo=spinfo)
|
||||
|
||||
@ -200,9 +234,13 @@ def _coo_todense_transpose(ct, data, row, col, *, spinfo):
|
||||
ad.defjvp(coo_todense_p, _coo_todense_jvp, None, None)
|
||||
ad.primitive_transposes[coo_todense_p] = _coo_todense_transpose
|
||||
xla.register_translation(coo_todense_p, _coo_todense_translation_rule)
|
||||
mlir.register_lowering(coo_todense_p, _coo_todense_lowering)
|
||||
if sparse_apis and sparse_apis.is_supported:
|
||||
xla.register_translation(coo_todense_p, _coo_todense_gpu_translation_rule,
|
||||
platform='gpu')
|
||||
if jax._src.lib.version > (0, 3, 5):
|
||||
mlir.register_lowering(coo_todense_p, _coo_todense_gpu_lowering,
|
||||
platform='gpu')
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# coo_fromdense
|
||||
@ -278,6 +316,23 @@ def _coo_fromdense_gpu_translation_rule(ctx, avals_in, avals_out, mat, *, nse,
|
||||
ctx.builder, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
|
||||
return [data, row, col]
|
||||
|
||||
_coo_fromdense_lowering = mlir.lower_fun(
|
||||
_coo_fromdense_impl, multiple_results=True)
|
||||
|
||||
def _coo_fromdense_gpu_lowering(ctx, mat, *, nse, index_dtype):
|
||||
dtype = ctx.avals_in[0].dtype
|
||||
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
|
||||
warnings.warn(f"coo_fromdense cusparse/hipsparse lowering not available for dtype={dtype}. "
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _coo_fromdense_lowering(ctx, mat, nse=nse, index_dtype=index_dtype)
|
||||
data, row, col = sparse_apis.coo_fromdense_mhlo(
|
||||
mat, nnz=nse,
|
||||
data_dtype=dtype,
|
||||
index_dtype=np.dtype(index_dtype),
|
||||
index_type=mlir.dtype_to_ir_type(np.dtype(index_dtype)))
|
||||
return [data, row, col]
|
||||
|
||||
|
||||
def _coo_fromdense_jvp(primals, tangents, *, nse, index_dtype):
|
||||
M, = primals
|
||||
Mdot, = tangents
|
||||
@ -307,10 +362,15 @@ ad.primitive_jvps[coo_fromdense_p] = _coo_fromdense_jvp
|
||||
ad.primitive_transposes[coo_fromdense_p] = _coo_fromdense_transpose
|
||||
|
||||
xla.register_translation(coo_fromdense_p, _coo_fromdense_translation_rule)
|
||||
mlir.register_lowering(coo_fromdense_p, _coo_fromdense_lowering)
|
||||
if sparse_apis and sparse_apis.is_supported:
|
||||
xla.register_translation(coo_fromdense_p,
|
||||
_coo_fromdense_gpu_translation_rule,
|
||||
platform='gpu')
|
||||
if jax._src.lib.version > (0, 3, 5):
|
||||
mlir.register_lowering(coo_fromdense_p,
|
||||
_coo_fromdense_gpu_lowering,
|
||||
platform='gpu')
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# coo_matvec
|
||||
@ -400,6 +460,36 @@ def _coo_matvec_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
|
||||
return [sparse_apis.coo_matvec(ctx.builder, data, row, col, v, shape=shape,
|
||||
transpose=transpose)]
|
||||
|
||||
_coo_matvec_lowering = mlir.lower_fun(
|
||||
_coo_matvec_impl, multiple_results=False)
|
||||
|
||||
def _coo_matvec_gpu_lowering(ctx, data, row, col, v, *, spinfo, transpose):
|
||||
data_aval, row_aval, _, x_aval = ctx.avals_in
|
||||
dtype = data_aval.dtype
|
||||
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
|
||||
warnings.warn(f"coo_matvec cusparse/hipsparse lowering not available for dtype={dtype}. "
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _coo_matvec_lowering(ctx, data, row, col, v, spinfo=spinfo,
|
||||
transpose=transpose)
|
||||
|
||||
if spinfo.rows_sorted:
|
||||
shape = spinfo.shape
|
||||
elif spinfo.cols_sorted:
|
||||
row, col = col, row
|
||||
transpose = not transpose
|
||||
shape = spinfo.shape[::-1]
|
||||
else:
|
||||
warnings.warn("coo_matvec GPU lowering requires matrices with sorted rows or sorted cols. "
|
||||
"To sort the rows in your matrix, use e.g. mat = mat._sort_rows(). Falling "
|
||||
"back to the default implementation.", CuSparseEfficiencyWarning)
|
||||
return _coo_matvec_lowering(ctx, data, row, col, v, spinfo=spinfo,
|
||||
transpose=transpose)
|
||||
|
||||
return [sparse_apis.coo_matvec_mhlo(
|
||||
data, row, col, v, shape=shape, transpose=transpose,
|
||||
index_dtype=row_aval.dtype, data_dtype=dtype, x_dtype=x_aval.dtype)]
|
||||
|
||||
|
||||
def _coo_matvec_jvp_mat(data_dot, data, row, col, v, *, spinfo, transpose):
|
||||
return _coo_matvec(data_dot, row, col, v, spinfo=spinfo, transpose=transpose)
|
||||
|
||||
@ -421,9 +511,13 @@ def _coo_matvec_transpose(ct, data, row, col, v, *, spinfo, transpose):
|
||||
ad.defjvp(coo_matvec_p, _coo_matvec_jvp_mat, None, None, _coo_matvec_jvp_vec)
|
||||
ad.primitive_transposes[coo_matvec_p] = _coo_matvec_transpose
|
||||
xla.register_translation(coo_matvec_p, _coo_matvec_translation_rule)
|
||||
mlir.register_lowering(coo_matvec_p, _coo_matvec_lowering)
|
||||
if sparse_apis and sparse_apis.is_supported:
|
||||
xla.register_translation(coo_matvec_p, _coo_matvec_gpu_translation_rule,
|
||||
platform='gpu')
|
||||
if jax._src.lib.version > (0, 3, 5):
|
||||
mlir.register_lowering(coo_matvec_p, _coo_matvec_gpu_lowering,
|
||||
platform='gpu')
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# coo_matmat
|
||||
@ -511,6 +605,35 @@ def _coo_matmat_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
|
||||
return [sparse_apis.coo_matmat(ctx.builder, data, row, col, B, shape=shape,
|
||||
transpose=transpose)]
|
||||
|
||||
_coo_matmat_lowering = mlir.lower_fun(_coo_matmat_impl, multiple_results=False)
|
||||
|
||||
def _coo_matmat_gpu_lowering(ctx, data, row, col, B, *, spinfo, transpose):
|
||||
data_aval, row_aval, _, B_aval = ctx.avals_in
|
||||
dtype = data_aval.dtype
|
||||
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
|
||||
warnings.warn(f"coo_matmat cusparse/hipsprse lowering not available for dtype={dtype}. "
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _coo_matmat_lowering(ctx, data, row, col, B, spinfo=spinfo,
|
||||
transpose=transpose)
|
||||
if spinfo.rows_sorted:
|
||||
shape = spinfo.shape
|
||||
elif spinfo.cols_sorted:
|
||||
row, col = col, row
|
||||
transpose = not transpose
|
||||
shape = spinfo.shape[::-1]
|
||||
else:
|
||||
warnings.warn("coo_matmat GPU lowering requires matrices with sorted rows or sorted cols. "
|
||||
"To sort the rows in your matrix, use e.g. mat = mat._sort_rows(). Falling "
|
||||
"back to the default implementation.", CuSparseEfficiencyWarning)
|
||||
return _coo_matmat_lowering(ctx, data, row, col, B, spinfo=spinfo,
|
||||
transpose=transpose)
|
||||
|
||||
return [sparse_apis.coo_matmat_mhlo(data, row, col, B, shape=shape,
|
||||
transpose=transpose, x_dtype=B_aval.dtype,
|
||||
data_dtype=data_aval.dtype,
|
||||
index_dtype=row_aval.dtype)]
|
||||
|
||||
|
||||
def _coo_matmat_jvp_left(data_dot, data, row, col, B, *, spinfo, transpose):
|
||||
return _coo_matmat(data_dot, row, col, B, spinfo=spinfo, transpose=transpose)
|
||||
|
||||
@ -529,6 +652,10 @@ def _coo_matmat_transpose(ct, data, row, col, B, *, spinfo, transpose):
|
||||
ad.defjvp(coo_matmat_p, _coo_matmat_jvp_left, None, None, _coo_matmat_jvp_right)
|
||||
ad.primitive_transposes[coo_matmat_p] = _coo_matmat_transpose
|
||||
xla.register_translation(coo_matmat_p, _coo_matmat_translation_rule)
|
||||
mlir.register_lowering(coo_matmat_p, _coo_matmat_lowering)
|
||||
if sparse_apis and sparse_apis.is_supported:
|
||||
xla.register_translation(coo_matmat_p, _coo_matmat_gpu_translation_rule,
|
||||
platform='gpu')
|
||||
if jax._src.lib.version > (0, 3, 5):
|
||||
mlir.register_lowering(coo_matmat_p, _coo_matmat_gpu_lowering,
|
||||
platform='gpu')
|
||||
|
@ -22,11 +22,13 @@ import numpy as np
|
||||
|
||||
from jax import core
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax.experimental.sparse._base import JAXSparse
|
||||
from jax.experimental.sparse.coo import _coo_matmat, _coo_matvec, _coo_todense, COOInfo
|
||||
from jax.experimental.sparse.util import _csr_to_coo, _csr_extract, _safe_asarray, CuSparseEfficiencyWarning
|
||||
from jax import tree_util
|
||||
import jax._src.lib
|
||||
from jax._src.lib import sparse_apis
|
||||
from jax._src.numpy.lax_numpy import _promote_dtypes
|
||||
import jax.numpy as jnp
|
||||
@ -184,6 +186,21 @@ def _csr_todense_gpu_translation_rule(ctx, avals_in, avals_out, data, indices,
|
||||
indptr, shape=shape)
|
||||
return [sparse_apis.csr_todense(ctx.builder, data, indices, indptr, shape=shape)]
|
||||
|
||||
_csr_todense_lowering = mlir.lower_fun(
|
||||
_csr_todense_impl, multiple_results=False)
|
||||
|
||||
def _csr_todense_gpu_lowering(ctx, data, indices, indptr, *, shape):
|
||||
data_aval, indices_aval, _ = ctx.avals_in
|
||||
dtype = data_aval.dtype
|
||||
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
|
||||
warnings.warn(f"csr_todense cusparse/hipsparse lowering not available for dtype={dtype}. "
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _csr_todense_lowering(ctx, data, indices, indptr, shape=shape)
|
||||
return [sparse_apis.csr_todense_mhlo(
|
||||
data, indices, indptr, shape=shape, data_dtype=dtype,
|
||||
index_dtype=indices_aval.dtype)]
|
||||
|
||||
|
||||
def _csr_todense_jvp(data_dot, data, indices, indptr, *, shape):
|
||||
return csr_todense(data_dot, indices, indptr, shape=shape)
|
||||
|
||||
@ -201,9 +218,13 @@ def _csr_todense_transpose(ct, data, indices, indptr, *, shape):
|
||||
ad.defjvp(csr_todense_p, _csr_todense_jvp, None, None)
|
||||
ad.primitive_transposes[csr_todense_p] = _csr_todense_transpose
|
||||
xla.register_translation(csr_todense_p, _csr_todense_translation_rule)
|
||||
mlir.register_lowering(csr_todense_p, _csr_todense_lowering)
|
||||
if sparse_apis and sparse_apis.is_supported:
|
||||
xla.register_translation(csr_todense_p, _csr_todense_gpu_translation_rule,
|
||||
platform='gpu')
|
||||
if jax._src.lib.version > (0, 3, 5):
|
||||
mlir.register_lowering(csr_todense_p, _csr_todense_gpu_lowering,
|
||||
platform='gpu')
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# csr_fromdense
|
||||
@ -267,6 +288,21 @@ def _csr_fromdense_gpu_translation_rule(ctx, avals_in, avals_out, mat, *, nse,
|
||||
ctx.builder, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
|
||||
return [data, indices, indptr]
|
||||
|
||||
_csr_fromdense_lowering = mlir.lower_fun(_csr_fromdense_impl,
|
||||
multiple_results=True)
|
||||
|
||||
def _csr_fromdense_gpu_lowering(ctx, mat, *, nse, index_dtype):
|
||||
dtype = ctx.avals_in[0].dtype
|
||||
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
|
||||
warnings.warn(f"csr_fromdense cusparse/hipsparse lowering not available for dtype={dtype}. "
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _csr_fromdense_lowering(ctx, mat, nse=nse, index_dtype=index_dtype)
|
||||
data, indices, indptr = sparse_apis.csr_fromdense_mhlo(
|
||||
mat, nnz=nse, index_dtype=np.dtype(index_dtype),
|
||||
data_dtype=dtype, index_type=mlir.dtype_to_ir_type(np.dtype(index_dtype)))
|
||||
return [data, indices, indptr]
|
||||
|
||||
|
||||
def _csr_fromdense_jvp(primals, tangents, *, nse, index_dtype):
|
||||
M, = primals
|
||||
Mdot, = tangents
|
||||
@ -295,10 +331,15 @@ def _csr_fromdense_transpose(ct, M, *, nse, index_dtype):
|
||||
ad.primitive_jvps[csr_fromdense_p] = _csr_fromdense_jvp
|
||||
ad.primitive_transposes[csr_fromdense_p] = _csr_fromdense_transpose
|
||||
xla.register_translation(csr_fromdense_p, _csr_fromdense_translation_rule)
|
||||
mlir.register_lowering(csr_fromdense_p, _csr_fromdense_lowering)
|
||||
if sparse_apis and sparse_apis.is_supported:
|
||||
xla.register_translation(csr_fromdense_p,
|
||||
_csr_fromdense_gpu_translation_rule,
|
||||
platform='gpu')
|
||||
if jax._src.lib.version > (0, 3, 5):
|
||||
mlir.register_lowering(csr_fromdense_p,
|
||||
_csr_fromdense_gpu_lowering,
|
||||
platform='gpu')
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# csr_matvec
|
||||
@ -354,6 +395,22 @@ def _csr_matvec_gpu_translation_rule(ctx, avals_in, avals_out, data, indices,
|
||||
return [sparse_apis.csr_matvec(ctx.builder, data, indices, indptr, v,
|
||||
shape=shape, transpose=transpose)]
|
||||
|
||||
_csr_matvec_lowering = mlir.lower_fun(_csr_matvec_impl, multiple_results=False)
|
||||
|
||||
def _csr_matvec_gpu_lowering(ctx, data, indices, indptr, v, *,
|
||||
shape, transpose):
|
||||
data_aval, indices_aval, _, v_aval = ctx.avals_in
|
||||
dtype = data_aval.dtype
|
||||
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
|
||||
warnings.warn(f"csr_matvec cusparse/hipsparse lowering not available for dtype={dtype}. "
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _csr_matvec_lowering(ctx, data, indices, indptr, v, shape=shape,
|
||||
transpose=transpose)
|
||||
return [sparse_apis.csr_matvec_mhlo(
|
||||
data, indices, indptr, v, shape=shape, transpose=transpose,
|
||||
data_dtype=dtype, index_dtype=indices_aval.dtype, x_dtype=v_aval.dtype)]
|
||||
|
||||
|
||||
def _csr_matvec_jvp_mat(data_dot, data, indices, indptr, v, *, shape, transpose):
|
||||
return csr_matvec(data_dot, indices, indptr, v, shape=shape, transpose=transpose)
|
||||
|
||||
@ -376,9 +433,13 @@ def _csr_matvec_transpose(ct, data, indices, indptr, v, *, shape, transpose):
|
||||
ad.defjvp(csr_matvec_p, _csr_matvec_jvp_mat, None, None, _csr_matvec_jvp_vec)
|
||||
ad.primitive_transposes[csr_matvec_p] = _csr_matvec_transpose
|
||||
xla.register_translation(csr_matvec_p, _csr_matvec_translation_rule)
|
||||
mlir.register_lowering(csr_matvec_p, _csr_matvec_lowering)
|
||||
if sparse_apis and sparse_apis.is_supported:
|
||||
xla.register_translation(csr_matvec_p, _csr_matvec_gpu_translation_rule,
|
||||
platform='gpu')
|
||||
if jax._src.lib.version > (0, 3, 5):
|
||||
mlir.register_lowering(csr_matvec_p, _csr_matvec_gpu_lowering,
|
||||
platform='gpu')
|
||||
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
@ -436,6 +497,22 @@ def _csr_matmat_gpu_translation_rule(ctx, avals_in, avals_out, data, indices,
|
||||
return [sparse_apis.csr_matmat(ctx.builder, data, indices, indptr, B,
|
||||
shape=shape, transpose=transpose)]
|
||||
|
||||
_csr_matmat_lowering = mlir.lower_fun(_csr_matmat_impl, multiple_results=False)
|
||||
|
||||
def _csr_matmat_gpu_lowering(ctx, data, indices, indptr, B, *, shape, transpose):
|
||||
data_aval, indices_aval, _, B_aval = ctx.avals_in
|
||||
dtype = data_aval.dtype
|
||||
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
|
||||
warnings.warn(f"csr_matmat cusparse/hipsparse lowering not available for dtype={dtype}. "
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _csr_matmat_lowering(ctx, data, indices, indptr, B, shape=shape,
|
||||
transpose=transpose)
|
||||
return [sparse_apis.csr_matmat_mhlo(
|
||||
data, indices, indptr, B, shape=shape, transpose=transpose,
|
||||
index_dtype=indices_aval.dtype, data_dtype=data_aval.dtype,
|
||||
B_dtype=B_aval.dtype)]
|
||||
|
||||
|
||||
def _csr_matmat_jvp_left(data_dot, data, indices, indptr, B, *, shape, transpose):
|
||||
return csr_matmat(data_dot, indices, indptr, B, shape=shape, transpose=transpose)
|
||||
|
||||
@ -456,6 +533,10 @@ def _csr_matmat_transpose(ct, data, indices, indptr, B, *, shape, transpose):
|
||||
ad.defjvp(csr_matmat_p, _csr_matmat_jvp_left, None, None, _csr_matmat_jvp_right)
|
||||
ad.primitive_transposes[csr_matmat_p] = _csr_matmat_transpose
|
||||
xla.register_translation(csr_matmat_p, _csr_matmat_translation_rule)
|
||||
mlir.register_lowering(csr_matmat_p, _csr_matmat_lowering)
|
||||
if sparse_apis and sparse_apis.is_supported:
|
||||
xla.register_translation(csr_matmat_p, _csr_matmat_gpu_translation_rule,
|
||||
platform='gpu')
|
||||
if jax._src.lib.version > (0, 3, 5):
|
||||
mlir.register_lowering(csr_matmat_p, _csr_matmat_gpu_lowering,
|
||||
platform='gpu')
|
||||
|
@ -46,6 +46,18 @@ def _validate_csr(c, data, indices, indptr, shape):
|
||||
assert c.get_shape(indptr).dimensions() == (shape[0] + 1,)
|
||||
return data_dtype, index_dtype, nnz
|
||||
|
||||
def _validate_csr_mhlo(data, indices, indptr, shape):
|
||||
data_type = ir.RankedTensorType(data.type)
|
||||
indices_type = ir.RankedTensorType(indices.type)
|
||||
indptr_type = ir.RankedTensorType(indptr.type)
|
||||
|
||||
nnz, = data_type.shape
|
||||
assert indices_type.shape == [nnz]
|
||||
assert indptr_type.element_type == indices_type.element_type
|
||||
assert indptr_type.shape == [shape[0] + 1]
|
||||
return data_type.element_type, indices_type.element_type, nnz
|
||||
|
||||
|
||||
def _validate_coo(c, data, row, col, shape):
|
||||
data_dtype = np.dtype(c.get_shape(data).element_type())
|
||||
index_dtype = np.dtype(c.get_shape(row).element_type())
|
||||
@ -55,6 +67,17 @@ def _validate_coo(c, data, row, col, shape):
|
||||
assert c.get_shape(col).dimensions() == (nnz,)
|
||||
return data_dtype, index_dtype, nnz
|
||||
|
||||
def _validate_coo_mhlo(data, row, col, shape):
|
||||
data_type = ir.RankedTensorType(data.type)
|
||||
row_type = ir.RankedTensorType(row.type)
|
||||
col_type = ir.RankedTensorType(col.type)
|
||||
|
||||
nnz, = data_type.shape
|
||||
assert row_type.shape == [nnz]
|
||||
assert col_type.element_type == row_type.element_type
|
||||
assert col_type.shape == [nnz]
|
||||
return data_type.element_type, row_type.element_type, nnz
|
||||
|
||||
def csr_todense(c, data, indices, indptr, *, shape):
|
||||
"""CSR to dense matrix."""
|
||||
data_dtype, index_dtype, nnz = _validate_csr(c, data, indices, indptr, shape)
|
||||
@ -83,6 +106,38 @@ def csr_todense(c, data, indices, indptr, *, shape):
|
||||
return _ops.GetTupleElement(out, 0)
|
||||
|
||||
|
||||
def csr_todense_mhlo(data, indices, indptr, *, shape, data_dtype, index_dtype):
|
||||
"""CSR to dense matrix."""
|
||||
data_type, index_type, nnz = _validate_csr_mhlo(data, indices, indptr, shape)
|
||||
rows, cols = shape
|
||||
|
||||
buffer_size, opaque = _cusparse.build_csr_todense_descriptor(
|
||||
data_dtype, index_dtype, rows, cols, nnz)
|
||||
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
out = mhlo.CustomCallOp(
|
||||
[ir.TupleType.get_tuple([
|
||||
ir.RankedTensorType.get(shape, data_type),
|
||||
ir.RankedTensorType.get([buffer_size],
|
||||
ir.IntegerType.get_signless(8)),
|
||||
])],
|
||||
[data, indices, indptr],
|
||||
call_target_name=ir.StringAttr.get("cusparse_csr_todense"),
|
||||
has_side_effect=ir.BoolAttr.get(False),
|
||||
backend_config=ir.StringAttr.get(opaque),
|
||||
api_version=ir.IntegerAttr.get(i32_type, 2),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
operand_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
] * 3),
|
||||
result_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([1, 0]),
|
||||
type=ir.IndexType.get()),
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
]))
|
||||
return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
|
||||
|
||||
|
||||
def csr_fromdense(c, mat, *, nnz, index_dtype):
|
||||
"""CSR from dense matrix."""
|
||||
data_dtype = np.dtype(c.get_shape(mat).element_type())
|
||||
@ -113,7 +168,43 @@ def csr_fromdense(c, mat, *, nnz, index_dtype):
|
||||
return tuple(_ops.GetTupleElement(out, i) for i in range(3))
|
||||
|
||||
|
||||
def csr_matvec(c, data, indices, indptr, x, *, shape, transpose=False, compute_dtype=None):
|
||||
def csr_fromdense_mhlo(mat, *, nnz, index_dtype, data_dtype, index_type):
|
||||
"""CSR from dense matrix."""
|
||||
mat_type = ir.RankedTensorType(mat.type)
|
||||
rows, cols = mat_type.shape
|
||||
|
||||
buffer_size, opaque = _cusparse.build_csr_fromdense_descriptor(
|
||||
data_dtype, index_dtype, rows, cols, nnz)
|
||||
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
out = mhlo.CustomCallOp(
|
||||
[ir.TupleType.get_tuple([
|
||||
ir.RankedTensorType.get([nnz], mat_type.element_type),
|
||||
ir.RankedTensorType.get([nnz], index_type),
|
||||
ir.RankedTensorType.get([rows + 1], index_type),
|
||||
ir.RankedTensorType.get([buffer_size],
|
||||
ir.IntegerType.get_signless(8)),
|
||||
])],
|
||||
[mat],
|
||||
call_target_name=ir.StringAttr.get("cusparse_csr_fromdense"),
|
||||
has_side_effect=ir.BoolAttr.get(False),
|
||||
backend_config=ir.StringAttr.get(opaque),
|
||||
api_version=ir.IntegerAttr.get(i32_type, 2),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
operand_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([1, 0]),
|
||||
type=ir.IndexType.get()),
|
||||
]),
|
||||
result_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
] * 4))
|
||||
return [
|
||||
mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
def csr_matvec(c, data, indices, indptr, x, *, shape, transpose=False,
|
||||
compute_dtype=None):
|
||||
"""CSR matrix/vector multiply."""
|
||||
data_dtype, index_dtype, nnz = _validate_csr(c, data, indices, indptr, shape)
|
||||
rows, cols = shape
|
||||
@ -147,8 +238,46 @@ def csr_matvec(c, data, indices, indptr, x, *, shape, transpose=False, compute_d
|
||||
)
|
||||
return _ops.GetTupleElement(out, 0)
|
||||
|
||||
def csr_matvec_mhlo(data, indices, indptr, x, *, shape, transpose=False,
|
||||
compute_dtype=None, compute_type=None, data_dtype,
|
||||
index_dtype, x_dtype):
|
||||
"""CSR matrix/vector multiply."""
|
||||
data_type, index_type, nnz = _validate_csr_mhlo(data, indices, indptr, shape)
|
||||
rows, cols = shape
|
||||
|
||||
def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False, compute_dtype=None):
|
||||
if compute_dtype is None:
|
||||
compute_dtype = data_dtype
|
||||
compute_type = data_type
|
||||
|
||||
buffer_size, opaque = _cusparse.build_csr_matvec_descriptor(
|
||||
data_dtype, x_dtype, compute_dtype, index_dtype,
|
||||
rows, cols, nnz, transpose)
|
||||
out_size = cols if transpose else rows
|
||||
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
out = mhlo.CustomCallOp(
|
||||
[ir.TupleType.get_tuple([
|
||||
ir.RankedTensorType.get([out_size], compute_type),
|
||||
ir.RankedTensorType.get([buffer_size],
|
||||
ir.IntegerType.get_signless(8)),
|
||||
])],
|
||||
[data, indices, indptr, x],
|
||||
call_target_name=ir.StringAttr.get("cusparse_csr_matvec"),
|
||||
has_side_effect=ir.BoolAttr.get(False),
|
||||
backend_config=ir.StringAttr.get(opaque),
|
||||
api_version=ir.IntegerAttr.get(i32_type, 2),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
operand_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
] * 4),
|
||||
result_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
] * 2))
|
||||
return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
|
||||
|
||||
|
||||
def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False,
|
||||
compute_dtype=None):
|
||||
"""CSR from dense matrix."""
|
||||
data_dtype, index_dtype, nnz = _validate_csr(c, data, indices, indptr, shape)
|
||||
rows, cols = shape
|
||||
@ -183,6 +312,50 @@ def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False, compute_d
|
||||
)
|
||||
return _ops.GetTupleElement(out, 0)
|
||||
|
||||
def csr_matmat_mhlo(data, indices, indptr, B, *, shape, transpose=False,
|
||||
compute_dtype=None, compute_type=None, index_dtype,
|
||||
data_dtype, B_dtype):
|
||||
"""CSR from dense matrix."""
|
||||
data_type, index_type, nnz = _validate_csr_mhlo(data, indices, indptr, shape)
|
||||
rows, cols = shape
|
||||
B_shape = ir.RankedTensorType(B.type).shape
|
||||
_, Ccols = B_shape
|
||||
|
||||
if compute_dtype is None:
|
||||
compute_dtype = data_dtype
|
||||
compute_type = data_type
|
||||
|
||||
buffer_size, opaque = _cusparse.build_csr_matmat_descriptor(
|
||||
data_dtype, B_dtype, compute_dtype, index_dtype,
|
||||
rows, cols, Ccols, nnz, transpose)
|
||||
out_size = cols if transpose else rows
|
||||
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
out = mhlo.CustomCallOp(
|
||||
[ir.TupleType.get_tuple([
|
||||
ir.RankedTensorType.get([out_size, Ccols], compute_type),
|
||||
ir.RankedTensorType.get([buffer_size],
|
||||
ir.IntegerType.get_signless(8)),
|
||||
])],
|
||||
[data, indices, indptr, B],
|
||||
call_target_name=ir.StringAttr.get("cusparse_csr_matmat"),
|
||||
has_side_effect=ir.BoolAttr.get(False),
|
||||
backend_config=ir.StringAttr.get(opaque),
|
||||
api_version=ir.IntegerAttr.get(i32_type, 2),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
operand_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
ir.DenseIntElementsAttr.get(np.array([1, 0]),
|
||||
type=ir.IndexType.get()),
|
||||
]),
|
||||
result_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([1, 0]), type=ir.IndexType.get()),
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
]))
|
||||
return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
|
||||
|
||||
|
||||
def coo_todense(c, data, row, col, *, shape):
|
||||
"""COO to dense matrix."""
|
||||
@ -211,6 +384,37 @@ def coo_todense(c, data, row, col, *, shape):
|
||||
)
|
||||
return _ops.GetTupleElement(out, 0)
|
||||
|
||||
def coo_todense_mhlo(data, row, col, *, shape, data_dtype, index_dtype):
|
||||
"""COO to dense matrix."""
|
||||
data_type, _, nnz = _validate_coo_mhlo(data, row, col, shape)
|
||||
rows, cols = shape
|
||||
|
||||
buffer_size, opaque = _cusparse.build_coo_todense_descriptor(
|
||||
data_dtype, index_dtype, rows, cols, nnz)
|
||||
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
out = mhlo.CustomCallOp(
|
||||
[ir.TupleType.get_tuple([
|
||||
ir.RankedTensorType.get(shape, data_type),
|
||||
ir.RankedTensorType.get([buffer_size],
|
||||
ir.IntegerType.get_signless(8)),
|
||||
])],
|
||||
[data, row, col],
|
||||
call_target_name=ir.StringAttr.get("cusparse_coo_todense"),
|
||||
has_side_effect=ir.BoolAttr.get(False),
|
||||
backend_config=ir.StringAttr.get(opaque),
|
||||
api_version=ir.IntegerAttr.get(i32_type, 2),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
operand_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
] * 3),
|
||||
result_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([1, 0]),
|
||||
type=ir.IndexType.get()),
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
]))
|
||||
return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
|
||||
|
||||
|
||||
def coo_fromdense(c, mat, *, nnz, index_dtype):
|
||||
"""COO from dense matrix."""
|
||||
@ -241,7 +445,44 @@ def coo_fromdense(c, mat, *, nnz, index_dtype):
|
||||
|
||||
return tuple(_ops.GetTupleElement(out, i) for i in range(3))
|
||||
|
||||
def coo_matvec(c, data, row, col, x, *, shape, transpose=False, compute_dtype=None):
|
||||
def coo_fromdense_mhlo(mat, *, nnz, data_dtype, index_dtype,
|
||||
index_type):
|
||||
"""COO from dense matrix."""
|
||||
mat_type = ir.RankedTensorType(mat.type)
|
||||
rows, cols = mat_type.shape
|
||||
|
||||
buffer_size, opaque = _cusparse.build_coo_fromdense_descriptor(
|
||||
data_dtype, index_dtype, rows, cols, nnz)
|
||||
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
out = mhlo.CustomCallOp(
|
||||
[ir.TupleType.get_tuple([
|
||||
ir.RankedTensorType.get([nnz], mat_type.element_type),
|
||||
ir.RankedTensorType.get([nnz], index_type),
|
||||
ir.RankedTensorType.get([nnz], index_type),
|
||||
ir.RankedTensorType.get([buffer_size],
|
||||
ir.IntegerType.get_signless(8)),
|
||||
])],
|
||||
[mat],
|
||||
call_target_name=ir.StringAttr.get("cusparse_coo_fromdense"),
|
||||
has_side_effect=ir.BoolAttr.get(False),
|
||||
backend_config=ir.StringAttr.get(opaque),
|
||||
api_version=ir.IntegerAttr.get(i32_type, 2),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
operand_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([1, 0]),
|
||||
type=ir.IndexType.get()),
|
||||
]),
|
||||
result_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
] * 4))
|
||||
return [
|
||||
mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
def coo_matvec(c, data, row, col, x, *, shape, transpose=False,
|
||||
compute_dtype=None):
|
||||
"""COO matrix/vector multiply."""
|
||||
data_dtype, index_dtype, nnz = _validate_coo(c, data, row, col, shape)
|
||||
rows, cols = shape
|
||||
@ -276,7 +517,46 @@ def coo_matvec(c, data, row, col, x, *, shape, transpose=False, compute_dtype=No
|
||||
return _ops.GetTupleElement(out, 0)
|
||||
|
||||
|
||||
def coo_matmat(c, data, row, col, B, *, shape, transpose=False, compute_dtype=None):
|
||||
def coo_matvec_mhlo(data, row, col, x, *, shape, transpose=False,
|
||||
compute_dtype=None,
|
||||
compute_type=None, index_dtype, data_dtype, x_dtype):
|
||||
"""COO matrix/vector multiply."""
|
||||
data_type, index_type, nnz = _validate_coo_mhlo(data, row, col, shape)
|
||||
rows, cols = shape
|
||||
|
||||
if compute_dtype is None:
|
||||
compute_dtype = data_dtype
|
||||
compute_type = data_type
|
||||
|
||||
buffer_size, opaque = _cusparse.build_coo_matvec_descriptor(
|
||||
data_dtype, x_dtype, compute_dtype, index_dtype,
|
||||
rows, cols, nnz, transpose)
|
||||
out_size = cols if transpose else rows
|
||||
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
out = mhlo.CustomCallOp(
|
||||
[ir.TupleType.get_tuple([
|
||||
ir.RankedTensorType.get([out_size], compute_type),
|
||||
ir.RankedTensorType.get([buffer_size],
|
||||
ir.IntegerType.get_signless(8)),
|
||||
])],
|
||||
[data, row, col, x],
|
||||
call_target_name=ir.StringAttr.get("cusparse_coo_matvec"),
|
||||
has_side_effect=ir.BoolAttr.get(False),
|
||||
backend_config=ir.StringAttr.get(opaque),
|
||||
api_version=ir.IntegerAttr.get(i32_type, 2),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
operand_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
] * 4),
|
||||
result_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
] * 2))
|
||||
return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
|
||||
|
||||
|
||||
def coo_matmat(c, data, row, col, B, *, shape, transpose=False,
|
||||
compute_dtype=None):
|
||||
"""COO from dense matrix."""
|
||||
data_dtype, index_dtype, nnz = _validate_coo(c, data, row, col, shape)
|
||||
rows, cols = shape
|
||||
@ -311,6 +591,51 @@ def coo_matmat(c, data, row, col, B, *, shape, transpose=False, compute_dtype=No
|
||||
)
|
||||
return _ops.GetTupleElement(out, 0)
|
||||
|
||||
def coo_matmat_mhlo(data, row, col, B, *, shape, transpose=False,
|
||||
compute_dtype=None, compute_type=None, x_dtype,
|
||||
data_dtype, index_dtype):
|
||||
"""COO from dense matrix."""
|
||||
data_type, index_type, nnz = _validate_coo_mhlo(data, row, col, shape)
|
||||
rows, cols = shape
|
||||
B_shape = ir.RankedTensorType(B.type).shape
|
||||
_, Ccols = B_shape
|
||||
|
||||
if compute_dtype is None:
|
||||
compute_dtype = data_dtype
|
||||
compute_type = data_type
|
||||
|
||||
buffer_size, opaque = _cusparse.build_coo_matmat_descriptor(
|
||||
data_dtype, x_dtype, compute_dtype, index_dtype,
|
||||
rows, cols, Ccols, nnz, transpose)
|
||||
out_size = cols if transpose else rows
|
||||
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
out = mhlo.CustomCallOp(
|
||||
[ir.TupleType.get_tuple([
|
||||
ir.RankedTensorType.get([out_size, Ccols], compute_type),
|
||||
ir.RankedTensorType.get([buffer_size],
|
||||
ir.IntegerType.get_signless(8)),
|
||||
])],
|
||||
[data, row, col, B],
|
||||
call_target_name=ir.StringAttr.get("cusparse_coo_matmat"),
|
||||
has_side_effect=ir.BoolAttr.get(False),
|
||||
backend_config=ir.StringAttr.get(opaque),
|
||||
api_version=ir.IntegerAttr.get(i32_type, 2),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
operand_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
ir.DenseIntElementsAttr.get(np.array([1, 0]),
|
||||
type=ir.IndexType.get()),
|
||||
]),
|
||||
result_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([1, 0]),
|
||||
type=ir.IndexType.get()),
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
]))
|
||||
return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
|
||||
|
||||
|
||||
def gtsv2(c, dl, d, du, B, *, m, n, ldb, t):
|
||||
"""Calls `cusparse<t>gtsv2(dl, d, du, B, m, n, ldb)`."""
|
||||
@ -347,7 +672,8 @@ def gtsv2_mhlo(dl, d, du, B, *, m, n, ldb, t):
|
||||
[ir.TupleType.get_tuple([
|
||||
ir.RankedTensorType.get(
|
||||
[ldb, n], ir.F32Type.get() if f32 else ir.F64Type.get()),
|
||||
ir.RankedTensorType.get([buffer_size], ir.IntegerType.get_signless(8)),
|
||||
ir.RankedTensorType.get([buffer_size],
|
||||
ir.IntegerType.get_signless(8)),
|
||||
])],
|
||||
[dl, d, du, B],
|
||||
call_target_name = ir.StringAttr.get(
|
||||
@ -363,7 +689,8 @@ def gtsv2_mhlo(dl, d, du, B, *, m, n, ldb, t):
|
||||
ir.DenseIntElementsAttr.get(np.array([1, 0]), type=ir.IndexType.get())
|
||||
]),
|
||||
result_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([1, 0]), type=ir.IndexType.get()),
|
||||
ir.DenseIntElementsAttr.get(np.array([1, 0]),
|
||||
type=ir.IndexType.get()),
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
]))
|
||||
return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
|
||||
|
@ -46,6 +46,17 @@ def _validate_csr(c, data, indices, indptr, shape):
|
||||
assert c.get_shape(indptr).dimensions() == (shape[0] + 1,)
|
||||
return data_dtype, index_dtype, nnz
|
||||
|
||||
def _validate_csr_mhlo(data, indices, indptr, shape):
|
||||
data_type = ir.RankedTensorType(data.type)
|
||||
indices_type = ir.RankedTensorType(indices.type)
|
||||
indptr_type = ir.RankedTensorType(indptr.type)
|
||||
|
||||
nnz, = data_type.shape
|
||||
assert indices_type.shape == [nnz]
|
||||
assert indptr_type.element_type == indices_type.element_type
|
||||
assert indptr_type.shape == [shape[0] + 1]
|
||||
return data_type.element_type, indices_type.element_type, nnz
|
||||
|
||||
def _validate_coo(c, data, row, col, shape):
|
||||
data_dtype = np.dtype(c.get_shape(data).element_type())
|
||||
index_dtype = np.dtype(c.get_shape(row).element_type())
|
||||
@ -55,6 +66,17 @@ def _validate_coo(c, data, row, col, shape):
|
||||
assert c.get_shape(col).dimensions() == (nnz,)
|
||||
return data_dtype, index_dtype, nnz
|
||||
|
||||
def _validate_coo_mhlo(data, row, col, shape):
|
||||
data_type = ir.RankedTensorType(data.type)
|
||||
row_type = ir.RankedTensorType(row.type)
|
||||
col_type = ir.RankedTensorType(col.type)
|
||||
|
||||
nnz, = data_type.shape
|
||||
assert row_type.shape == [nnz]
|
||||
assert col_type.element_type == row_type.element_type
|
||||
assert col_type.shape == [nnz]
|
||||
return data_type.element_type, row_type.element_type, nnz
|
||||
|
||||
def csr_todense(c, data, indices, indptr, *, shape):
|
||||
"""CSR to dense matrix."""
|
||||
data_dtype, index_dtype, nnz = _validate_csr(c, data, indices, indptr, shape)
|
||||
@ -83,6 +105,39 @@ def csr_todense(c, data, indices, indptr, *, shape):
|
||||
return _ops.GetTupleElement(out, 0)
|
||||
|
||||
|
||||
def csr_todense_mhlo(data, indices, indptr, *, shape, data_dtype, index_dtype):
|
||||
"""CSR to dense matrix."""
|
||||
data_type, index_type, nnz = _validate_csr_mhlo(data, indices, indptr, shape)
|
||||
rows, cols = shape
|
||||
|
||||
buffer_size, opaque = _hipsparse.build_csr_todense_descriptor(
|
||||
data_dtype, index_dtype, rows, cols, nnz)
|
||||
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
out = mhlo.CustomCallOp(
|
||||
[ir.TupleType.get_tuple([
|
||||
ir.RankedTensorType.get(shape, data_type),
|
||||
ir.RankedTensorType.get([buffer_size],
|
||||
ir.IntegerType.get_signless(8)),
|
||||
])],
|
||||
[data, indices, indptr],
|
||||
call_target_name=ir.StringAttr.get("hipsparse_csr_todense"),
|
||||
has_side_effect=ir.BoolAttr.get(False),
|
||||
backend_config=ir.StringAttr.get(opaque),
|
||||
api_version=ir.IntegerAttr.get(i32_type, 2),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
operand_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
] * 3),
|
||||
result_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([1, 0]),
|
||||
type=ir.IndexType.get()),
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
]))
|
||||
return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
|
||||
|
||||
|
||||
|
||||
def csr_fromdense(c, mat, *, nnz, index_dtype):
|
||||
"""CSR from dense matrix."""
|
||||
data_dtype = np.dtype(c.get_shape(mat).element_type())
|
||||
@ -112,6 +167,40 @@ def csr_fromdense(c, mat, *, nnz, index_dtype):
|
||||
|
||||
return tuple(_ops.GetTupleElement(out, i) for i in range(3))
|
||||
|
||||
def csr_fromdense_mhlo(mat, *, nnz, index_dtype, data_dtype, index_type):
|
||||
"""CSR from dense matrix."""
|
||||
mat_type = ir.RankedTensorType(mat.type)
|
||||
rows, cols = mat_type.shape
|
||||
|
||||
buffer_size, opaque = _hipsparse.build_csr_fromdense_descriptor(
|
||||
data_dtype, index_dtype, rows, cols, nnz)
|
||||
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
out = mhlo.CustomCallOp(
|
||||
[ir.TupleType.get_tuple([
|
||||
ir.RankedTensorType.get([nnz], mat_type.element_type),
|
||||
ir.RankedTensorType.get([nnz], index_type),
|
||||
ir.RankedTensorType.get([rows + 1], index_type),
|
||||
ir.RankedTensorType.get([buffer_size],
|
||||
ir.IntegerType.get_signless(8)),
|
||||
])],
|
||||
[mat],
|
||||
call_target_name=ir.StringAttr.get("hipsparse_csr_fromdense"),
|
||||
has_side_effect=ir.BoolAttr.get(False),
|
||||
backend_config=ir.StringAttr.get(opaque),
|
||||
api_version=ir.IntegerAttr.get(i32_type, 2),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
operand_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([1, 0]),
|
||||
type=ir.IndexType.get()),
|
||||
]),
|
||||
result_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
] * 4))
|
||||
return [
|
||||
mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
def csr_matvec(c, data, indices, indptr, x, *, shape, transpose=False, compute_dtype=None):
|
||||
"""CSR matrix/vector multiply."""
|
||||
@ -147,6 +236,43 @@ def csr_matvec(c, data, indices, indptr, x, *, shape, transpose=False, compute_d
|
||||
)
|
||||
return _ops.GetTupleElement(out, 0)
|
||||
|
||||
def csr_matvec_mhlo(data, indices, indptr, x, *, shape, transpose=False,
|
||||
compute_dtype=None, compute_type=None, data_dtype,
|
||||
index_dtype, x_dtype):
|
||||
"""CSR matrix/vector multiply."""
|
||||
data_type, index_type, nnz = _validate_csr_mhlo(data, indices, indptr, shape)
|
||||
rows, cols = shape
|
||||
|
||||
if compute_dtype is None:
|
||||
compute_dtype = data_dtype
|
||||
compute_type = data_type
|
||||
|
||||
buffer_size, opaque = _hipsparse.build_csr_matvec_descriptor(
|
||||
data_dtype, x_dtype, compute_dtype, index_dtype,
|
||||
rows, cols, nnz, transpose)
|
||||
out_size = cols if transpose else rows
|
||||
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
out = mhlo.CustomCallOp(
|
||||
[ir.TupleType.get_tuple([
|
||||
ir.RankedTensorType.get([out_size], compute_type),
|
||||
ir.RankedTensorType.get([buffer_size],
|
||||
ir.IntegerType.get_signless(8)),
|
||||
])],
|
||||
[data, indices, indptr, x],
|
||||
call_target_name=ir.StringAttr.get("hipsparse_csr_matvec"),
|
||||
has_side_effect=ir.BoolAttr.get(False),
|
||||
backend_config=ir.StringAttr.get(opaque),
|
||||
api_version=ir.IntegerAttr.get(i32_type, 2),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
operand_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
] * 4),
|
||||
result_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
] * 2))
|
||||
return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
|
||||
|
||||
|
||||
def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False, compute_dtype=None):
|
||||
"""CSR from dense matrix."""
|
||||
@ -183,6 +309,50 @@ def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False, compute_d
|
||||
)
|
||||
return _ops.GetTupleElement(out, 0)
|
||||
|
||||
def csr_matmat_mhlo(data, indices, indptr, B, *, shape, transpose=False,
|
||||
compute_dtype=None, compute_type=None, index_dtype,
|
||||
data_dtype, B_dtype):
|
||||
"""CSR from dense matrix."""
|
||||
data_type, index_type, nnz = _validate_csr_mhlo(data, indices, indptr, shape)
|
||||
rows, cols = shape
|
||||
B_shape = ir.RankedTensorType(B.type).shape
|
||||
_, Ccols = B_shape
|
||||
|
||||
if compute_dtype is None:
|
||||
compute_dtype = data_dtype
|
||||
compute_type = data_type
|
||||
|
||||
buffer_size, opaque = _hipsparse.build_csr_matmat_descriptor(
|
||||
data_dtype, B_dtype, compute_dtype, index_dtype,
|
||||
rows, cols, Ccols, nnz, transpose)
|
||||
out_size = cols if transpose else rows
|
||||
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
out = mhlo.CustomCallOp(
|
||||
[ir.TupleType.get_tuple([
|
||||
ir.RankedTensorType.get([out_size, Ccols], compute_type),
|
||||
ir.RankedTensorType.get([buffer_size],
|
||||
ir.IntegerType.get_signless(8)),
|
||||
])],
|
||||
[data, indices, indptr, B],
|
||||
call_target_name=ir.StringAttr.get("hipsparse_csr_matmat"),
|
||||
has_side_effect=ir.BoolAttr.get(False),
|
||||
backend_config=ir.StringAttr.get(opaque),
|
||||
api_version=ir.IntegerAttr.get(i32_type, 2),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
operand_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
ir.DenseIntElementsAttr.get(np.array([1, 0]),
|
||||
type=ir.IndexType.get()),
|
||||
]),
|
||||
result_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([1, 0]), type=ir.IndexType.get()),
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
]))
|
||||
return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
|
||||
|
||||
|
||||
def coo_todense(c, data, row, col, *, shape):
|
||||
"""COO to dense matrix."""
|
||||
@ -211,6 +381,36 @@ def coo_todense(c, data, row, col, *, shape):
|
||||
)
|
||||
return _ops.GetTupleElement(out, 0)
|
||||
|
||||
def coo_todense_mhlo(data, row, col, *, shape, data_dtype, index_dtype):
|
||||
"""COO to dense matrix."""
|
||||
data_type, _, nnz = _validate_coo_mhlo(data, row, col, shape)
|
||||
rows, cols = shape
|
||||
|
||||
buffer_size, opaque = _hipsparse.build_coo_todense_descriptor(
|
||||
data_dtype, index_dtype, rows, cols, nnz)
|
||||
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
out = mhlo.CustomCallOp(
|
||||
[ir.TupleType.get_tuple([
|
||||
ir.RankedTensorType.get(shape, data_type),
|
||||
ir.RankedTensorType.get([buffer_size],
|
||||
ir.IntegerType.get_signless(8)),
|
||||
])],
|
||||
[data, row, col],
|
||||
call_target_name=ir.StringAttr.get("hipsparse_coo_todense"),
|
||||
has_side_effect=ir.BoolAttr.get(False),
|
||||
backend_config=ir.StringAttr.get(opaque),
|
||||
api_version=ir.IntegerAttr.get(i32_type, 2),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
operand_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
] * 3),
|
||||
result_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([1, 0]),
|
||||
type=ir.IndexType.get()),
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
]))
|
||||
return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
|
||||
|
||||
def coo_fromdense(c, mat, *, nnz, index_dtype):
|
||||
"""COO from dense matrix."""
|
||||
@ -241,6 +441,42 @@ def coo_fromdense(c, mat, *, nnz, index_dtype):
|
||||
|
||||
return tuple(_ops.GetTupleElement(out, i) for i in range(3))
|
||||
|
||||
def coo_fromdense_mhlo(mat, *, nnz, data_dtype, index_dtype,
|
||||
index_type):
|
||||
"""COO from dense matrix."""
|
||||
mat_type = ir.RankedTensorType(mat.type)
|
||||
rows, cols = mat_type.shape
|
||||
|
||||
buffer_size, opaque = _hipsparse.build_coo_fromdense_descriptor(
|
||||
data_dtype, index_dtype, rows, cols, nnz)
|
||||
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
out = mhlo.CustomCallOp(
|
||||
[ir.TupleType.get_tuple([
|
||||
ir.RankedTensorType.get([nnz], mat_type.element_type),
|
||||
ir.RankedTensorType.get([nnz], index_type),
|
||||
ir.RankedTensorType.get([nnz], index_type),
|
||||
ir.RankedTensorType.get([buffer_size],
|
||||
ir.IntegerType.get_signless(8)),
|
||||
])],
|
||||
[mat],
|
||||
call_target_name=ir.StringAttr.get("hipsparse_coo_fromdense"),
|
||||
has_side_effect=ir.BoolAttr.get(False),
|
||||
backend_config=ir.StringAttr.get(opaque),
|
||||
api_version=ir.IntegerAttr.get(i32_type, 2),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
operand_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([1, 0]),
|
||||
type=ir.IndexType.get()),
|
||||
]),
|
||||
result_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
] * 4))
|
||||
return [
|
||||
mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
def coo_matvec(c, data, row, col, x, *, shape, transpose=False, compute_dtype=None):
|
||||
"""COO matrix/vector multiply."""
|
||||
data_dtype, index_dtype, nnz = _validate_coo(c, data, row, col, shape)
|
||||
@ -275,6 +511,43 @@ def coo_matvec(c, data, row, col, x, *, shape, transpose=False, compute_dtype=No
|
||||
)
|
||||
return _ops.GetTupleElement(out, 0)
|
||||
|
||||
def coo_matvec_mhlo(data, row, col, x, *, shape, transpose=False,
|
||||
compute_dtype=None,
|
||||
compute_type=None, index_dtype, data_dtype, x_dtype):
|
||||
"""COO matrix/vector multiply."""
|
||||
data_type, index_type, nnz = _validate_coo_mhlo(data, row, col, shape)
|
||||
rows, cols = shape
|
||||
|
||||
if compute_dtype is None:
|
||||
compute_dtype = data_dtype
|
||||
compute_type = data_type
|
||||
|
||||
buffer_size, opaque = _hipsparse.build_coo_matvec_descriptor(
|
||||
data_dtype, x_dtype, compute_dtype, index_dtype,
|
||||
rows, cols, nnz, transpose)
|
||||
out_size = cols if transpose else rows
|
||||
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
out = mhlo.CustomCallOp(
|
||||
[ir.TupleType.get_tuple([
|
||||
ir.RankedTensorType.get([out_size], compute_type),
|
||||
ir.RankedTensorType.get([buffer_size],
|
||||
ir.IntegerType.get_signless(8)),
|
||||
])],
|
||||
[data, row, col, x],
|
||||
call_target_name=ir.StringAttr.get("hipsparse_coo_matvec"),
|
||||
has_side_effect=ir.BoolAttr.get(False),
|
||||
backend_config=ir.StringAttr.get(opaque),
|
||||
api_version=ir.IntegerAttr.get(i32_type, 2),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
operand_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
] * 4),
|
||||
result_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
] * 2))
|
||||
return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
|
||||
|
||||
|
||||
def coo_matmat(c, data, row, col, B, *, shape, transpose=False, compute_dtype=None):
|
||||
"""COO from dense matrix."""
|
||||
@ -311,6 +584,51 @@ def coo_matmat(c, data, row, col, B, *, shape, transpose=False, compute_dtype=No
|
||||
)
|
||||
return _ops.GetTupleElement(out, 0)
|
||||
|
||||
def coo_matmat_mhlo(data, row, col, B, *, shape, transpose=False,
|
||||
compute_dtype=None, compute_type=None, x_dtype,
|
||||
data_dtype, index_dtype):
|
||||
"""COO from dense matrix."""
|
||||
data_type, index_type, nnz = _validate_coo_mhlo(data, row, col, shape)
|
||||
rows, cols = shape
|
||||
B_shape = ir.RankedTensorType(B.type).shape
|
||||
_, Ccols = B_shape
|
||||
|
||||
if compute_dtype is None:
|
||||
compute_dtype = data_dtype
|
||||
compute_type = data_type
|
||||
|
||||
buffer_size, opaque = _hipsparse.build_coo_matmat_descriptor(
|
||||
data_dtype, x_dtype, compute_dtype, index_dtype,
|
||||
rows, cols, Ccols, nnz, transpose)
|
||||
out_size = cols if transpose else rows
|
||||
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
out = mhlo.CustomCallOp(
|
||||
[ir.TupleType.get_tuple([
|
||||
ir.RankedTensorType.get([out_size, Ccols], compute_type),
|
||||
ir.RankedTensorType.get([buffer_size],
|
||||
ir.IntegerType.get_signless(8)),
|
||||
])],
|
||||
[data, row, col, B],
|
||||
call_target_name=ir.StringAttr.get("hipsparse_coo_matmat"),
|
||||
has_side_effect=ir.BoolAttr.get(False),
|
||||
backend_config=ir.StringAttr.get(opaque),
|
||||
api_version=ir.IntegerAttr.get(i32_type, 2),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
operand_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
ir.DenseIntElementsAttr.get(np.array([1, 0]),
|
||||
type=ir.IndexType.get()),
|
||||
]),
|
||||
result_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(np.array([1, 0]),
|
||||
type=ir.IndexType.get()),
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
]))
|
||||
return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
|
||||
|
||||
|
||||
def gtsv2(c, dl, d, du, B, *, m, n, ldb, t):
|
||||
"""Calls `hipsparse<t>gtsv2(dl, d, du, B, m, n, ldb)`."""
|
||||
|
@ -236,6 +236,9 @@ def _split_abstract_eval(mat):
|
||||
xla.register_translation(
|
||||
split_p, xla.lower_fun(_split_impl, multiple_results=True, new_style=True))
|
||||
|
||||
mlir.register_lowering(
|
||||
split_p, mlir.lower_fun(_split_impl, multiple_results=True))
|
||||
|
||||
def make_sparse_array(rng, shape, dtype, nnz=0.2):
|
||||
mat = rng(shape, dtype)
|
||||
size = int(np.prod(shape))
|
||||
|
Loading…
x
Reference in New Issue
Block a user