mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[ROCm]: Lower sparse(some) ops correctly for ROCm
-Lower coo_spmv, coo_spmm, csr_spmv and csr_spmm correctly for ROCm
This commit is contained in:
parent
0017b3a6d3
commit
ebe7fb1765
@ -17,6 +17,8 @@ In general, these primitives are not meant to be used directly, but rather
|
||||
are used internally in GPU translation rules of higher-level primitives.
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
|
||||
from jax import core
|
||||
from jax._src import dispatch
|
||||
from jax._src.interpreters import mlir
|
||||
@ -52,9 +54,9 @@ def _coo_spmv_abstract_eval(data, row, col, x, *, transpose, shape):
|
||||
shape=shape[1:] if transpose else shape[:1],
|
||||
dtype=x.dtype)
|
||||
|
||||
def _coo_spmv_gpu_lowering(ctx, data, row, col, x, *, transpose, shape):
|
||||
def _coo_spmv_gpu_lowering(coo_spmv_hlo, ctx, data, row, col, x, *, transpose, shape):
|
||||
data_aval, row_aval, _, x_aval = ctx.avals_in
|
||||
return [gpu_sparse.cuda_coo_matvec(
|
||||
return [coo_spmv_hlo(
|
||||
data, row, col, x,
|
||||
shape=shape,
|
||||
transpose=transpose,
|
||||
@ -65,9 +67,15 @@ def _coo_spmv_gpu_lowering(ctx, data, row, col, x, *, transpose, shape):
|
||||
coo_spmv_p.def_abstract_eval(_coo_spmv_abstract_eval)
|
||||
dispatch.simple_impl(coo_spmv_p)
|
||||
if gpu_sparse.cuda_is_supported:
|
||||
mlir.register_lowering(coo_spmv_p, _coo_spmv_gpu_lowering, platform='cuda')
|
||||
mlir.register_lowering(
|
||||
coo_spmv_p,
|
||||
partial(_coo_spmv_gpu_lowering, gpu_sparse.cuda_coo_matvec),
|
||||
platform='cuda')
|
||||
if gpu_sparse.rocm_is_supported:
|
||||
mlir.register_lowering(coo_spmv_p, _coo_spmv_gpu_lowering, platform='rocm')
|
||||
mlir.register_lowering(
|
||||
coo_spmv_p,
|
||||
partial(_coo_spmv_gpu_lowering, gpu_sparse.rocm_coo_matvec),
|
||||
platform='rocm')
|
||||
|
||||
|
||||
# coo_spmm_p
|
||||
@ -95,9 +103,9 @@ def _coo_spmm_abstract_eval(data, row, col, x, *, transpose, shape):
|
||||
shape=(shape[1] if transpose else shape[0], x.shape[1]),
|
||||
dtype=x.dtype)
|
||||
|
||||
def _coo_spmm_gpu_lowering(ctx, data, row, col, x, *, transpose, shape):
|
||||
def _coo_spmm_gpu_lowering(coo_spmm_hlo, ctx, data, row, col, x, *, transpose, shape):
|
||||
data_aval, row_aval, _, x_aval = ctx.avals_in
|
||||
return [gpu_sparse.cuda_coo_matmat(
|
||||
return [coo_spmm_hlo(
|
||||
data, row, col, x,
|
||||
shape=shape,
|
||||
transpose=transpose,
|
||||
@ -108,9 +116,15 @@ def _coo_spmm_gpu_lowering(ctx, data, row, col, x, *, transpose, shape):
|
||||
coo_spmm_p.def_abstract_eval(_coo_spmm_abstract_eval)
|
||||
dispatch.simple_impl(coo_spmm_p)
|
||||
if gpu_sparse.cuda_is_supported:
|
||||
mlir.register_lowering(coo_spmm_p, _coo_spmm_gpu_lowering, platform='cuda')
|
||||
mlir.register_lowering(
|
||||
coo_spmm_p,
|
||||
partial(_coo_spmm_gpu_lowering, gpu_sparse.cuda_coo_matmat),
|
||||
platform='cuda')
|
||||
if gpu_sparse.rocm_is_supported:
|
||||
mlir.register_lowering(coo_spmm_p, _coo_spmm_gpu_lowering, platform='rocm')
|
||||
mlir.register_lowering(
|
||||
coo_spmm_p,
|
||||
partial(_coo_spmm_gpu_lowering, gpu_sparse.rocm_coo_matmat),
|
||||
platform='rocm')
|
||||
|
||||
# csr_spmv_p
|
||||
# This is an internal-only primitive that calls into cusparse csr SpMV.
|
||||
@ -137,9 +151,9 @@ def _csr_spmv_abstract_eval(data, indices, indptr, x, *, transpose, shape):
|
||||
shape=shape[1:] if transpose else shape[:1],
|
||||
dtype=x.dtype)
|
||||
|
||||
def _csr_spmv_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape):
|
||||
def _csr_spmv_gpu_lowering(csr_spmv_hlo, ctx, data, indices, indptr, x, *, transpose, shape):
|
||||
data_aval, indices_aval, _, x_aval = ctx.avals_in
|
||||
return [gpu_sparse.cuda_csr_matvec(
|
||||
return [csr_spmv_hlo(
|
||||
data, indices, indptr, x,
|
||||
shape=shape,
|
||||
transpose=transpose,
|
||||
@ -150,9 +164,15 @@ def _csr_spmv_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape):
|
||||
csr_spmv_p.def_abstract_eval(_csr_spmv_abstract_eval)
|
||||
dispatch.simple_impl(csr_spmv_p)
|
||||
if gpu_sparse.cuda_is_supported:
|
||||
mlir.register_lowering(csr_spmv_p, _csr_spmv_gpu_lowering, platform='cuda')
|
||||
mlir.register_lowering(
|
||||
csr_spmv_p,
|
||||
partial(_csr_spmv_gpu_lowering, gpu_sparse.cuda_csr_matvec),
|
||||
platform='cuda')
|
||||
if gpu_sparse.rocm_is_supported:
|
||||
mlir.register_lowering(csr_spmv_p, _csr_spmv_gpu_lowering, platform='rocm')
|
||||
mlir.register_lowering(
|
||||
csr_spmv_p,
|
||||
partial(_csr_spmv_gpu_lowering, gpu_sparse.rocm_csr_matvec),
|
||||
platform='rocm')
|
||||
|
||||
|
||||
# csr_spmm_p
|
||||
@ -180,9 +200,9 @@ def _csr_spmm_abstract_eval(data, indices, indptr, x, *, transpose, shape):
|
||||
shape=(shape[1] if transpose else shape[0], x.shape[1]),
|
||||
dtype=x.dtype)
|
||||
|
||||
def _csr_spmm_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape):
|
||||
def _csr_spmm_gpu_lowering(csr_spmm_hlo, ctx, data, indices, indptr, x, *, transpose, shape):
|
||||
data_aval, indices_aval, _, x_aval = ctx.avals_in
|
||||
return [gpu_sparse.cuda_csr_matmat(
|
||||
return [csr_spmm_hlo(
|
||||
data, indices, indptr, x,
|
||||
shape=shape,
|
||||
transpose=transpose,
|
||||
@ -193,6 +213,12 @@ def _csr_spmm_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape):
|
||||
csr_spmm_p.def_abstract_eval(_csr_spmm_abstract_eval)
|
||||
dispatch.simple_impl(csr_spmm_p)
|
||||
if gpu_sparse.cuda_is_supported:
|
||||
mlir.register_lowering(csr_spmm_p, _csr_spmm_gpu_lowering, platform='cuda')
|
||||
mlir.register_lowering(
|
||||
csr_spmm_p,
|
||||
partial(_csr_spmm_gpu_lowering, gpu_sparse.cuda_csr_matmat),
|
||||
platform='cuda')
|
||||
if gpu_sparse.rocm_is_supported:
|
||||
mlir.register_lowering(csr_spmm_p, _csr_spmm_gpu_lowering, platform='rocm')
|
||||
mlir.register_lowering(
|
||||
csr_spmm_p,
|
||||
partial(_csr_spmm_gpu_lowering, gpu_sparse.rocm_csr_matmat),
|
||||
platform='rocm')
|
||||
|
Loading…
x
Reference in New Issue
Block a user