Fix typo: JAX_CUSPARSE_11030 -> JAX_CUSPARSE_11300

This is a silly typo, but it's been annoying me for months

PiperOrigin-RevId: 432078590
This commit is contained in:
Jake VanderPlas 2022-03-02 18:31:49 -08:00 committed by jax authors
parent d369501417
commit 3403054b33
4 changed files with 17 additions and 17 deletions

View File

@ -64,7 +64,7 @@ cudaDataType DtypeToCudaDataType(const py::dtype& np_type) {
{{'c', 16}, CUDA_C_64F}, {{'i', 1}, CUDA_R_8I},
{{'u', 1}, CUDA_R_8U}, {{'i', 4}, CUDA_R_32I},
{{'u', 4}, CUDA_R_32U},
#if JAX_CUSPARSE_11030
#if JAX_CUSPARSE_11300
{{'V', 2}, CUDA_R_16BF},
#endif
});
@ -98,7 +98,7 @@ DenseVecDescriptor BuildDenseVecDescriptor(const py::dtype& data_dtype,
return DenseVecDescriptor{value_type, size};
}
#if JAX_CUSPARSE_11030
#if JAX_CUSPARSE_11300
// CsrToDense: Convert CSR matrix to dense matrix
// Returns the descriptor for a Sparse matrix.
@ -513,7 +513,7 @@ std::pair<size_t, py::bytes> BuildCooMatmatDescriptor(
return {buffer_size, PackDescriptor(CooMatmatDescriptor{A, B, C, op_A})};
}
#endif // if JAX_CUSPARSE_11030
#endif // if JAX_CUSPARSE_11300
py::bytes BuildGtsv2Descriptor(int m, int n, int ldb) {
return PackDescriptor(Gtsv2Descriptor{m, n, ldb});
@ -541,7 +541,7 @@ size_t Gtsv2BufferSizeF64(int m, int n, int ldb) {
py::dict Registrations() {
py::dict dict;
#if JAX_CUSPARSE_11030
#if JAX_CUSPARSE_11300
dict["cusparse_csr_todense"] = EncapsulateFunction(CsrToDense);
dict["cusparse_csr_fromdense"] = EncapsulateFunction(CsrFromDense);
dict["cusparse_csr_matvec"] = EncapsulateFunction(CsrMatvec);
@ -558,9 +558,9 @@ py::dict Registrations() {
}
PYBIND11_MODULE(_cusparse, m) {
m.attr("cusparse_supported") = py::bool_(JAX_CUSPARSE_11030);
m.attr("cusparse_supported") = py::bool_(JAX_CUSPARSE_11300);
m.def("registrations", &Registrations);
#if JAX_CUSPARSE_11030
#if JAX_CUSPARSE_11300
m.def("build_csr_todense_descriptor", &BuildCsrToDenseDescriptor);
m.def("build_csr_fromdense_descriptor", &BuildCsrFromDenseDescriptor);
m.def("build_csr_matvec_descriptor", &BuildCsrMatvecDescriptor);

View File

@ -74,7 +74,7 @@ CudaConst CudaOne(cudaDataType type) {
CudaConst c;
std::memset(&c, 0, sizeof(c));
switch (type) {
#if JAX_CUSPARSE_11030
#if JAX_CUSPARSE_11300
// TODO(jakevdp): 4I/4U here might break on big endian platforms.
case CUDA_R_4I:
case CUDA_C_4I:
@ -83,7 +83,7 @@ CudaConst CudaOne(cudaDataType type) {
case CUDA_C_8I:
c.i8[0] = 1;
break;
#if JAX_CUSPARSE_11030
#if JAX_CUSPARSE_11300
case CUDA_R_4U:
case CUDA_C_4U:
#endif
@ -91,7 +91,7 @@ CudaConst CudaOne(cudaDataType type) {
case CUDA_C_8U:
c.u8[0] = 1;
break;
#if JAX_CUSPARSE_11030
#if JAX_CUSPARSE_11300
case CUDA_R_16I:
case CUDA_C_16I:
c.i16[0] = 1;
@ -109,7 +109,7 @@ CudaConst CudaOne(cudaDataType type) {
case CUDA_C_32U:
c.u32[0] = 1;
break;
#if JAX_CUSPARSE_11030
#if JAX_CUSPARSE_11300
case CUDA_R_64I:
case CUDA_C_64I:
c.i64[0] = 1;
@ -124,7 +124,7 @@ CudaConst CudaOne(cudaDataType type) {
case CUDA_C_16F:
c.u16[0] = 0b11110000000000; // 1.0 in little-endian float16
break;
#if JAX_CUSPARSE_11030
#if JAX_CUSPARSE_11300
case CUDA_R_16BF:
case CUDA_C_16BF:
c.u16[0] = 0b11111110000000; // 1.0 in little-endian bfloat16
@ -142,7 +142,7 @@ CudaConst CudaOne(cudaDataType type) {
return c;
}
#if JAX_CUSPARSE_11030
#if JAX_CUSPARSE_11300
// CsrToDense: Convert CSR matrix to dense matrix
static absl::Status CsrToDense_(cudaStream_t stream, void** buffers,
@ -536,7 +536,7 @@ void CooMatmat(cudaStream_t stream, void** buffers, const char* opaque,
s.message().length());
}
}
#endif // if JAX_CUSPARSE_11030
#endif // if JAX_CUSPARSE_11300
template <typename T, typename F>
static absl::Status gtsv2(F computeGtsv2, cudaStream_t stream, void** buffers,

View File

@ -31,7 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/custom_call_status.h"
// Some functionality defined here is only available in CUSPARSE 11.3 or newer.
#define JAX_CUSPARSE_11030 (CUSPARSE_VERSION >= 11300)
#define JAX_CUSPARSE_11300 (CUSPARSE_VERSION >= 11300)
namespace jax {
@ -73,7 +73,7 @@ struct DenseVecDescriptor {
int size;
};
#if JAX_CUSPARSE_11030
#if JAX_CUSPARSE_11300
// CsrToDense: Convert CSR matrix to dense matrix
void CsrToDense(cudaStream_t stream, void** buffers, const char* opaque,
@ -137,7 +137,7 @@ struct CooMatmatDescriptor {
void CooMatmat(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
#endif // if JAX_CUSPARSE_11030
#endif // if JAX_CUSPARSE_11300
struct Gtsv2Descriptor {
int m, n, ldb;

View File

@ -43,7 +43,7 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevj", Syevj, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvd", Gesvd, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvdj", Gesvdj, "CUDA");
#if JAX_CUSPARSE_11030
#if JAX_CUSPARSE_11300
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_todense", CsrToDense,
"CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_fromdense", CsrFromDense,