mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Remove some code to support older CUDA and CUSPARSE versions.
The minimum CUDA version supported by JAX is CUDA 11.8, which ships with CUSPARSE 11.7.5. PiperOrigin-RevId: 616892230
This commit is contained in:
parent
0308aad076
commit
c2bbf9c577
@ -26,32 +26,21 @@ limitations under the License.
|
||||
#include "third_party/gpus/cuda/include/cuComplex.h" // IWYU pragma: export
|
||||
#include "third_party/gpus/cuda/include/cublas_v2.h" // IWYU pragma: export
|
||||
#include "third_party/gpus/cuda/include/cuda.h" // IWYU pragma: export
|
||||
#include "third_party/gpus/cuda/include/cuda_fp8.h" // IWYU pragma: export
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h" // IWYU pragma: export
|
||||
#include "third_party/gpus/cuda/include/cufft.h" // IWYU pragma: export
|
||||
#include "third_party/gpus/cuda/include/cusolverDn.h" // IWYU pragma: export
|
||||
#include "third_party/gpus/cuda/include/cusparse.h" // IWYU pragma: export
|
||||
#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: export
|
||||
|
||||
// Some sparse functionality is only available in CUSPARSE 11.3 or newer.
|
||||
#define JAX_GPU_HAVE_SPARSE (CUSPARSE_VERSION >= 11300)
|
||||
#if CUDA_VERSION < 11080
|
||||
#error "JAX requires CUDA 11.8 or newer."
|
||||
#endif // CUDA_VERSION < 11080
|
||||
|
||||
#define JAX_GPU_HAVE_SPARSE 1
|
||||
|
||||
// CUDA-11.8 introduces FP8 E4M3/E5M2 types.
|
||||
#define JAX_GPU_HAVE_FP8 (CUDA_VERSION >= 11080)
|
||||
|
||||
#if JAX_GPU_HAVE_FP8
|
||||
#include "third_party/gpus/cuda/include/cuda_fp8.h"
|
||||
#endif
|
||||
|
||||
// cuSPARSE generic APIs are not supported on Windows until 11.0
|
||||
// cusparseIndexType_t is used in very limited scope so manually define will
|
||||
// workaround compiling issue without harm.
|
||||
#if defined(_WIN32) && (CUSPARSE_VERSION < 11000)
|
||||
typedef enum {
|
||||
CUSPARSE_INDEX_16U = 1,
|
||||
CUSPARSE_INDEX_32I = 2,
|
||||
CUSPARSE_INDEX_64I = 3
|
||||
} cusparseIndexType_t;
|
||||
#endif
|
||||
#define JAX_GPU_HAVE_FP8 1
|
||||
|
||||
#define JAX_GPU_NAMESPACE cuda
|
||||
#define JAX_GPU_PREFIX "cu"
|
||||
@ -232,7 +221,6 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
|
||||
// provide deterministic (bit-wise) results for each run. These indexing modes
|
||||
// are fully supported (both row- and column-major inputs) in CUSPARSE 11.7.1
|
||||
// and newer (which was released as part of CUDA 11.8)
|
||||
#if CUSPARSE_VERSION > 11700
|
||||
#define GPUSPARSE_SPMV_COO_ALG CUSPARSE_SPMV_COO_ALG2
|
||||
#define GPUSPARSE_SPMV_CSR_ALG CUSPARSE_SPMV_CSR_ALG2
|
||||
#define GPUSPARSE_SPMM_COO_ALG CUSPARSE_SPMM_COO_ALG2
|
||||
@ -242,12 +230,7 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
|
||||
// does not cover all cases and silently fell back to other algorithms for cases
|
||||
// it did not cover. CUDA 12.2.1 removed the fallback behavior.
|
||||
#define GPUSPARSE_SPMM_CSR_ALG CUSPARSE_SPMM_ALG_DEFAULT
|
||||
#else
|
||||
#define GPUSPARSE_SPMV_COO_ALG CUSPARSE_MV_ALG_DEFAULT
|
||||
#define GPUSPARSE_SPMV_CSR_ALG CUSPARSE_MV_ALG_DEFAULT
|
||||
#define GPUSPARSE_SPMM_COO_ALG CUSPARSE_SPMM_ALG_DEFAULT
|
||||
#define GPUSPARSE_SPMM_CSR_ALG CUSPARSE_SPMM_ALG_DEFAULT
|
||||
#endif
|
||||
|
||||
#define GPUSPARSE_OPERATION_NON_TRANSPOSE CUSPARSE_OPERATION_NON_TRANSPOSE
|
||||
#define GPUSPARSE_OPERATION_TRANSPOSE CUSPARSE_OPERATION_TRANSPOSE
|
||||
#define GPUSPARSE_ORDER_ROW CUSPARSE_ORDER_ROW
|
||||
@ -542,7 +525,7 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
|
||||
#define gpuThreadExchangeStreamCaptureMode hipThreadExchangeStreamCaptureMode
|
||||
#define gpuStreamCreate hipStreamCreateWithFlags
|
||||
#define gpuStreamDestroy hipStreamDestroy
|
||||
#define gpuStreamIsCapturing hipStreamIsCapturing
|
||||
#define gpuStreamIsCapturing hipStreamIsCapturing
|
||||
|
||||
#define GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR \
|
||||
hipDeviceAttributeComputeCapabilityMajor
|
||||
|
Loading…
x
Reference in New Issue
Block a user