2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2019 The JAX Authors.
|
2021-04-15 10:10:40 -07:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
"""
|
|
|
|
cusparse wrappers for performing sparse matrix computations in JAX
|
|
|
|
"""
|
|
|
|
|
2023-08-10 16:25:23 -07:00
|
|
|
import math
|
2022-05-06 09:34:25 -07:00
|
|
|
from functools import partial
|
2025-03-10 08:17:07 -07:00
|
|
|
from typing import Any
|
2022-05-06 09:34:25 -07:00
|
|
|
|
2022-04-06 13:56:01 -07:00
|
|
|
import jaxlib.mlir.ir as ir
|
|
|
|
|
2021-04-15 10:10:40 -07:00
|
|
|
import numpy as np
|
|
|
|
|
2023-08-10 16:25:23 -07:00
|
|
|
from .hlo_helpers import custom_call, mk_result_types_and_shapes
|
2022-05-06 14:50:54 -07:00
|
|
|
|
2025-02-27 11:51:39 -08:00
|
|
|
from .plugin_support import import_from_plugin
|
|
|
|
|
|
|
|
_cusparse = import_from_plugin("cuda", "_sparse")
|
|
|
|
_hipsparse = import_from_plugin("rocm", "_sparse")
|
2023-11-06 09:05:08 -08:00
|
|
|
|
2025-03-10 08:17:07 -07:00
|
|
|
def registrations() -> dict[str, list[tuple[str, Any, int]]]:
|
|
|
|
registrations = {"CUDA": [], "ROCM": []}
|
|
|
|
for platform, module in [("CUDA", _cusparse), ("ROCM", _hipsparse)]:
|
|
|
|
if module:
|
|
|
|
registrations[platform].extend(
|
|
|
|
(name, value, int(name.endswith("_ffi")))
|
|
|
|
for name, value in module.registrations().items())
|
|
|
|
return registrations # pytype: disable=bad-return-type
|
2021-04-15 10:10:40 -07:00
|
|
|
|
2022-05-06 09:34:25 -07:00
|
|
|
|
2024-01-19 09:10:13 -08:00
|
|
|
cuda_is_supported = bool(_cusparse and _cusparse.sparse_supported)
|
|
|
|
rocm_is_supported = bool(_hipsparse and _hipsparse.sparse_supported)
|
2021-04-15 10:10:40 -07:00
|
|
|
|
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
def _validate_csr_hlo(data, indices, indptr, shape):
|
2022-04-08 08:43:23 -07:00
|
|
|
data_type = ir.RankedTensorType(data.type)
|
|
|
|
indices_type = ir.RankedTensorType(indices.type)
|
|
|
|
indptr_type = ir.RankedTensorType(indptr.type)
|
|
|
|
|
|
|
|
nnz, = data_type.shape
|
|
|
|
assert indices_type.shape == [nnz]
|
|
|
|
assert indptr_type.element_type == indices_type.element_type
|
|
|
|
assert indptr_type.shape == [shape[0] + 1]
|
|
|
|
return data_type.element_type, indices_type.element_type, nnz
|
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
def _validate_coo_hlo(data, row, col):
|
2022-04-08 08:43:23 -07:00
|
|
|
data_type = ir.RankedTensorType(data.type)
|
|
|
|
row_type = ir.RankedTensorType(row.type)
|
|
|
|
col_type = ir.RankedTensorType(col.type)
|
|
|
|
|
|
|
|
nnz, = data_type.shape
|
|
|
|
assert row_type.shape == [nnz]
|
|
|
|
assert col_type.element_type == row_type.element_type
|
|
|
|
assert col_type.shape == [nnz]
|
|
|
|
return data_type.element_type, row_type.element_type, nnz
|
|
|
|
|
2021-04-15 10:10:40 -07:00
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
def _csr_todense_hlo(platform, gpu_sparse, data, indices, indptr, *, shape,
|
|
|
|
data_dtype, index_dtype):
|
2022-04-08 08:43:23 -07:00
|
|
|
"""CSR to dense matrix."""
|
2022-12-15 20:59:34 -08:00
|
|
|
data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape)
|
2022-04-08 08:43:23 -07:00
|
|
|
rows, cols = shape
|
|
|
|
|
2022-05-06 09:34:25 -07:00
|
|
|
buffer_size, opaque = gpu_sparse.build_csr_todense_descriptor(
|
2022-04-08 08:43:23 -07:00
|
|
|
data_dtype, index_dtype, rows, cols, nnz)
|
|
|
|
|
2022-05-06 14:50:54 -07:00
|
|
|
out = custom_call(
|
2025-02-06 11:45:14 -08:00
|
|
|
f"{platform}sparse_csr_todense_ffi",
|
2023-09-03 13:06:50 -07:00
|
|
|
result_types=[
|
2022-04-08 08:43:23 -07:00
|
|
|
ir.RankedTensorType.get(shape, data_type),
|
|
|
|
ir.RankedTensorType.get([buffer_size],
|
|
|
|
ir.IntegerType.get_signless(8)),
|
2022-05-06 14:50:54 -07:00
|
|
|
],
|
2023-09-03 13:06:50 -07:00
|
|
|
operands=[data, indices, indptr],
|
2025-02-06 11:45:14 -08:00
|
|
|
backend_config={"opaque": ir.StringAttr.get(opaque)},
|
|
|
|
api_version=4,
|
2022-05-06 14:50:54 -07:00
|
|
|
operand_layouts=[[0]] * 3,
|
2023-09-03 13:06:50 -07:00
|
|
|
result_layouts=[[1, 0], [0]]).results
|
2022-05-06 14:50:54 -07:00
|
|
|
return out[0]
|
2022-04-08 08:43:23 -07:00
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
cuda_csr_todense = partial(_csr_todense_hlo, "cu", _cusparse)
|
|
|
|
rocm_csr_todense = partial(_csr_todense_hlo, "hip", _hipsparse)
|
2022-05-06 09:34:25 -07:00
|
|
|
|
2022-04-08 08:43:23 -07:00
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
def _csr_fromdense_hlo(platform, gpu_sparse, mat, *, nnz, index_dtype,
|
|
|
|
data_dtype, index_type):
|
2022-04-08 08:43:23 -07:00
|
|
|
"""CSR from dense matrix."""
|
|
|
|
mat_type = ir.RankedTensorType(mat.type)
|
|
|
|
rows, cols = mat_type.shape
|
|
|
|
|
2022-05-06 09:34:25 -07:00
|
|
|
buffer_size, opaque = gpu_sparse.build_csr_fromdense_descriptor(
|
2022-04-08 08:43:23 -07:00
|
|
|
data_dtype, index_dtype, rows, cols, nnz)
|
|
|
|
|
2022-05-06 14:50:54 -07:00
|
|
|
out = custom_call(
|
2025-02-06 11:45:14 -08:00
|
|
|
f"{platform}sparse_csr_fromdense_ffi",
|
2023-09-03 13:06:50 -07:00
|
|
|
result_types=[
|
2022-04-08 08:43:23 -07:00
|
|
|
ir.RankedTensorType.get([nnz], mat_type.element_type),
|
|
|
|
ir.RankedTensorType.get([nnz], index_type),
|
|
|
|
ir.RankedTensorType.get([rows + 1], index_type),
|
|
|
|
ir.RankedTensorType.get([buffer_size],
|
|
|
|
ir.IntegerType.get_signless(8)),
|
2022-05-06 14:50:54 -07:00
|
|
|
],
|
2023-09-03 13:06:50 -07:00
|
|
|
operands=[mat],
|
2025-02-06 11:45:14 -08:00
|
|
|
backend_config={"opaque": ir.StringAttr.get(opaque)},
|
|
|
|
api_version=4,
|
2022-05-06 14:50:54 -07:00
|
|
|
operand_layouts=[[1, 0]],
|
2023-09-03 13:06:50 -07:00
|
|
|
result_layouts=[[0]] * 4).results
|
2022-05-06 14:50:54 -07:00
|
|
|
return out[:3]
|
2022-04-08 08:43:23 -07:00
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
cuda_csr_fromdense = partial(_csr_fromdense_hlo, "cu", _cusparse)
|
|
|
|
rocm_csr_fromdense = partial(_csr_fromdense_hlo, "hip", _hipsparse)
|
2022-05-06 09:34:25 -07:00
|
|
|
|
2021-04-15 10:10:40 -07:00
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
def _csr_matvec_hlo(platform, gpu_sparse, data, indices, indptr, x, *, shape,
|
|
|
|
transpose=False, compute_dtype=None, compute_type=None,
|
|
|
|
data_dtype, index_dtype, x_dtype):
|
2022-04-08 08:43:23 -07:00
|
|
|
"""CSR matrix/vector multiply."""
|
2022-12-15 20:59:34 -08:00
|
|
|
data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape)
|
2022-04-08 08:43:23 -07:00
|
|
|
rows, cols = shape
|
|
|
|
|
|
|
|
if compute_dtype is None:
|
|
|
|
compute_dtype = data_dtype
|
|
|
|
compute_type = data_type
|
|
|
|
|
2022-05-06 09:34:25 -07:00
|
|
|
buffer_size, opaque = gpu_sparse.build_csr_matvec_descriptor(
|
2022-04-08 08:43:23 -07:00
|
|
|
data_dtype, x_dtype, compute_dtype, index_dtype,
|
|
|
|
rows, cols, nnz, transpose)
|
|
|
|
out_size = cols if transpose else rows
|
|
|
|
|
2022-05-06 14:50:54 -07:00
|
|
|
out = custom_call(
|
2025-02-06 11:45:14 -08:00
|
|
|
f"{platform}sparse_csr_matvec_ffi",
|
2023-09-03 13:06:50 -07:00
|
|
|
result_types=[
|
2022-04-08 08:43:23 -07:00
|
|
|
ir.RankedTensorType.get([out_size], compute_type),
|
|
|
|
ir.RankedTensorType.get([buffer_size],
|
|
|
|
ir.IntegerType.get_signless(8)),
|
2022-05-06 14:50:54 -07:00
|
|
|
],
|
2023-09-03 13:06:50 -07:00
|
|
|
operands=[data, indices, indptr, x],
|
2025-02-06 11:45:14 -08:00
|
|
|
backend_config={"opaque": ir.StringAttr.get(opaque)},
|
|
|
|
api_version=4,
|
2022-05-06 14:50:54 -07:00
|
|
|
operand_layouts=[[0]] * 4,
|
2023-09-03 13:06:50 -07:00
|
|
|
result_layouts=[[0]] * 2).results
|
2022-05-06 14:50:54 -07:00
|
|
|
return out[0]
|
2022-04-08 08:43:23 -07:00
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
cuda_csr_matvec = partial(_csr_matvec_hlo, "cu", _cusparse)
|
|
|
|
rocm_csr_matvec = partial(_csr_matvec_hlo, "hip", _hipsparse)
|
2021-04-15 10:10:40 -07:00
|
|
|
|
2022-05-06 09:34:25 -07:00
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
def _csr_matmat_hlo(platform, gpu_sparse, data, indices, indptr, B, *, shape,
|
|
|
|
transpose=False, compute_dtype=None, compute_type=None,
|
2023-02-13 08:39:05 -08:00
|
|
|
index_dtype, data_dtype, B_dtype):
|
2022-04-08 08:43:23 -07:00
|
|
|
"""CSR from dense matrix."""
|
2022-12-15 20:59:34 -08:00
|
|
|
data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape)
|
2022-04-08 08:43:23 -07:00
|
|
|
rows, cols = shape
|
|
|
|
B_shape = ir.RankedTensorType(B.type).shape
|
|
|
|
_, Ccols = B_shape
|
|
|
|
|
|
|
|
if compute_dtype is None:
|
|
|
|
compute_dtype = data_dtype
|
|
|
|
compute_type = data_type
|
|
|
|
|
2022-05-06 09:34:25 -07:00
|
|
|
buffer_size, opaque = gpu_sparse.build_csr_matmat_descriptor(
|
2023-02-13 08:39:05 -08:00
|
|
|
data_dtype, B_dtype, compute_dtype, index_dtype,
|
2022-04-08 08:43:23 -07:00
|
|
|
rows, cols, Ccols, nnz, transpose)
|
|
|
|
out_size = cols if transpose else rows
|
|
|
|
|
2022-05-06 14:50:54 -07:00
|
|
|
out = custom_call(
|
2025-02-06 11:45:14 -08:00
|
|
|
f"{platform}sparse_csr_matmat_ffi",
|
2023-09-03 13:06:50 -07:00
|
|
|
result_types=[
|
2022-04-08 08:43:23 -07:00
|
|
|
ir.RankedTensorType.get([out_size, Ccols], compute_type),
|
|
|
|
ir.RankedTensorType.get([buffer_size],
|
|
|
|
ir.IntegerType.get_signless(8)),
|
2022-05-06 14:50:54 -07:00
|
|
|
],
|
2023-09-03 13:06:50 -07:00
|
|
|
operands=[data, indices, indptr, B],
|
2025-02-06 11:45:14 -08:00
|
|
|
backend_config={"opaque": ir.StringAttr.get(opaque)},
|
|
|
|
api_version=4,
|
2022-05-06 14:50:54 -07:00
|
|
|
operand_layouts=[[0], [0], [0], [1, 0]],
|
2023-09-03 13:06:50 -07:00
|
|
|
result_layouts=[[1, 0], [0]]).results
|
2022-05-06 14:50:54 -07:00
|
|
|
return out[0]
|
2022-04-08 08:43:23 -07:00
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
cuda_csr_matmat = partial(_csr_matmat_hlo, "cu", _cusparse)
|
|
|
|
rocm_csr_matmat = partial(_csr_matmat_hlo, "hip", _hipsparse)
|
2022-05-06 09:34:25 -07:00
|
|
|
|
2021-04-15 10:10:40 -07:00
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
def _coo_todense_hlo(platform, gpu_sparse, data, row, col, *, shape,
|
|
|
|
data_dtype, index_dtype):
|
2022-04-08 08:43:23 -07:00
|
|
|
"""COO to dense matrix."""
|
2022-12-15 20:59:34 -08:00
|
|
|
data_type, _, nnz = _validate_coo_hlo(data, row, col)
|
2022-04-08 08:43:23 -07:00
|
|
|
rows, cols = shape
|
|
|
|
|
2022-05-06 09:34:25 -07:00
|
|
|
buffer_size, opaque = gpu_sparse.build_coo_todense_descriptor(
|
2022-04-08 08:43:23 -07:00
|
|
|
data_dtype, index_dtype, rows, cols, nnz)
|
|
|
|
|
2022-05-06 14:50:54 -07:00
|
|
|
out = custom_call(
|
2025-02-06 11:45:14 -08:00
|
|
|
f"{platform}sparse_coo_todense_ffi",
|
2023-09-03 13:06:50 -07:00
|
|
|
result_types=[
|
2022-04-08 08:43:23 -07:00
|
|
|
ir.RankedTensorType.get(shape, data_type),
|
|
|
|
ir.RankedTensorType.get([buffer_size],
|
|
|
|
ir.IntegerType.get_signless(8)),
|
2022-05-06 14:50:54 -07:00
|
|
|
],
|
2023-09-03 13:06:50 -07:00
|
|
|
operands=[data, row, col],
|
2025-02-06 11:45:14 -08:00
|
|
|
backend_config={"opaque": ir.StringAttr.get(opaque)},
|
|
|
|
api_version=4,
|
2022-05-06 14:50:54 -07:00
|
|
|
operand_layouts=[[0]] * 3,
|
2023-09-03 13:06:50 -07:00
|
|
|
result_layouts=[[1, 0], [0]]).results
|
2022-05-06 14:50:54 -07:00
|
|
|
return out[0]
|
2022-04-08 08:43:23 -07:00
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
cuda_coo_todense = partial(_coo_todense_hlo, "cu", _cusparse)
|
|
|
|
rocm_coo_todense = partial(_coo_todense_hlo, "hip", _hipsparse)
|
2021-04-15 10:10:40 -07:00
|
|
|
|
2022-05-06 09:34:25 -07:00
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
def _coo_fromdense_hlo(platform, gpu_sparse, mat, *, nnz, data_dtype,
|
|
|
|
index_dtype, index_type):
|
2022-04-08 08:43:23 -07:00
|
|
|
"""COO from dense matrix."""
|
|
|
|
mat_type = ir.RankedTensorType(mat.type)
|
|
|
|
rows, cols = mat_type.shape
|
|
|
|
|
2022-05-06 09:34:25 -07:00
|
|
|
buffer_size, opaque = gpu_sparse.build_coo_fromdense_descriptor(
|
2022-04-08 08:43:23 -07:00
|
|
|
data_dtype, index_dtype, rows, cols, nnz)
|
|
|
|
|
2022-05-06 14:50:54 -07:00
|
|
|
out = custom_call(
|
2025-02-06 11:45:14 -08:00
|
|
|
f"{platform}sparse_coo_fromdense_ffi",
|
2023-09-03 13:06:50 -07:00
|
|
|
result_types=[
|
2022-04-08 08:43:23 -07:00
|
|
|
ir.RankedTensorType.get([nnz], mat_type.element_type),
|
|
|
|
ir.RankedTensorType.get([nnz], index_type),
|
|
|
|
ir.RankedTensorType.get([nnz], index_type),
|
|
|
|
ir.RankedTensorType.get([buffer_size],
|
|
|
|
ir.IntegerType.get_signless(8)),
|
2022-05-06 14:50:54 -07:00
|
|
|
],
|
2023-09-03 13:06:50 -07:00
|
|
|
operands=[mat],
|
2025-02-06 11:45:14 -08:00
|
|
|
backend_config={"opaque": ir.StringAttr.get(opaque)},
|
|
|
|
api_version=4,
|
2022-05-06 14:50:54 -07:00
|
|
|
operand_layouts=[[1, 0]],
|
2023-09-03 13:06:50 -07:00
|
|
|
result_layouts=[[0]] * 4).results
|
2022-05-06 14:50:54 -07:00
|
|
|
return out[:3]
|
2022-04-08 08:43:23 -07:00
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
cuda_coo_fromdense = partial(_coo_fromdense_hlo, "cu", _cusparse)
|
|
|
|
rocm_coo_fromdense = partial(_coo_fromdense_hlo, "hip", _hipsparse)
|
2022-05-06 09:34:25 -07:00
|
|
|
|
2021-04-15 10:10:40 -07:00
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
def _coo_matvec_hlo(platform, gpu_sparse, data, row, col, x, *, shape,
|
|
|
|
transpose=False, compute_dtype=None, compute_type=None,
|
|
|
|
index_dtype, data_dtype, x_dtype):
|
2022-04-08 08:43:23 -07:00
|
|
|
"""COO matrix/vector multiply."""
|
2022-12-15 20:59:34 -08:00
|
|
|
data_type, _, nnz = _validate_coo_hlo(data, row, col)
|
2022-04-08 08:43:23 -07:00
|
|
|
rows, cols = shape
|
|
|
|
|
|
|
|
if compute_dtype is None:
|
|
|
|
compute_dtype = data_dtype
|
|
|
|
compute_type = data_type
|
|
|
|
|
2022-05-06 09:34:25 -07:00
|
|
|
buffer_size, opaque = gpu_sparse.build_coo_matvec_descriptor(
|
2022-04-08 08:43:23 -07:00
|
|
|
data_dtype, x_dtype, compute_dtype, index_dtype,
|
|
|
|
rows, cols, nnz, transpose)
|
|
|
|
out_size = cols if transpose else rows
|
|
|
|
|
2022-05-06 14:50:54 -07:00
|
|
|
out = custom_call(
|
2025-02-06 11:45:14 -08:00
|
|
|
f"{platform}sparse_coo_matvec_ffi",
|
2023-09-03 13:06:50 -07:00
|
|
|
result_types=[
|
2022-04-08 08:43:23 -07:00
|
|
|
ir.RankedTensorType.get([out_size], compute_type),
|
|
|
|
ir.RankedTensorType.get([buffer_size],
|
|
|
|
ir.IntegerType.get_signless(8)),
|
2022-05-06 14:50:54 -07:00
|
|
|
],
|
2023-09-03 13:06:50 -07:00
|
|
|
operands=[data, row, col, x],
|
2025-02-06 11:45:14 -08:00
|
|
|
backend_config={"opaque": ir.StringAttr.get(opaque)},
|
|
|
|
api_version=4,
|
2022-05-06 14:50:54 -07:00
|
|
|
operand_layouts=[[0]] * 4,
|
2023-09-03 13:06:50 -07:00
|
|
|
result_layouts=[[0]] * 2).results
|
2022-05-06 14:50:54 -07:00
|
|
|
return out[0]
|
2022-04-08 08:43:23 -07:00
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
cuda_coo_matvec = partial(_coo_matvec_hlo, "cu", _cusparse)
|
|
|
|
rocm_coo_matvec = partial(_coo_matvec_hlo, "hip", _hipsparse)
|
2022-04-08 08:43:23 -07:00
|
|
|
|
2022-05-06 09:34:25 -07:00
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
def _coo_matmat_hlo(platform, gpu_sparse, data, row, col, B, *, shape,
|
|
|
|
transpose=False, compute_dtype=None, compute_type=None,
|
|
|
|
x_dtype, data_dtype, index_dtype):
|
2022-04-08 08:43:23 -07:00
|
|
|
"""COO from dense matrix."""
|
2022-12-15 20:59:34 -08:00
|
|
|
data_type, _, nnz = _validate_coo_hlo(data, row, col)
|
2022-09-12 10:08:50 -07:00
|
|
|
is_batched_matmat = False
|
|
|
|
batch_count = 1
|
|
|
|
if len(shape) == 2:
|
|
|
|
rows, cols = shape
|
|
|
|
elif len(shape) == 3:
|
|
|
|
is_batched_matmat = True
|
|
|
|
batch_count, rows, cols = shape
|
|
|
|
# Redefine nnz as nnz per batch.
|
|
|
|
nnz = nnz // batch_count
|
|
|
|
|
2022-04-08 08:43:23 -07:00
|
|
|
B_shape = ir.RankedTensorType(B.type).shape
|
|
|
|
_, Ccols = B_shape
|
|
|
|
|
|
|
|
if compute_dtype is None:
|
|
|
|
compute_dtype = data_dtype
|
|
|
|
compute_type = data_type
|
|
|
|
|
2022-08-19 12:25:37 -07:00
|
|
|
# TODO(tianjianlu): use batch stride to trigger different mode of batch
|
|
|
|
# computation. Currently batch_stride = 0 is not allowed because of the issue
|
|
|
|
# in cusparse https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643
|
|
|
|
# Set batch stride to be the matrix size for now.
|
2022-09-12 10:08:50 -07:00
|
|
|
lhs_batch_stride = nnz
|
2022-08-19 12:25:37 -07:00
|
|
|
B_rows = rows if transpose else cols
|
|
|
|
rhs_batch_stride = B_rows * Ccols
|
|
|
|
|
2022-05-06 09:34:25 -07:00
|
|
|
buffer_size, opaque = gpu_sparse.build_coo_matmat_descriptor(
|
2022-04-08 08:43:23 -07:00
|
|
|
data_dtype, x_dtype, compute_dtype, index_dtype,
|
2022-08-19 12:25:37 -07:00
|
|
|
rows, cols, Ccols, nnz, transpose, batch_count, lhs_batch_stride,
|
|
|
|
rhs_batch_stride)
|
2022-04-08 08:43:23 -07:00
|
|
|
out_size = cols if transpose else rows
|
|
|
|
|
2022-09-12 10:08:50 -07:00
|
|
|
if is_batched_matmat:
|
|
|
|
out_shape = [batch_count, out_size, Ccols]
|
|
|
|
out_layout = [2, 1, 0]
|
|
|
|
else:
|
|
|
|
out_shape = [out_size, Ccols]
|
|
|
|
out_layout = [1, 0]
|
|
|
|
|
2022-05-06 14:50:54 -07:00
|
|
|
out = custom_call(
|
2025-02-06 11:45:14 -08:00
|
|
|
f"{platform}sparse_coo_matmat_ffi",
|
2023-09-03 13:06:50 -07:00
|
|
|
result_types=[
|
2022-09-12 10:08:50 -07:00
|
|
|
ir.RankedTensorType.get(out_shape, compute_type),
|
2022-04-08 08:43:23 -07:00
|
|
|
ir.RankedTensorType.get([buffer_size],
|
|
|
|
ir.IntegerType.get_signless(8)),
|
2022-05-06 14:50:54 -07:00
|
|
|
],
|
2023-09-03 13:06:50 -07:00
|
|
|
operands=[data, row, col, B],
|
2025-02-06 11:45:14 -08:00
|
|
|
backend_config={"opaque": ir.StringAttr.get(opaque)},
|
|
|
|
api_version=4,
|
2022-05-06 14:50:54 -07:00
|
|
|
operand_layouts=[[0], [0], [0], [1, 0]],
|
2023-09-03 13:06:50 -07:00
|
|
|
result_layouts=[out_layout, [0]]).results
|
2022-05-06 14:50:54 -07:00
|
|
|
return out[0]
|
2022-04-08 08:43:23 -07:00
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
cuda_coo_matmat = partial(_coo_matmat_hlo, "cu", _cusparse)
|
|
|
|
rocm_coo_matmat = partial(_coo_matmat_hlo, "hip", _hipsparse)
|
2022-05-06 09:34:25 -07:00
|
|
|
|
2021-05-26 19:14:37 +00:00
|
|
|
|
2023-08-10 16:25:23 -07:00
|
|
|
def _gtsv2_hlo(
|
|
|
|
platform, gpu_sparse, dl, d, du, B, *, m, n, ldb, t, b_shape_vals=None):
|
2022-04-06 13:56:01 -07:00
|
|
|
"""Calls `cusparse<t>gtsv2(dl, d, du, B, m, n, ldb)`."""
|
2023-08-10 16:25:23 -07:00
|
|
|
assert len(b_shape_vals) >= 2
|
|
|
|
batch_dim_vals = b_shape_vals[:-2]
|
|
|
|
batch_size = math.prod(batch_dim_vals)
|
|
|
|
num_bd = len(b_shape_vals) - 2
|
2022-04-06 13:56:01 -07:00
|
|
|
f32 = (t == np.float32)
|
|
|
|
if f32:
|
2022-05-06 09:34:25 -07:00
|
|
|
buffer_size = gpu_sparse.gtsv2_f32_buffer_size(m, n, ldb)
|
2022-04-06 13:56:01 -07:00
|
|
|
else:
|
2022-05-06 09:34:25 -07:00
|
|
|
buffer_size = gpu_sparse.gtsv2_f64_buffer_size(m, n, ldb)
|
2023-08-10 16:25:23 -07:00
|
|
|
|
|
|
|
b_layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
|
|
|
d_layout = (num_bd,) + tuple(range(num_bd - 1, -1, -1))
|
|
|
|
b_type = ir.RankedTensorType(B.type)
|
|
|
|
|
|
|
|
shape_type_pairs = [
|
|
|
|
(batch_dim_vals + (ldb, n), b_type.element_type),
|
|
|
|
((buffer_size,), ir.IntegerType.get_signless(8))
|
|
|
|
]
|
|
|
|
result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs)
|
2025-02-06 11:45:14 -08:00
|
|
|
opaque = gpu_sparse.build_gtsv2_descriptor(batch_size, m, n, ldb)
|
2022-05-06 14:50:54 -07:00
|
|
|
out = custom_call(
|
2025-02-06 11:45:14 -08:00
|
|
|
f"{platform}sparse_gtsv2_" + ("f32" if f32 else "f64") + "_ffi",
|
2023-09-03 13:06:50 -07:00
|
|
|
result_types=result_types,
|
|
|
|
operands=[dl, d, du, B],
|
2025-02-06 11:45:14 -08:00
|
|
|
backend_config={"opaque": ir.StringAttr.get(opaque)},
|
|
|
|
api_version=4,
|
2023-08-10 16:25:23 -07:00
|
|
|
operand_layouts=[d_layout] * 3 + [b_layout],
|
|
|
|
result_layouts=[b_layout, [0]],
|
|
|
|
operand_output_aliases={3: 0},
|
2023-09-03 13:06:50 -07:00
|
|
|
result_shapes=result_shapes).results
|
2022-05-06 14:50:54 -07:00
|
|
|
return out[0]
|
2022-05-06 09:34:25 -07:00
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
cuda_gtsv2 = partial(_gtsv2_hlo, "cu", _cusparse)
|
|
|
|
rocm_gtsv2 = partial(_gtsv2_hlo, "hip", _hipsparse)
|