mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix typo: JAX_CUSPARSE_11030 -> JAX_CUSPARSE_11300
This is a silly typo, but it's been annoying me for months PiperOrigin-RevId: 432078590
This commit is contained in:
parent
d369501417
commit
3403054b33
@ -64,7 +64,7 @@ cudaDataType DtypeToCudaDataType(const py::dtype& np_type) {
|
||||
{{'c', 16}, CUDA_C_64F}, {{'i', 1}, CUDA_R_8I},
|
||||
{{'u', 1}, CUDA_R_8U}, {{'i', 4}, CUDA_R_32I},
|
||||
{{'u', 4}, CUDA_R_32U},
|
||||
#if JAX_CUSPARSE_11030
|
||||
#if JAX_CUSPARSE_11300
|
||||
{{'V', 2}, CUDA_R_16BF},
|
||||
#endif
|
||||
});
|
||||
@ -98,7 +98,7 @@ DenseVecDescriptor BuildDenseVecDescriptor(const py::dtype& data_dtype,
|
||||
return DenseVecDescriptor{value_type, size};
|
||||
}
|
||||
|
||||
#if JAX_CUSPARSE_11030
|
||||
#if JAX_CUSPARSE_11300
|
||||
// CsrToDense: Convert CSR matrix to dense matrix
|
||||
|
||||
// Returns the descriptor for a Sparse matrix.
|
||||
@ -513,7 +513,7 @@ std::pair<size_t, py::bytes> BuildCooMatmatDescriptor(
|
||||
return {buffer_size, PackDescriptor(CooMatmatDescriptor{A, B, C, op_A})};
|
||||
}
|
||||
|
||||
#endif // if JAX_CUSPARSE_11030
|
||||
#endif // if JAX_CUSPARSE_11300
|
||||
|
||||
py::bytes BuildGtsv2Descriptor(int m, int n, int ldb) {
|
||||
return PackDescriptor(Gtsv2Descriptor{m, n, ldb});
|
||||
@ -541,7 +541,7 @@ size_t Gtsv2BufferSizeF64(int m, int n, int ldb) {
|
||||
|
||||
py::dict Registrations() {
|
||||
py::dict dict;
|
||||
#if JAX_CUSPARSE_11030
|
||||
#if JAX_CUSPARSE_11300
|
||||
dict["cusparse_csr_todense"] = EncapsulateFunction(CsrToDense);
|
||||
dict["cusparse_csr_fromdense"] = EncapsulateFunction(CsrFromDense);
|
||||
dict["cusparse_csr_matvec"] = EncapsulateFunction(CsrMatvec);
|
||||
@ -558,9 +558,9 @@ py::dict Registrations() {
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(_cusparse, m) {
|
||||
m.attr("cusparse_supported") = py::bool_(JAX_CUSPARSE_11030);
|
||||
m.attr("cusparse_supported") = py::bool_(JAX_CUSPARSE_11300);
|
||||
m.def("registrations", &Registrations);
|
||||
#if JAX_CUSPARSE_11030
|
||||
#if JAX_CUSPARSE_11300
|
||||
m.def("build_csr_todense_descriptor", &BuildCsrToDenseDescriptor);
|
||||
m.def("build_csr_fromdense_descriptor", &BuildCsrFromDenseDescriptor);
|
||||
m.def("build_csr_matvec_descriptor", &BuildCsrMatvecDescriptor);
|
||||
|
@ -74,7 +74,7 @@ CudaConst CudaOne(cudaDataType type) {
|
||||
CudaConst c;
|
||||
std::memset(&c, 0, sizeof(c));
|
||||
switch (type) {
|
||||
#if JAX_CUSPARSE_11030
|
||||
#if JAX_CUSPARSE_11300
|
||||
// TODO(jakevdp): 4I/4U here might break on big endian platforms.
|
||||
case CUDA_R_4I:
|
||||
case CUDA_C_4I:
|
||||
@ -83,7 +83,7 @@ CudaConst CudaOne(cudaDataType type) {
|
||||
case CUDA_C_8I:
|
||||
c.i8[0] = 1;
|
||||
break;
|
||||
#if JAX_CUSPARSE_11030
|
||||
#if JAX_CUSPARSE_11300
|
||||
case CUDA_R_4U:
|
||||
case CUDA_C_4U:
|
||||
#endif
|
||||
@ -91,7 +91,7 @@ CudaConst CudaOne(cudaDataType type) {
|
||||
case CUDA_C_8U:
|
||||
c.u8[0] = 1;
|
||||
break;
|
||||
#if JAX_CUSPARSE_11030
|
||||
#if JAX_CUSPARSE_11300
|
||||
case CUDA_R_16I:
|
||||
case CUDA_C_16I:
|
||||
c.i16[0] = 1;
|
||||
@ -109,7 +109,7 @@ CudaConst CudaOne(cudaDataType type) {
|
||||
case CUDA_C_32U:
|
||||
c.u32[0] = 1;
|
||||
break;
|
||||
#if JAX_CUSPARSE_11030
|
||||
#if JAX_CUSPARSE_11300
|
||||
case CUDA_R_64I:
|
||||
case CUDA_C_64I:
|
||||
c.i64[0] = 1;
|
||||
@ -124,7 +124,7 @@ CudaConst CudaOne(cudaDataType type) {
|
||||
case CUDA_C_16F:
|
||||
c.u16[0] = 0b11110000000000; // 1.0 in little-endian float16
|
||||
break;
|
||||
#if JAX_CUSPARSE_11030
|
||||
#if JAX_CUSPARSE_11300
|
||||
case CUDA_R_16BF:
|
||||
case CUDA_C_16BF:
|
||||
c.u16[0] = 0b11111110000000; // 1.0 in little-endian bfloat16
|
||||
@ -142,7 +142,7 @@ CudaConst CudaOne(cudaDataType type) {
|
||||
return c;
|
||||
}
|
||||
|
||||
#if JAX_CUSPARSE_11030
|
||||
#if JAX_CUSPARSE_11300
|
||||
// CsrToDense: Convert CSR matrix to dense matrix
|
||||
|
||||
static absl::Status CsrToDense_(cudaStream_t stream, void** buffers,
|
||||
@ -536,7 +536,7 @@ void CooMatmat(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
#endif // if JAX_CUSPARSE_11030
|
||||
#endif // if JAX_CUSPARSE_11300
|
||||
|
||||
template <typename T, typename F>
|
||||
static absl::Status gtsv2(F computeGtsv2, cudaStream_t stream, void** buffers,
|
||||
|
@ -31,7 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
// Some functionality defined here is only available in CUSPARSE 11.3 or newer.
|
||||
#define JAX_CUSPARSE_11030 (CUSPARSE_VERSION >= 11300)
|
||||
#define JAX_CUSPARSE_11300 (CUSPARSE_VERSION >= 11300)
|
||||
|
||||
namespace jax {
|
||||
|
||||
@ -73,7 +73,7 @@ struct DenseVecDescriptor {
|
||||
int size;
|
||||
};
|
||||
|
||||
#if JAX_CUSPARSE_11030
|
||||
#if JAX_CUSPARSE_11300
|
||||
// CsrToDense: Convert CSR matrix to dense matrix
|
||||
|
||||
void CsrToDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
@ -137,7 +137,7 @@ struct CooMatmatDescriptor {
|
||||
|
||||
void CooMatmat(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
#endif // if JAX_CUSPARSE_11030
|
||||
#endif // if JAX_CUSPARSE_11300
|
||||
|
||||
struct Gtsv2Descriptor {
|
||||
int m, n, ldb;
|
||||
|
@ -43,7 +43,7 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevj", Syevj, "CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvd", Gesvd, "CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvdj", Gesvdj, "CUDA");
|
||||
|
||||
#if JAX_CUSPARSE_11030
|
||||
#if JAX_CUSPARSE_11300
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_todense", CsrToDense,
|
||||
"CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_fromdense", CsrFromDense,
|
||||
|
Loading…
x
Reference in New Issue
Block a user