Align dummy pointers passed to cusparse to 16 bytes

Fixes alignment errors from Cusparse 12.2.

PiperOrigin-RevId: 560793586
This commit is contained in:
Peter Hawkins 2023-08-28 12:55:31 -07:00 committed by jax authors
parent b09bef7793
commit 34010a9e4a

View File

@ -122,7 +122,7 @@ std::pair<size_t, nb::bytes> BuildCsrToDenseDescriptor(
// buffer_size does not reference these pointers, but does error on NULL.
// TODO(jakevdp): check whether this is documented.
int val = 0;
int val alignas(16) = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCsr(
@ -198,7 +198,7 @@ std::pair<size_t, nb::bytes> BuildCsrFromDenseDescriptor(
gpusparseSpMatDescr_t mat_b = 0;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
int val alignas(16) = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
&mat_a, d.rows, d.cols,
@ -282,7 +282,7 @@ std::pair<size_t, nb::bytes> BuildCsrMatvecDescriptor(
: GPUSPARSE_OPERATION_NON_TRANSPOSE;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
int val alignas(16) = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCsr(
&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type,
@ -332,7 +332,7 @@ std::pair<size_t, nb::bytes> BuildCsrMatmatDescriptor(
gpusparseDnMatDescr_t mat_c = 0;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
int val alignas(16) = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCsr(
&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type,
@ -374,7 +374,7 @@ std::pair<size_t, nb::bytes> BuildCooToDenseDescriptor(
gpusparseDnMatDescr_t mat_b = 0;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
int val alignas(16) = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCoo(
@ -411,7 +411,7 @@ std::pair<size_t, nb::bytes> BuildCooFromDenseDescriptor(
gpusparseSpMatDescr_t mat_b = 0;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
int val alignas(16) = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
&mat_a, d.rows, d.cols,
@ -455,7 +455,7 @@ std::pair<size_t, nb::bytes> BuildCooMatvecDescriptor(
: GPUSPARSE_OPERATION_NON_TRANSPOSE;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
int val alignas(16) = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCoo(
&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type,
@ -517,7 +517,7 @@ std::pair<size_t, nb::bytes> BuildCooMatmatDescriptor(
gpusparseDnMatDescr_t mat_c = 0;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
int val alignas(16) = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCoo(
&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type,