Remove some code to support older CUDA and CUSPARSE versions.

The minimum CUDA version supported by JAX is CUDA 11.8, which ships with CUSPARSE 11.7.5.

PiperOrigin-RevId: 616892230
This commit is contained in:
Peter Hawkins 2024-03-18 11:24:14 -07:00 committed by jax authors
parent 0308aad076
commit c2bbf9c577

View File

@ -26,32 +26,21 @@ limitations under the License.
#include "third_party/gpus/cuda/include/cuComplex.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cublas_v2.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cuda.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cuda_fp8.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cuda_runtime_api.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cufft.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cusolverDn.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cusparse.h" // IWYU pragma: export
#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: export
// Some sparse functionality is only available in CUSPARSE 11.3 or newer.
#define JAX_GPU_HAVE_SPARSE (CUSPARSE_VERSION >= 11300)
#if CUDA_VERSION < 11080
#error "JAX requires CUDA 11.8 or newer."
#endif // CUDA_VERSION < 11080
#define JAX_GPU_HAVE_SPARSE 1
// CUDA-11.8 introduces FP8 E4M3/E5M2 types.
#define JAX_GPU_HAVE_FP8 (CUDA_VERSION >= 11080)
#if JAX_GPU_HAVE_FP8
#include "third_party/gpus/cuda/include/cuda_fp8.h"
#endif
// cuSPARSE generic APIs are not supported on Windows until 11.0
// cusparseIndexType_t is used in very limited scope so manually define will
// workaround compiling issue without harm.
#if defined(_WIN32) && (CUSPARSE_VERSION < 11000)
typedef enum {
CUSPARSE_INDEX_16U = 1,
CUSPARSE_INDEX_32I = 2,
CUSPARSE_INDEX_64I = 3
} cusparseIndexType_t;
#endif
#define JAX_GPU_HAVE_FP8 1
#define JAX_GPU_NAMESPACE cuda
#define JAX_GPU_PREFIX "cu"
@ -232,7 +221,6 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
// provide deterministic (bit-wise) results for each run. These indexing modes
// are fully supported (both row- and column-major inputs) in CUSPARSE 11.7.1
// and newer (which was released as part of CUDA 11.8)
#if CUSPARSE_VERSION > 11700
#define GPUSPARSE_SPMV_COO_ALG CUSPARSE_SPMV_COO_ALG2
#define GPUSPARSE_SPMV_CSR_ALG CUSPARSE_SPMV_CSR_ALG2
#define GPUSPARSE_SPMM_COO_ALG CUSPARSE_SPMM_COO_ALG2
@ -242,12 +230,7 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
// does not cover all cases and silently fell back to other algorithms for cases
// it did not cover. CUDA 12.2.1 removed the fallback behavior.
#define GPUSPARSE_SPMM_CSR_ALG CUSPARSE_SPMM_ALG_DEFAULT
#else
#define GPUSPARSE_SPMV_COO_ALG CUSPARSE_MV_ALG_DEFAULT
#define GPUSPARSE_SPMV_CSR_ALG CUSPARSE_MV_ALG_DEFAULT
#define GPUSPARSE_SPMM_COO_ALG CUSPARSE_SPMM_ALG_DEFAULT
#define GPUSPARSE_SPMM_CSR_ALG CUSPARSE_SPMM_ALG_DEFAULT
#endif
#define GPUSPARSE_OPERATION_NON_TRANSPOSE CUSPARSE_OPERATION_NON_TRANSPOSE
#define GPUSPARSE_OPERATION_TRANSPOSE CUSPARSE_OPERATION_TRANSPOSE
#define GPUSPARSE_ORDER_ROW CUSPARSE_ORDER_ROW
@ -542,7 +525,7 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
#define gpuThreadExchangeStreamCaptureMode hipThreadExchangeStreamCaptureMode
#define gpuStreamCreate hipStreamCreateWithFlags
#define gpuStreamDestroy hipStreamDestroy
#define gpuStreamIsCapturing hipStreamIsCapturing
#define gpuStreamIsCapturing hipStreamIsCapturing
#define GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR \
hipDeviceAttributeComputeCapabilityMajor