[ROCm]: Lower sparse(some) ops correctly for ROCm

-Lower coo_spmv, coo_spmm, csr_spmv and csr_spmm
	correctly for ROCm
This commit is contained in:
Rahul Batra 2023-12-13 22:08:36 +00:00
parent 0017b3a6d3
commit ebe7fb1765

View File

@ -17,6 +17,8 @@ In general, these primitives are not meant to be used directly, but rather
are used internally in GPU translation rules of higher-level primitives.
"""
from functools import partial
from jax import core
from jax._src import dispatch
from jax._src.interpreters import mlir
@ -52,9 +54,9 @@ def _coo_spmv_abstract_eval(data, row, col, x, *, transpose, shape):
shape=shape[1:] if transpose else shape[:1],
dtype=x.dtype)
def _coo_spmv_gpu_lowering(ctx, data, row, col, x, *, transpose, shape):
def _coo_spmv_gpu_lowering(coo_spmv_hlo, ctx, data, row, col, x, *, transpose, shape):
data_aval, row_aval, _, x_aval = ctx.avals_in
return [gpu_sparse.cuda_coo_matvec(
return [coo_spmv_hlo(
data, row, col, x,
shape=shape,
transpose=transpose,
@ -65,9 +67,15 @@ def _coo_spmv_gpu_lowering(ctx, data, row, col, x, *, transpose, shape):
coo_spmv_p.def_abstract_eval(_coo_spmv_abstract_eval)
dispatch.simple_impl(coo_spmv_p)
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(coo_spmv_p, _coo_spmv_gpu_lowering, platform='cuda')
mlir.register_lowering(
coo_spmv_p,
partial(_coo_spmv_gpu_lowering, gpu_sparse.cuda_coo_matvec),
platform='cuda')
if gpu_sparse.rocm_is_supported:
mlir.register_lowering(coo_spmv_p, _coo_spmv_gpu_lowering, platform='rocm')
mlir.register_lowering(
coo_spmv_p,
partial(_coo_spmv_gpu_lowering, gpu_sparse.rocm_coo_matvec),
platform='rocm')
# coo_spmm_p
@ -95,9 +103,9 @@ def _coo_spmm_abstract_eval(data, row, col, x, *, transpose, shape):
shape=(shape[1] if transpose else shape[0], x.shape[1]),
dtype=x.dtype)
def _coo_spmm_gpu_lowering(ctx, data, row, col, x, *, transpose, shape):
def _coo_spmm_gpu_lowering(coo_spmm_hlo, ctx, data, row, col, x, *, transpose, shape):
data_aval, row_aval, _, x_aval = ctx.avals_in
return [gpu_sparse.cuda_coo_matmat(
return [coo_spmm_hlo(
data, row, col, x,
shape=shape,
transpose=transpose,
@ -108,9 +116,15 @@ def _coo_spmm_gpu_lowering(ctx, data, row, col, x, *, transpose, shape):
coo_spmm_p.def_abstract_eval(_coo_spmm_abstract_eval)
dispatch.simple_impl(coo_spmm_p)
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(coo_spmm_p, _coo_spmm_gpu_lowering, platform='cuda')
mlir.register_lowering(
coo_spmm_p,
partial(_coo_spmm_gpu_lowering, gpu_sparse.cuda_coo_matmat),
platform='cuda')
if gpu_sparse.rocm_is_supported:
mlir.register_lowering(coo_spmm_p, _coo_spmm_gpu_lowering, platform='rocm')
mlir.register_lowering(
coo_spmm_p,
partial(_coo_spmm_gpu_lowering, gpu_sparse.rocm_coo_matmat),
platform='rocm')
# csr_spmv_p
# This is an internal-only primitive that calls into cusparse csr SpMV.
@ -137,9 +151,9 @@ def _csr_spmv_abstract_eval(data, indices, indptr, x, *, transpose, shape):
shape=shape[1:] if transpose else shape[:1],
dtype=x.dtype)
def _csr_spmv_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape):
def _csr_spmv_gpu_lowering(csr_spmv_hlo, ctx, data, indices, indptr, x, *, transpose, shape):
data_aval, indices_aval, _, x_aval = ctx.avals_in
return [gpu_sparse.cuda_csr_matvec(
return [csr_spmv_hlo(
data, indices, indptr, x,
shape=shape,
transpose=transpose,
@ -150,9 +164,15 @@ def _csr_spmv_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape):
csr_spmv_p.def_abstract_eval(_csr_spmv_abstract_eval)
dispatch.simple_impl(csr_spmv_p)
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(csr_spmv_p, _csr_spmv_gpu_lowering, platform='cuda')
mlir.register_lowering(
csr_spmv_p,
partial(_csr_spmv_gpu_lowering, gpu_sparse.cuda_csr_matvec),
platform='cuda')
if gpu_sparse.rocm_is_supported:
mlir.register_lowering(csr_spmv_p, _csr_spmv_gpu_lowering, platform='rocm')
mlir.register_lowering(
csr_spmv_p,
partial(_csr_spmv_gpu_lowering, gpu_sparse.rocm_csr_matvec),
platform='rocm')
# csr_spmm_p
@ -180,9 +200,9 @@ def _csr_spmm_abstract_eval(data, indices, indptr, x, *, transpose, shape):
shape=(shape[1] if transpose else shape[0], x.shape[1]),
dtype=x.dtype)
def _csr_spmm_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape):
def _csr_spmm_gpu_lowering(csr_spmm_hlo, ctx, data, indices, indptr, x, *, transpose, shape):
data_aval, indices_aval, _, x_aval = ctx.avals_in
return [gpu_sparse.cuda_csr_matmat(
return [csr_spmm_hlo(
data, indices, indptr, x,
shape=shape,
transpose=transpose,
@ -193,6 +213,12 @@ def _csr_spmm_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape):
csr_spmm_p.def_abstract_eval(_csr_spmm_abstract_eval)
dispatch.simple_impl(csr_spmm_p)
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(csr_spmm_p, _csr_spmm_gpu_lowering, platform='cuda')
mlir.register_lowering(
csr_spmm_p,
partial(_csr_spmm_gpu_lowering, gpu_sparse.cuda_csr_matmat),
platform='cuda')
if gpu_sparse.rocm_is_supported:
mlir.register_lowering(csr_spmm_p, _csr_spmm_gpu_lowering, platform='rocm')
mlir.register_lowering(
csr_spmm_p,
partial(_csr_spmm_gpu_lowering, gpu_sparse.rocm_csr_matmat),
platform='rocm')