[ROCm] Add rocm support in nm.py lowering code

This commit is contained in:
Ruturaj4 2024-06-05 20:40:35 +00:00
parent 9e1a4e8d74
commit d2ab42de3d

View File

@ -181,6 +181,9 @@ dispatch.simple_impl(nm_spmm_p)
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(nm_spmm_p, _nm_spmm_gpu_lowering, platform="cuda")
if gpu_sparse.rocm_is_supported:
mlir.register_lowering(nm_spmm_p, _nm_spmm_gpu_lowering, platform="ROCM")
# --------------------------------------------------------------------
# nm_pack