mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #14444 from jakevdp:fix-csr-lowering
PiperOrigin-RevId: 509241948
This commit is contained in:
commit
83b7ba2ba0
@ -637,8 +637,10 @@ def _bcsr_dot_general_gpu_lowering(
|
||||
|
||||
if rhs_aval.ndim == 1:
|
||||
dot_general_fn = csr_matvec_lowering
|
||||
x_dtype = 'x_dtype'
|
||||
elif rhs_aval.ndim == 2:
|
||||
dot_general_fn = csr_matmat_lowering
|
||||
x_dtype = 'B_dtype'
|
||||
if rhs_contract[0] == 1:
|
||||
rhs = hlo.TransposeOp(
|
||||
rhs, permutation=mlir.dense_int_elements([1, 0])).result
|
||||
@ -649,7 +651,7 @@ def _bcsr_dot_general_gpu_lowering(
|
||||
shape=lhs_spinfo.shape, transpose=False,
|
||||
data_dtype=lhs_data_aval.dtype,
|
||||
index_dtype=lhs_indices_aval.dtype,
|
||||
x_dtype=rhs_aval.dtype)]
|
||||
**{x_dtype: rhs_aval.dtype})]
|
||||
|
||||
_bcsr_dot_general_default_lowering = mlir.lower_fun(
|
||||
_bcsr_dot_general_impl, multiple_results=False)
|
||||
|
@ -516,7 +516,7 @@ def _csr_matmat_gpu_lowering(csr_matmat_hlo, ctx, data, indices, indptr, B, *,
|
||||
return [csr_matmat_hlo(
|
||||
data, indices, indptr, B, shape=shape, transpose=transpose,
|
||||
index_dtype=indices_aval.dtype, data_dtype=data_aval.dtype,
|
||||
x_dtype=B_aval.dtype)]
|
||||
B_dtype=B_aval.dtype)]
|
||||
|
||||
|
||||
def _csr_matmat_jvp_left(data_dot, data, indices, indptr, B, *, shape, transpose):
|
||||
|
@ -158,7 +158,7 @@ rocm_csr_matvec = partial(_csr_matvec_hlo, "hip", _hipsparse)
|
||||
|
||||
def _csr_matmat_hlo(platform, gpu_sparse, data, indices, indptr, B, *, shape,
|
||||
transpose=False, compute_dtype=None, compute_type=None,
|
||||
index_dtype, data_dtype, x_dtype):
|
||||
index_dtype, data_dtype, B_dtype):
|
||||
"""CSR from dense matrix."""
|
||||
data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape)
|
||||
rows, cols = shape
|
||||
@ -170,7 +170,7 @@ def _csr_matmat_hlo(platform, gpu_sparse, data, indices, indptr, B, *, shape,
|
||||
compute_type = data_type
|
||||
|
||||
buffer_size, opaque = gpu_sparse.build_csr_matmat_descriptor(
|
||||
data_dtype, x_dtype, compute_dtype, index_dtype,
|
||||
data_dtype, B_dtype, compute_dtype, index_dtype,
|
||||
rows, cols, Ccols, nnz, transpose)
|
||||
out_size = cols if transpose else rows
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user