Merge pull request #14444 from jakevdp:fix-csr-lowering

PiperOrigin-RevId: 509241948
This commit is contained in:
jax authors 2023-02-13 08:59:37 -08:00
commit 83b7ba2ba0
3 changed files with 6 additions and 4 deletions

View File

@ -637,8 +637,10 @@ def _bcsr_dot_general_gpu_lowering(
if rhs_aval.ndim == 1:
dot_general_fn = csr_matvec_lowering
x_dtype = 'x_dtype'
elif rhs_aval.ndim == 2:
dot_general_fn = csr_matmat_lowering
x_dtype = 'B_dtype'
if rhs_contract[0] == 1:
rhs = hlo.TransposeOp(
rhs, permutation=mlir.dense_int_elements([1, 0])).result
@ -649,7 +651,7 @@ def _bcsr_dot_general_gpu_lowering(
shape=lhs_spinfo.shape, transpose=False,
data_dtype=lhs_data_aval.dtype,
index_dtype=lhs_indices_aval.dtype,
x_dtype=rhs_aval.dtype)]
**{x_dtype: rhs_aval.dtype})]
_bcsr_dot_general_default_lowering = mlir.lower_fun(
_bcsr_dot_general_impl, multiple_results=False)

View File

@ -516,7 +516,7 @@ def _csr_matmat_gpu_lowering(csr_matmat_hlo, ctx, data, indices, indptr, B, *,
return [csr_matmat_hlo(
data, indices, indptr, B, shape=shape, transpose=transpose,
index_dtype=indices_aval.dtype, data_dtype=data_aval.dtype,
x_dtype=B_aval.dtype)]
B_dtype=B_aval.dtype)]
def _csr_matmat_jvp_left(data_dot, data, indices, indptr, B, *, shape, transpose):

View File

@ -158,7 +158,7 @@ rocm_csr_matvec = partial(_csr_matvec_hlo, "hip", _hipsparse)
def _csr_matmat_hlo(platform, gpu_sparse, data, indices, indptr, B, *, shape,
transpose=False, compute_dtype=None, compute_type=None,
index_dtype, data_dtype, x_dtype):
index_dtype, data_dtype, B_dtype):
"""CSR from dense matrix."""
data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape)
rows, cols = shape
@ -170,7 +170,7 @@ def _csr_matmat_hlo(platform, gpu_sparse, data, indices, indptr, B, *, shape,
compute_type = data_type
buffer_size, opaque = gpu_sparse.build_csr_matmat_descriptor(
data_dtype, x_dtype, compute_dtype, index_dtype,
data_dtype, B_dtype, compute_dtype, index_dtype,
rows, cols, Ccols, nnz, transpose)
out_size = cols if transpose else rows