mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[ROCm] Add rocm support in nm.py lowering code
This commit is contained in:
parent
9e1a4e8d74
commit
d2ab42de3d
@ -181,6 +181,9 @@ dispatch.simple_impl(nm_spmm_p)
|
||||
if gpu_sparse.cuda_is_supported:
|
||||
mlir.register_lowering(nm_spmm_p, _nm_spmm_gpu_lowering, platform="cuda")
|
||||
|
||||
if gpu_sparse.rocm_is_supported:
|
||||
mlir.register_lowering(nm_spmm_p, _nm_spmm_gpu_lowering, platform="ROCM")
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# nm_pack
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user