mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fix a number of minor problems in the ROCM build.
Change in preparation for adding more presubmits for AMD ROCM. PiperOrigin-RevId: 667766343
This commit is contained in:
parent
9027101737
commit
45b871950e
@ -295,6 +295,7 @@ cc_library(
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
@ -323,6 +324,7 @@ pybind_extension(
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
":cusparse_kernels",
|
||||
"//jaxlib:absl_status_casters",
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base",
|
||||
|
@ -40,7 +40,8 @@ absl::Status CholeskyUpdateImpl(gpuStream_t stream, void** buffers,
|
||||
auto s = UnpackDescriptor<CholeskyUpdateDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const CholeskyUpdateDescriptor& d = **s;
|
||||
LaunchCholeskyUpdateKernel(stream, buffers, d);
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(LaunchCholeskyUpdateKernel(stream, buffers, d)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
@ -98,8 +99,8 @@ ffi::Error CholeskyUpdateFfiImpl(gpuStream_t stream, ffi::AnyBuffer matrix_in,
|
||||
gpuMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
for (auto n = 0; n < batch; ++n) {
|
||||
LaunchCholeskyUpdateFfiKernel(stream, matrix, vector, size,
|
||||
is_single_precision);
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(LaunchCholeskyUpdateFfiKernel(
|
||||
stream, matrix, vector, size, is_single_precision)));
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError()));
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
|
@ -67,8 +67,9 @@ __global__ void CholeskyUpdateKernel(T* rMatrix, T* uVector, int nSize) {
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void LaunchCholeskyUpdateKernelBody(gpuStream_t stream, void** buffers,
|
||||
int grid_dim, int block_dim, int nSize) {
|
||||
gpuError_t LaunchCholeskyUpdateKernelBody(gpuStream_t stream, void** buffers,
|
||||
int grid_dim, int block_dim,
|
||||
int nSize) {
|
||||
T* rMatrix = reinterpret_cast<T*>(buffers[2]);
|
||||
T* uVector = reinterpret_cast<T*>(buffers[3]);
|
||||
|
||||
@ -77,37 +78,38 @@ void LaunchCholeskyUpdateKernelBody(gpuStream_t stream, void** buffers,
|
||||
reinterpret_cast<void*>(&uVector),
|
||||
reinterpret_cast<void*>(&nSize),
|
||||
};
|
||||
gpuLaunchCooperativeKernel((void*)CholeskyUpdateKernel<T>, grid_dim,
|
||||
return gpuLaunchCooperativeKernel((void*)CholeskyUpdateKernel<T>, grid_dim,
|
||||
block_dim, arg_ptrs,
|
||||
/*dynamic_shared_mem_bytes=*/0, stream);
|
||||
}
|
||||
|
||||
void LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers,
|
||||
gpuError_t LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers,
|
||||
CholeskyUpdateDescriptor descriptor) {
|
||||
int nSize = descriptor.matrix_size;
|
||||
LinalgType type = descriptor.linalg_type;
|
||||
|
||||
int dev = 0;
|
||||
gpuDeviceProp deviceProp;
|
||||
gpuGetDeviceProperties(&deviceProp, dev);
|
||||
gpuError_t err = gpuGetDeviceProperties(&deviceProp, dev);
|
||||
if (err != gpuSuccess) {
|
||||
return err;
|
||||
}
|
||||
|
||||
int block_dim = deviceProp.maxThreadsPerBlock;
|
||||
int grid_dim = deviceProp.multiProcessorCount;
|
||||
|
||||
switch (type) {
|
||||
case LinalgType::F64:
|
||||
LaunchCholeskyUpdateKernelBody<double>(stream, buffers, grid_dim,
|
||||
return LaunchCholeskyUpdateKernelBody<double>(stream, buffers, grid_dim,
|
||||
block_dim, nSize);
|
||||
break;
|
||||
case LinalgType::F32:
|
||||
LaunchCholeskyUpdateKernelBody<float>(stream, buffers, grid_dim,
|
||||
return LaunchCholeskyUpdateKernelBody<float>(stream, buffers, grid_dim,
|
||||
block_dim, nSize);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LaunchCholeskyUpdateFfiKernelBody(gpuStream_t stream, void* matrix,
|
||||
gpuError_t LaunchCholeskyUpdateFfiKernelBody(gpuStream_t stream, void* matrix,
|
||||
void* vector, int grid_dim,
|
||||
int block_dim, int nSize) {
|
||||
T* rMatrix = reinterpret_cast<T*>(matrix);
|
||||
@ -118,26 +120,30 @@ void LaunchCholeskyUpdateFfiKernelBody(gpuStream_t stream, void* matrix,
|
||||
reinterpret_cast<void*>(&uVector),
|
||||
reinterpret_cast<void*>(&nSize),
|
||||
};
|
||||
gpuLaunchCooperativeKernel((void*)CholeskyUpdateKernel<T>, grid_dim,
|
||||
return gpuLaunchCooperativeKernel((void*)CholeskyUpdateKernel<T>, grid_dim,
|
||||
block_dim, arg_ptrs,
|
||||
/*dynamic_shared_mem_bytes=*/0, stream);
|
||||
}
|
||||
|
||||
void LaunchCholeskyUpdateFfiKernel(gpuStream_t stream, void* matrix,
|
||||
gpuError_t LaunchCholeskyUpdateFfiKernel(gpuStream_t stream, void* matrix,
|
||||
void* vector, int size,
|
||||
bool is_single_precision) {
|
||||
int dev = 0;
|
||||
gpuDeviceProp deviceProp;
|
||||
gpuGetDeviceProperties(&deviceProp, dev);
|
||||
gpuError_t err = gpuGetDeviceProperties(&deviceProp, dev);
|
||||
if (err != gpuSuccess) {
|
||||
return err;
|
||||
}
|
||||
|
||||
int block_dim = deviceProp.maxThreadsPerBlock;
|
||||
int grid_dim = deviceProp.multiProcessorCount;
|
||||
|
||||
if (is_single_precision) {
|
||||
LaunchCholeskyUpdateFfiKernelBody<float>(stream, matrix, vector, grid_dim,
|
||||
block_dim, size);
|
||||
return LaunchCholeskyUpdateFfiKernelBody<float>(stream, matrix, vector,
|
||||
grid_dim, block_dim, size);
|
||||
} else {
|
||||
LaunchCholeskyUpdateFfiKernelBody<double>(stream, matrix, vector, grid_dim,
|
||||
block_dim, size);
|
||||
return LaunchCholeskyUpdateFfiKernelBody<double>(stream, matrix, vector,
|
||||
grid_dim, block_dim, size);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -36,13 +36,13 @@ struct CholeskyUpdateDescriptor {
|
||||
std::int64_t matrix_size; // leading dim (N) for a square (NxN)matrix
|
||||
};
|
||||
|
||||
void LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers,
|
||||
gpuError_t LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers,
|
||||
CholeskyUpdateDescriptor descriptor);
|
||||
|
||||
void CholeskyUpdate(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
void LaunchCholeskyUpdateFfiKernel(gpuStream_t stream, void* matrix,
|
||||
gpuError_t LaunchCholeskyUpdateFfiKernel(gpuStream_t stream, void* matrix,
|
||||
void* vector, int size,
|
||||
bool is_single_precision);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(CholeskyUpdateFfi);
|
||||
|
@ -23,8 +23,8 @@ limitations under the License.
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "jaxlib/gpu/solver_handle_pool.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/solver_handle_pool.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
@ -421,9 +421,9 @@ static absl::Status Syevd_(gpuStream_t stream, void** buffers,
|
||||
int output_idx = 1; // with static shapes buffers[1] is the first output
|
||||
if (d.batch == -1) {
|
||||
// the batch is passed as a second operand
|
||||
gpuMemcpyAsync((void*)&batch,
|
||||
reinterpret_cast<const std::int64_t*>(buffers[1]),
|
||||
sizeof(batch), gpuMemcpyDeviceToHost, stream);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync(
|
||||
(void*)&batch, reinterpret_cast<const std::int64_t*>(buffers[1]),
|
||||
sizeof(batch), gpuMemcpyDeviceToHost, stream)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
|
||||
output_idx = 2;
|
||||
}
|
||||
|
@ -61,6 +61,17 @@ inline absl::StatusOr<T*> AllocateWorkspace(ffi::ScratchAllocator& scratch,
|
||||
return impl<gpuDoubleComplex>(__VA_ARGS__); \
|
||||
}
|
||||
|
||||
#define SOLVER_BLAS_DISPATCH_IMPL(impl, ...) \
|
||||
if (dataType == ffi::F32) { \
|
||||
return impl<float>(__VA_ARGS__); \
|
||||
} else if (dataType == ffi::F64) { \
|
||||
return impl<double>(__VA_ARGS__); \
|
||||
} else if (dataType == ffi::C64) { \
|
||||
return impl<gpublasComplex>(__VA_ARGS__); \
|
||||
} else if (dataType == ffi::C128) { \
|
||||
return impl<gpublasDoubleComplex>(__VA_ARGS__); \
|
||||
}
|
||||
|
||||
// LU decomposition: getrf
|
||||
|
||||
namespace {
|
||||
@ -189,8 +200,8 @@ ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
|
||||
ipiv->dimensions(), {batch, std::min(rows, cols)}, "ipiv", "getrf"));
|
||||
FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "getrf"));
|
||||
if (batch > 1 && rows == cols && rows / batch <= 128) {
|
||||
SOLVER_DISPATCH_IMPL(GetrfBatchedImpl, batch, cols, stream, scratch, a, out,
|
||||
ipiv, info);
|
||||
SOLVER_BLAS_DISPATCH_IMPL(GetrfBatchedImpl, batch, cols, stream, scratch, a,
|
||||
out, ipiv, info);
|
||||
} else {
|
||||
SOLVER_DISPATCH_IMPL(GetrfImpl, batch, rows, cols, stream, scratch, a, out,
|
||||
ipiv, info);
|
||||
@ -345,8 +356,8 @@ ffi::Error GeqrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
|
||||
FFI_RETURN_IF_ERROR(CheckShape(
|
||||
tau->dimensions(), {batch, std::min(rows, cols)}, "tau", "geqrf"));
|
||||
if (batch > 1 && rows / batch <= 128 && cols / batch <= 128) {
|
||||
SOLVER_DISPATCH_IMPL(GeqrfBatchedImpl, batch, rows, cols, stream, scratch,
|
||||
a, out, tau);
|
||||
SOLVER_BLAS_DISPATCH_IMPL(GeqrfBatchedImpl, batch, rows, cols, stream,
|
||||
scratch, a, out, tau);
|
||||
} else {
|
||||
SOLVER_DISPATCH_IMPL(GeqrfImpl, batch, rows, cols, stream, scratch, a, out,
|
||||
tau);
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "jaxlib/absl_status_casters.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/sparse_kernels.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
@ -57,12 +58,17 @@ gpusparseIndexType_t DtypeToCuSparseIndexType(const dtype& np_type) {
|
||||
gpuDataType DtypeToCudaDataType(const dtype& np_type) {
|
||||
static auto* types =
|
||||
new absl::flat_hash_map<std::pair<char, int>, gpuDataType>({
|
||||
{{'f', 2}, GPU_R_16F}, {{'c', 4}, GPU_C_16F}, {{'f', 4}, GPU_R_32F},
|
||||
{{'c', 8}, GPU_C_32F}, {{'f', 8}, GPU_R_64F},
|
||||
{{'f', 2}, GPU_R_16F},
|
||||
{{'c', 4}, GPU_C_16F},
|
||||
{{'f', 4}, GPU_R_32F},
|
||||
{{'c', 8}, GPU_C_32F},
|
||||
{{'f', 8}, GPU_R_64F},
|
||||
{{'c', 16}, GPU_C_64F},
|
||||
#ifdef JAX_GPU_CUDA
|
||||
{{'i', 1}, CUDA_R_8I}, {{'u', 1}, CUDA_R_8U},
|
||||
{{'i', 4}, CUDA_R_32I}, {{'u', 4}, CUDA_R_32U},
|
||||
{{'i', 1}, CUDA_R_8I},
|
||||
{{'u', 1}, CUDA_R_8U},
|
||||
{{'i', 4}, CUDA_R_32I},
|
||||
{{'u', 4}, CUDA_R_32U},
|
||||
#if JAX_GPU_HAVE_SPARSE
|
||||
{{'V', 2}, CUDA_R_16BF},
|
||||
#endif // JAX_GPU_HAVE_SPARSE
|
||||
@ -78,9 +84,8 @@ gpuDataType DtypeToCudaDataType(const dtype& np_type) {
|
||||
}
|
||||
// Returns the descriptor for a Sparse matrix.
|
||||
SparseMatDescriptor BuildSparseMatDescriptor(const dtype& data_dtype,
|
||||
const dtype& index_dtype,
|
||||
int rows, int cols, int nnz,
|
||||
int batch_count,
|
||||
const dtype& index_dtype, int rows,
|
||||
int cols, int nnz, int batch_count,
|
||||
int batch_stride) {
|
||||
gpuDataType value_type = DtypeToCudaDataType(data_dtype);
|
||||
gpusparseIndexType_t index_type = DtypeToCuSparseIndexType(index_dtype);
|
||||
@ -89,16 +94,15 @@ SparseMatDescriptor BuildSparseMatDescriptor(const dtype& data_dtype,
|
||||
}
|
||||
|
||||
// Returns the descriptor for a Dense matrix.
|
||||
DenseMatDescriptor BuildDenseMatDescriptor(const dtype& data_dtype,
|
||||
int rows, int cols, int batch_count,
|
||||
DenseMatDescriptor BuildDenseMatDescriptor(const dtype& data_dtype, int rows,
|
||||
int cols, int batch_count,
|
||||
int batch_stride) {
|
||||
gpuDataType value_type = DtypeToCudaDataType(data_dtype);
|
||||
return DenseMatDescriptor{value_type, rows, cols, batch_count, batch_stride};
|
||||
}
|
||||
|
||||
// Returns the descriptor for a Dense vector.
|
||||
DenseVecDescriptor BuildDenseVecDescriptor(const dtype& data_dtype,
|
||||
int size) {
|
||||
DenseVecDescriptor BuildDenseVecDescriptor(const dtype& data_dtype, int size) {
|
||||
gpuDataType value_type = DtypeToCudaDataType(data_dtype);
|
||||
return DenseVecDescriptor{value_type, size};
|
||||
}
|
||||
@ -107,9 +111,10 @@ DenseVecDescriptor BuildDenseVecDescriptor(const dtype& data_dtype,
|
||||
// CsrToDense: Convert CSR matrix to dense matrix
|
||||
|
||||
// Returns the descriptor for a Sparse matrix.
|
||||
std::pair<size_t, nb::bytes> BuildCsrToDenseDescriptor(
|
||||
const dtype& data_dtype, const dtype& index_dtype, int rows,
|
||||
int cols, int nnz) {
|
||||
std::pair<size_t, nb::bytes> BuildCsrToDenseDescriptor(const dtype& data_dtype,
|
||||
const dtype& index_dtype,
|
||||
int rows, int cols,
|
||||
int nnz) {
|
||||
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
@ -185,8 +190,8 @@ void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
|
||||
// Returns the descriptor for a CsrFromDense operation.
|
||||
std::pair<size_t, nb::bytes> BuildCsrFromDenseDescriptor(
|
||||
const dtype& data_dtype, const dtype& index_dtype, int rows,
|
||||
int cols, int nnz) {
|
||||
const dtype& data_dtype, const dtype& index_dtype, int rows, int cols,
|
||||
int nnz) {
|
||||
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
@ -261,9 +266,8 @@ void CsrFromDense(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
|
||||
// Returns the descriptor for a CsrMatvec operation.
|
||||
std::pair<size_t, nb::bytes> BuildCsrMatvecDescriptor(
|
||||
const dtype& data_dtype, const dtype& x_dtype,
|
||||
const dtype& compute_dtype, const dtype& index_dtype, int rows,
|
||||
int cols, int nnz, bool transpose) {
|
||||
const dtype& data_dtype, const dtype& x_dtype, const dtype& compute_dtype,
|
||||
const dtype& index_dtype, int rows, int cols, int nnz, bool transpose) {
|
||||
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
@ -292,7 +296,7 @@ std::pair<size_t, nb::bytes> BuildCsrMatvecDescriptor(
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(gpusparseCreateDnVec(&vec_y, y.size, empty, y.type)));
|
||||
size_t buffer_size;
|
||||
SparseConst alpha = ConstOne(y.type);
|
||||
SparseConst alpha = ValueOrThrow(ConstOne(y.type));
|
||||
SparseConst beta = ConstZero(y.type);
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMV_bufferSize(
|
||||
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
|
||||
@ -309,9 +313,9 @@ std::pair<size_t, nb::bytes> BuildCsrMatvecDescriptor(
|
||||
|
||||
// Returns the descriptor for a CsrMatmat operation.
|
||||
std::pair<size_t, nb::bytes> BuildCsrMatmatDescriptor(
|
||||
const dtype& data_dtype, const dtype& b_dtype,
|
||||
const dtype& compute_dtype, const dtype& index_dtype, int rows,
|
||||
int cols, int BCcols, int nnz, bool transpose) {
|
||||
const dtype& data_dtype, const dtype& b_dtype, const dtype& compute_dtype,
|
||||
const dtype& index_dtype, int rows, int cols, int BCcols, int nnz,
|
||||
bool transpose) {
|
||||
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
@ -344,7 +348,7 @@ std::pair<size_t, nb::bytes> BuildCsrMatmatDescriptor(
|
||||
JAX_AS_STATUS(gpusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
|
||||
empty, C.type, GPUSPARSE_ORDER_ROW)));
|
||||
size_t buffer_size;
|
||||
SparseConst alpha = ConstOne(C.type);
|
||||
SparseConst alpha = ValueOrThrow(ConstOne(C.type));
|
||||
SparseConst beta = ConstZero(C.type);
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM_bufferSize(
|
||||
handle.get(), op_A, GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
|
||||
@ -360,9 +364,10 @@ std::pair<size_t, nb::bytes> BuildCsrMatmatDescriptor(
|
||||
// CooToDense: Convert COO matrix to dense matrix
|
||||
|
||||
// Returns the descriptor for a CooToDense operation.
|
||||
std::pair<size_t, nb::bytes> BuildCooToDenseDescriptor(
|
||||
const dtype& data_dtype, const dtype& index_dtype, int rows,
|
||||
int cols, int nnz) {
|
||||
std::pair<size_t, nb::bytes> BuildCooToDenseDescriptor(const dtype& data_dtype,
|
||||
const dtype& index_dtype,
|
||||
int rows, int cols,
|
||||
int nnz) {
|
||||
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
@ -398,8 +403,8 @@ std::pair<size_t, nb::bytes> BuildCooToDenseDescriptor(
|
||||
|
||||
// Returns the descriptor for a CooFromDense operation.
|
||||
std::pair<size_t, nb::bytes> BuildCooFromDenseDescriptor(
|
||||
const dtype& data_dtype, const dtype& index_dtype, int rows,
|
||||
int cols, int nnz) {
|
||||
const dtype& data_dtype, const dtype& index_dtype, int rows, int cols,
|
||||
int nnz) {
|
||||
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
@ -434,9 +439,8 @@ std::pair<size_t, nb::bytes> BuildCooFromDenseDescriptor(
|
||||
|
||||
// Returns the descriptor for a CooMatvec operation.
|
||||
std::pair<size_t, nb::bytes> BuildCooMatvecDescriptor(
|
||||
const dtype& data_dtype, const dtype& x_dtype,
|
||||
const dtype& compute_dtype, const dtype& index_dtype, int rows,
|
||||
int cols, int nnz, bool transpose) {
|
||||
const dtype& data_dtype, const dtype& x_dtype, const dtype& compute_dtype,
|
||||
const dtype& index_dtype, int rows, int cols, int nnz, bool transpose) {
|
||||
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
@ -465,7 +469,7 @@ std::pair<size_t, nb::bytes> BuildCooMatvecDescriptor(
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(gpusparseCreateDnVec(&vec_y, y.size, empty, y.type)));
|
||||
size_t buffer_size;
|
||||
SparseConst alpha = ConstOne(y.type);
|
||||
SparseConst alpha = ValueOrThrow(ConstOne(y.type));
|
||||
SparseConst beta = ConstZero(y.type);
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMV_bufferSize(
|
||||
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
|
||||
@ -482,10 +486,10 @@ std::pair<size_t, nb::bytes> BuildCooMatvecDescriptor(
|
||||
|
||||
// Returns the descriptor for a CooMatmat operation.
|
||||
std::pair<size_t, nb::bytes> BuildCooMatmatDescriptor(
|
||||
const dtype& data_dtype, const dtype& b_dtype,
|
||||
const dtype& compute_dtype, const dtype& index_dtype, int rows,
|
||||
int cols, int BCcols, int nnz, bool transpose, int batch_count,
|
||||
int lhs_batch_stride, int rhs_batch_stride) {
|
||||
const dtype& data_dtype, const dtype& b_dtype, const dtype& compute_dtype,
|
||||
const dtype& index_dtype, int rows, int cols, int BCcols, int nnz,
|
||||
bool transpose, int batch_count, int lhs_batch_stride,
|
||||
int rhs_batch_stride) {
|
||||
// Three batch modes are supported, C_i = A_i B, C_i = A B_i, and
|
||||
// Ci = A_i B_i, where `i` denotes the batch dimension.
|
||||
// All three matrices A, B, and C must have the same batch count.
|
||||
@ -535,7 +539,7 @@ std::pair<size_t, nb::bytes> BuildCooMatmatDescriptor(
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDnMatSetStridedBatch(
|
||||
mat_c, /*batchCount=*/batch_count, /*batchStride=*/C.batch_stride)));
|
||||
size_t buffer_size;
|
||||
SparseConst alpha = ConstOne(C.type);
|
||||
SparseConst alpha = ValueOrThrow(ConstOne(C.type));
|
||||
SparseConst beta = ConstZero(C.type);
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM_bufferSize(
|
||||
handle.get(), op_A, GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
@ -58,7 +59,7 @@ SparseConst ConstZero(gpuDataType type) {
|
||||
return c;
|
||||
}
|
||||
|
||||
SparseConst ConstOne(gpuDataType type) {
|
||||
absl::StatusOr<SparseConst> ConstOne(gpuDataType type) {
|
||||
SparseConst c;
|
||||
std::memset(&c, 0, sizeof(c));
|
||||
switch (type) {
|
||||
@ -138,6 +139,9 @@ SparseConst ConstOne(gpuDataType type) {
|
||||
case GPU_C_64F:
|
||||
c.f64[0] = 1.0;
|
||||
break;
|
||||
default:
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Unsupported data type: ", type));
|
||||
}
|
||||
return c;
|
||||
}
|
||||
@ -248,7 +252,7 @@ static absl::Status CsrMatvec_(gpuStream_t stream, void** buffers,
|
||||
// are sufficient for basic matvec operations.
|
||||
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
|
||||
// or else the operation will segfault.
|
||||
SparseConst alpha = ConstOne(d.y.type);
|
||||
JAX_ASSIGN_OR_RETURN(SparseConst alpha, ConstOne(d.y.type));
|
||||
SparseConst beta = ConstZero(d.y.type);
|
||||
|
||||
gpusparseSpMatDescr_t mat_a = 0;
|
||||
@ -305,7 +309,7 @@ static absl::Status CsrMatmat_(gpuStream_t stream, void** buffers,
|
||||
// are sufficient for basic matvec operations.
|
||||
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
|
||||
// or else the operation will segfault.
|
||||
SparseConst alpha = ConstOne(d.C.type);
|
||||
JAX_ASSIGN_OR_RETURN(SparseConst alpha, ConstOne(d.C.type));
|
||||
SparseConst beta = ConstZero(d.C.type);
|
||||
|
||||
gpusparseSpMatDescr_t mat_a = 0;
|
||||
@ -446,7 +450,7 @@ static absl::Status CooMatvec_(gpuStream_t stream, void** buffers,
|
||||
// are sufficient for basic matvec operations.
|
||||
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
|
||||
// or else the operation will segfault.
|
||||
SparseConst alpha = ConstOne(d.y.type);
|
||||
JAX_ASSIGN_OR_RETURN(SparseConst alpha, ConstOne(d.y.type));
|
||||
SparseConst beta = ConstZero(d.y.type);
|
||||
|
||||
gpusparseSpMatDescr_t mat_a = 0;
|
||||
@ -502,7 +506,7 @@ static absl::Status CooMatmat_(gpuStream_t stream, void** buffers,
|
||||
// are sufficient for basic matvec operations.
|
||||
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
|
||||
// or else the operation will segfault.
|
||||
SparseConst alpha = ConstOne(d.C.type);
|
||||
JAX_ASSIGN_OR_RETURN(SparseConst alpha, ConstOne(d.C.type));
|
||||
SparseConst beta = ConstZero(d.C.type);
|
||||
|
||||
gpusparseSpMatDescr_t mat_a = 0;
|
||||
@ -581,8 +585,8 @@ static absl::Status gtsv2(F computeGtsv2, gpuStream_t stream, void** buffers,
|
||||
gpuMemcpyAsync(X, B, B_bytes, gpuMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
for (int i = 0; i < batch; ++i) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(computeGtsv2(
|
||||
handle.get(), m, n, dl, d, du, X, ldb, buffer)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
computeGtsv2(handle.get(), m, n, dl, d, du, X, ldb, buffer)));
|
||||
dl += m;
|
||||
d += m;
|
||||
du += m;
|
||||
|
@ -51,7 +51,7 @@ union SparseConst {
|
||||
};
|
||||
|
||||
SparseConst ConstZero(gpuDataType type);
|
||||
SparseConst ConstOne(gpuDataType type);
|
||||
absl::StatusOr<SparseConst> ConstOne(gpuDataType type);
|
||||
|
||||
struct SparseMatDescriptor {
|
||||
gpuDataType value_type;
|
||||
|
@ -34,7 +34,11 @@
|
||||
|
||||
#ifdef JAX_GPU_CUDA
|
||||
#include "xla/stream_executor/cuda/cuda_asm_compiler.h"
|
||||
#endif
|
||||
#endif // JAX_GPU_CUDA
|
||||
|
||||
#ifdef JAX_GPU_HIP
|
||||
#include "tsl/platform/env.h"
|
||||
#endif // JAX_GPU_HIP
|
||||
|
||||
#define GPU_RETURN_IF_ERROR(expr) JAX_RETURN_IF_ERROR(JAX_AS_STATUS(expr))
|
||||
|
||||
@ -44,7 +48,12 @@ namespace {
|
||||
constexpr float kBenchmarkTimeMillis = 10.;
|
||||
|
||||
struct gpuModuleDeleter {
|
||||
void operator()(gpuModule_t module) { gpuModuleUnload(module); }
|
||||
void operator()(gpuModule_t module) {
|
||||
absl::Status status = JAX_AS_STATUS(gpuModuleUnload(module));
|
||||
if (!status.ok()) {
|
||||
LOG(WARNING) << "Failed to unload GPU module: " << status;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using OwnedGPUmodule =
|
||||
@ -52,11 +61,11 @@ using OwnedGPUmodule =
|
||||
|
||||
absl::StatusOr<gpuDevice_t> GetStreamDevice(gpuStream_t stream) {
|
||||
gpuDevice_t device;
|
||||
gpuContext_t context;
|
||||
#ifdef JAX_GPU_HIP
|
||||
int device_id = gpuGetStreamDeviceId(stream);
|
||||
GPU_RETURN_IF_ERROR(gpuDeviceGet(&device, device_id));
|
||||
#else // JAX_GPU_CUDA
|
||||
gpuContext_t context;
|
||||
GPU_RETURN_IF_ERROR(gpuStreamGetCtx(stream, &context));
|
||||
GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context));
|
||||
absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); };
|
||||
@ -210,7 +219,12 @@ class ModuleImage {
|
||||
}
|
||||
|
||||
GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context));
|
||||
absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); };
|
||||
absl::Cleanup ctx_restorer = [] {
|
||||
absl::Status status = JAX_AS_STATUS(gpuCtxPopCurrent(nullptr));
|
||||
if (!status.ok()) {
|
||||
LOG(WARNING) << "Failed to pop GPU context: " << status;
|
||||
}
|
||||
};
|
||||
|
||||
gpuModule_t module;
|
||||
GPU_RETURN_IF_ERROR(gpuModuleLoadData(&module, module_image_.data()));
|
||||
|
@ -22,18 +22,20 @@ limitations under the License.
|
||||
|
||||
#if defined(JAX_GPU_CUDA)
|
||||
|
||||
#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" // IWYU pragma: export
|
||||
#include "third_party/gpus/cuda/include/cooperative_groups.h" // IWYU pragma: export
|
||||
#include "third_party/gpus/cuda/include/cuComplex.h" // IWYU pragma: export
|
||||
#include "third_party/gpus/cuda/include/cublas_v2.h" // IWYU pragma: export
|
||||
#include "third_party/gpus/cuda/include/cuda.h" // IWYU pragma: export
|
||||
#include "third_party/gpus/cuda/include/cuda_fp8.h" // IWYU pragma: export
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h" // IWYU pragma: export
|
||||
#include "third_party/gpus/cuda/include/cufft.h" // IWYU pragma: export
|
||||
#include "third_party/gpus/cuda/include/cusolverDn.h" // IWYU pragma: export
|
||||
#include "third_party/gpus/cuda/include/cusolver_common.h" // IWYU pragma: export
|
||||
#include "third_party/gpus/cuda/include/cusparse.h" // IWYU pragma: export
|
||||
#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: export
|
||||
// IWYU pragma: begin_exports
|
||||
#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h"
|
||||
#include "third_party/gpus/cuda/include/cooperative_groups.h"
|
||||
#include "third_party/gpus/cuda/include/cuComplex.h"
|
||||
#include "third_party/gpus/cuda/include/cublas_v2.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_fp8.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "third_party/gpus/cuda/include/cufft.h"
|
||||
#include "third_party/gpus/cuda/include/cusolverDn.h"
|
||||
#include "third_party/gpus/cuda/include/cusolver_common.h"
|
||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||
#include "third_party/gpus/cudnn/cudnn.h"
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
#if CUDA_VERSION < 11080
|
||||
#error "JAX requires CUDA 11.8 or newer."
|
||||
@ -305,11 +307,13 @@ constexpr uint32_t kNumThreadsPerWarp = 32;
|
||||
|
||||
#elif defined(JAX_GPU_HIP)
|
||||
|
||||
#include "rocm/include/hip/amd_detail/amd_hip_cooperative_groups.h"
|
||||
// IWYU pragma: begin_exports
|
||||
#include "rocm/include/hip/hip_cooperative_groups.h"
|
||||
#include "rocm/include/hip/hip_runtime_api.h"
|
||||
#include "rocm/include/hipblas/hipblas.h"
|
||||
#include "rocm/include/hipsolver/hipsolver.h"
|
||||
#include "rocm/include/hipsparse/hipsparse.h"
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
#define JAX_GPU_NAMESPACE hip
|
||||
#define JAX_GPU_PREFIX "hip"
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
# AMD HIP kernels
|
||||
|
||||
load("@rules_python//python:defs.bzl", "py_library")
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"if_rocm_is_configured",
|
||||
@ -23,7 +24,10 @@ load(
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//:__subpackages__"])
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = ["//:__subpackages__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hip_vendor",
|
||||
@ -203,6 +207,7 @@ pybind_extension(
|
||||
":hipsolver_kernels_ffi",
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@local_config_rocm//rocm:hipblas",
|
||||
"@local_config_rocm//rocm:hipsolver",
|
||||
@ -223,6 +228,7 @@ cc_library(
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@local_config_rocm//rocm:hipsparse",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
@ -243,6 +249,7 @@ pybind_extension(
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_vendor",
|
||||
":hipsparse_kernels",
|
||||
"//jaxlib:absl_status_casters",
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base",
|
||||
@ -277,6 +284,7 @@ cc_library(
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
],
|
||||
)
|
||||
|
||||
@ -376,14 +384,15 @@ cc_library(
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/cleanup",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/log:check",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@tsl//tsl/platform:env",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@xla//xla/stream_executor/gpu:asm_compiler",
|
||||
"@xla//xla/tsl/util:env_var",
|
||||
],
|
||||
)
|
@ -142,8 +142,8 @@ NB_MODULE(rocm_plugin_extension, m) {
|
||||
HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
|
||||
reinterpret_cast<hipDeviceptr_t>(data_ptr));
|
||||
if (result != hipSuccess) {
|
||||
LOG(FATAL) << "Not able to get the device_ordinal for ptr: " << data_ptr
|
||||
<< ". Error: " << ToString(result);
|
||||
LOG(FATAL) << "Not able to get the device_ordinal for ptr: "
|
||||
<< data_ptr << ". Error: " << ToString(result);
|
||||
}
|
||||
return device_ordinal;
|
||||
},
|
||||
|
Loading…
x
Reference in New Issue
Block a user