Fix a number of minor problems in the ROCM build.

Change in preparation for adding more presubmits for AMD ROCM.

PiperOrigin-RevId: 667766343
This commit is contained in:
Peter Hawkins 2024-08-26 17:03:27 -07:00 committed by jax authors
parent 9027101737
commit 45b871950e
13 changed files with 173 additions and 118 deletions

View File

@ -295,6 +295,7 @@ cc_library(
"//jaxlib:kernel_helpers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@local_config_cuda//cuda:cuda_headers",
"@xla//xla/service:custom_call_status",
@ -323,6 +324,7 @@ pybind_extension(
":cuda_gpu_kernel_helpers",
":cuda_vendor",
":cusparse_kernels",
"//jaxlib:absl_status_casters",
"//jaxlib:kernel_nanobind_helpers",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",

View File

@ -40,7 +40,8 @@ absl::Status CholeskyUpdateImpl(gpuStream_t stream, void** buffers,
auto s = UnpackDescriptor<CholeskyUpdateDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const CholeskyUpdateDescriptor& d = **s;
LaunchCholeskyUpdateKernel(stream, buffers, d);
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(LaunchCholeskyUpdateKernel(stream, buffers, d)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError()));
return absl::OkStatus();
}
@ -98,8 +99,8 @@ ffi::Error CholeskyUpdateFfiImpl(gpuStream_t stream, ffi::AnyBuffer matrix_in,
gpuMemcpyDeviceToDevice, stream)));
}
for (auto n = 0; n < batch; ++n) {
LaunchCholeskyUpdateFfiKernel(stream, matrix, vector, size,
is_single_precision);
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(LaunchCholeskyUpdateFfiKernel(
stream, matrix, vector, size, is_single_precision)));
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError()));
}
return ffi::Error::Success();

View File

@ -67,8 +67,9 @@ __global__ void CholeskyUpdateKernel(T* rMatrix, T* uVector, int nSize) {
} // namespace
template <typename T>
void LaunchCholeskyUpdateKernelBody(gpuStream_t stream, void** buffers,
int grid_dim, int block_dim, int nSize) {
gpuError_t LaunchCholeskyUpdateKernelBody(gpuStream_t stream, void** buffers,
int grid_dim, int block_dim,
int nSize) {
T* rMatrix = reinterpret_cast<T*>(buffers[2]);
T* uVector = reinterpret_cast<T*>(buffers[3]);
@ -77,37 +78,38 @@ void LaunchCholeskyUpdateKernelBody(gpuStream_t stream, void** buffers,
reinterpret_cast<void*>(&uVector),
reinterpret_cast<void*>(&nSize),
};
gpuLaunchCooperativeKernel((void*)CholeskyUpdateKernel<T>, grid_dim,
return gpuLaunchCooperativeKernel((void*)CholeskyUpdateKernel<T>, grid_dim,
block_dim, arg_ptrs,
/*dynamic_shared_mem_bytes=*/0, stream);
}
void LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers,
gpuError_t LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers,
CholeskyUpdateDescriptor descriptor) {
int nSize = descriptor.matrix_size;
LinalgType type = descriptor.linalg_type;
int dev = 0;
gpuDeviceProp deviceProp;
gpuGetDeviceProperties(&deviceProp, dev);
gpuError_t err = gpuGetDeviceProperties(&deviceProp, dev);
if (err != gpuSuccess) {
return err;
}
int block_dim = deviceProp.maxThreadsPerBlock;
int grid_dim = deviceProp.multiProcessorCount;
switch (type) {
case LinalgType::F64:
LaunchCholeskyUpdateKernelBody<double>(stream, buffers, grid_dim,
return LaunchCholeskyUpdateKernelBody<double>(stream, buffers, grid_dim,
block_dim, nSize);
break;
case LinalgType::F32:
LaunchCholeskyUpdateKernelBody<float>(stream, buffers, grid_dim,
return LaunchCholeskyUpdateKernelBody<float>(stream, buffers, grid_dim,
block_dim, nSize);
break;
}
}
template <typename T>
void LaunchCholeskyUpdateFfiKernelBody(gpuStream_t stream, void* matrix,
gpuError_t LaunchCholeskyUpdateFfiKernelBody(gpuStream_t stream, void* matrix,
void* vector, int grid_dim,
int block_dim, int nSize) {
T* rMatrix = reinterpret_cast<T*>(matrix);
@ -118,26 +120,30 @@ void LaunchCholeskyUpdateFfiKernelBody(gpuStream_t stream, void* matrix,
reinterpret_cast<void*>(&uVector),
reinterpret_cast<void*>(&nSize),
};
gpuLaunchCooperativeKernel((void*)CholeskyUpdateKernel<T>, grid_dim,
return gpuLaunchCooperativeKernel((void*)CholeskyUpdateKernel<T>, grid_dim,
block_dim, arg_ptrs,
/*dynamic_shared_mem_bytes=*/0, stream);
}
void LaunchCholeskyUpdateFfiKernel(gpuStream_t stream, void* matrix,
gpuError_t LaunchCholeskyUpdateFfiKernel(gpuStream_t stream, void* matrix,
void* vector, int size,
bool is_single_precision) {
int dev = 0;
gpuDeviceProp deviceProp;
gpuGetDeviceProperties(&deviceProp, dev);
gpuError_t err = gpuGetDeviceProperties(&deviceProp, dev);
if (err != gpuSuccess) {
return err;
}
int block_dim = deviceProp.maxThreadsPerBlock;
int grid_dim = deviceProp.multiProcessorCount;
if (is_single_precision) {
LaunchCholeskyUpdateFfiKernelBody<float>(stream, matrix, vector, grid_dim,
block_dim, size);
return LaunchCholeskyUpdateFfiKernelBody<float>(stream, matrix, vector,
grid_dim, block_dim, size);
} else {
LaunchCholeskyUpdateFfiKernelBody<double>(stream, matrix, vector, grid_dim,
block_dim, size);
return LaunchCholeskyUpdateFfiKernelBody<double>(stream, matrix, vector,
grid_dim, block_dim, size);
}
}

View File

@ -36,13 +36,13 @@ struct CholeskyUpdateDescriptor {
std::int64_t matrix_size; // leading dim (N) for a square (NxN)matrix
};
void LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers,
gpuError_t LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers,
CholeskyUpdateDescriptor descriptor);
void CholeskyUpdate(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
void LaunchCholeskyUpdateFfiKernel(gpuStream_t stream, void* matrix,
gpuError_t LaunchCholeskyUpdateFfiKernel(gpuStream_t stream, void* matrix,
void* vector, int size,
bool is_single_precision);
XLA_FFI_DECLARE_HANDLER_SYMBOL(CholeskyUpdateFfi);

View File

@ -23,8 +23,8 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "jaxlib/gpu/solver_handle_pool.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/solver_handle_pool.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/kernel_helpers.h"
#include "xla/service/custom_call_status.h"
@ -421,9 +421,9 @@ static absl::Status Syevd_(gpuStream_t stream, void** buffers,
int output_idx = 1; // with static shapes buffers[1] is the first output
if (d.batch == -1) {
// the batch is passed as a second operand
gpuMemcpyAsync((void*)&batch,
reinterpret_cast<const std::int64_t*>(buffers[1]),
sizeof(batch), gpuMemcpyDeviceToHost, stream);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync(
(void*)&batch, reinterpret_cast<const std::int64_t*>(buffers[1]),
sizeof(batch), gpuMemcpyDeviceToHost, stream)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
output_idx = 2;
}

View File

@ -61,6 +61,17 @@ inline absl::StatusOr<T*> AllocateWorkspace(ffi::ScratchAllocator& scratch,
return impl<gpuDoubleComplex>(__VA_ARGS__); \
}
#define SOLVER_BLAS_DISPATCH_IMPL(impl, ...) \
if (dataType == ffi::F32) { \
return impl<float>(__VA_ARGS__); \
} else if (dataType == ffi::F64) { \
return impl<double>(__VA_ARGS__); \
} else if (dataType == ffi::C64) { \
return impl<gpublasComplex>(__VA_ARGS__); \
} else if (dataType == ffi::C128) { \
return impl<gpublasDoubleComplex>(__VA_ARGS__); \
}
// LU decomposition: getrf
namespace {
@ -189,8 +200,8 @@ ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
ipiv->dimensions(), {batch, std::min(rows, cols)}, "ipiv", "getrf"));
FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "getrf"));
if (batch > 1 && rows == cols && rows / batch <= 128) {
SOLVER_DISPATCH_IMPL(GetrfBatchedImpl, batch, cols, stream, scratch, a, out,
ipiv, info);
SOLVER_BLAS_DISPATCH_IMPL(GetrfBatchedImpl, batch, cols, stream, scratch, a,
out, ipiv, info);
} else {
SOLVER_DISPATCH_IMPL(GetrfImpl, batch, rows, cols, stream, scratch, a, out,
ipiv, info);
@ -345,8 +356,8 @@ ffi::Error GeqrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
FFI_RETURN_IF_ERROR(CheckShape(
tau->dimensions(), {batch, std::min(rows, cols)}, "tau", "geqrf"));
if (batch > 1 && rows / batch <= 128 && cols / batch <= 128) {
SOLVER_DISPATCH_IMPL(GeqrfBatchedImpl, batch, rows, cols, stream, scratch,
a, out, tau);
SOLVER_BLAS_DISPATCH_IMPL(GeqrfBatchedImpl, batch, rows, cols, stream,
scratch, a, out, tau);
} else {
SOLVER_DISPATCH_IMPL(GeqrfImpl, batch, rows, cols, stream, scratch, a, out,
tau);

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "jaxlib/absl_status_casters.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/sparse_kernels.h"
#include "jaxlib/gpu/vendor.h"
@ -57,12 +58,17 @@ gpusparseIndexType_t DtypeToCuSparseIndexType(const dtype& np_type) {
gpuDataType DtypeToCudaDataType(const dtype& np_type) {
static auto* types =
new absl::flat_hash_map<std::pair<char, int>, gpuDataType>({
{{'f', 2}, GPU_R_16F}, {{'c', 4}, GPU_C_16F}, {{'f', 4}, GPU_R_32F},
{{'c', 8}, GPU_C_32F}, {{'f', 8}, GPU_R_64F},
{{'f', 2}, GPU_R_16F},
{{'c', 4}, GPU_C_16F},
{{'f', 4}, GPU_R_32F},
{{'c', 8}, GPU_C_32F},
{{'f', 8}, GPU_R_64F},
{{'c', 16}, GPU_C_64F},
#ifdef JAX_GPU_CUDA
{{'i', 1}, CUDA_R_8I}, {{'u', 1}, CUDA_R_8U},
{{'i', 4}, CUDA_R_32I}, {{'u', 4}, CUDA_R_32U},
{{'i', 1}, CUDA_R_8I},
{{'u', 1}, CUDA_R_8U},
{{'i', 4}, CUDA_R_32I},
{{'u', 4}, CUDA_R_32U},
#if JAX_GPU_HAVE_SPARSE
{{'V', 2}, CUDA_R_16BF},
#endif // JAX_GPU_HAVE_SPARSE
@ -78,9 +84,8 @@ gpuDataType DtypeToCudaDataType(const dtype& np_type) {
}
// Returns the descriptor for a Sparse matrix.
SparseMatDescriptor BuildSparseMatDescriptor(const dtype& data_dtype,
const dtype& index_dtype,
int rows, int cols, int nnz,
int batch_count,
const dtype& index_dtype, int rows,
int cols, int nnz, int batch_count,
int batch_stride) {
gpuDataType value_type = DtypeToCudaDataType(data_dtype);
gpusparseIndexType_t index_type = DtypeToCuSparseIndexType(index_dtype);
@ -89,16 +94,15 @@ SparseMatDescriptor BuildSparseMatDescriptor(const dtype& data_dtype,
}
// Returns the descriptor for a Dense matrix.
DenseMatDescriptor BuildDenseMatDescriptor(const dtype& data_dtype,
int rows, int cols, int batch_count,
DenseMatDescriptor BuildDenseMatDescriptor(const dtype& data_dtype, int rows,
int cols, int batch_count,
int batch_stride) {
gpuDataType value_type = DtypeToCudaDataType(data_dtype);
return DenseMatDescriptor{value_type, rows, cols, batch_count, batch_stride};
}
// Returns the descriptor for a Dense vector.
DenseVecDescriptor BuildDenseVecDescriptor(const dtype& data_dtype,
int size) {
DenseVecDescriptor BuildDenseVecDescriptor(const dtype& data_dtype, int size) {
gpuDataType value_type = DtypeToCudaDataType(data_dtype);
return DenseVecDescriptor{value_type, size};
}
@ -107,9 +111,10 @@ DenseVecDescriptor BuildDenseVecDescriptor(const dtype& data_dtype,
// CsrToDense: Convert CSR matrix to dense matrix
// Returns the descriptor for a Sparse matrix.
std::pair<size_t, nb::bytes> BuildCsrToDenseDescriptor(
const dtype& data_dtype, const dtype& index_dtype, int rows,
int cols, int nnz) {
std::pair<size_t, nb::bytes> BuildCsrToDenseDescriptor(const dtype& data_dtype,
const dtype& index_dtype,
int rows, int cols,
int nnz) {
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
@ -185,8 +190,8 @@ void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque,
// Returns the descriptor for a CsrFromDense operation.
std::pair<size_t, nb::bytes> BuildCsrFromDenseDescriptor(
const dtype& data_dtype, const dtype& index_dtype, int rows,
int cols, int nnz) {
const dtype& data_dtype, const dtype& index_dtype, int rows, int cols,
int nnz) {
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
@ -261,9 +266,8 @@ void CsrFromDense(gpuStream_t stream, void** buffers, const char* opaque,
// Returns the descriptor for a CsrMatvec operation.
std::pair<size_t, nb::bytes> BuildCsrMatvecDescriptor(
const dtype& data_dtype, const dtype& x_dtype,
const dtype& compute_dtype, const dtype& index_dtype, int rows,
int cols, int nnz, bool transpose) {
const dtype& data_dtype, const dtype& x_dtype, const dtype& compute_dtype,
const dtype& index_dtype, int rows, int cols, int nnz, bool transpose) {
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
@ -292,7 +296,7 @@ std::pair<size_t, nb::bytes> BuildCsrMatvecDescriptor(
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(gpusparseCreateDnVec(&vec_y, y.size, empty, y.type)));
size_t buffer_size;
SparseConst alpha = ConstOne(y.type);
SparseConst alpha = ValueOrThrow(ConstOne(y.type));
SparseConst beta = ConstZero(y.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMV_bufferSize(
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
@ -309,9 +313,9 @@ std::pair<size_t, nb::bytes> BuildCsrMatvecDescriptor(
// Returns the descriptor for a CsrMatmat operation.
std::pair<size_t, nb::bytes> BuildCsrMatmatDescriptor(
const dtype& data_dtype, const dtype& b_dtype,
const dtype& compute_dtype, const dtype& index_dtype, int rows,
int cols, int BCcols, int nnz, bool transpose) {
const dtype& data_dtype, const dtype& b_dtype, const dtype& compute_dtype,
const dtype& index_dtype, int rows, int cols, int BCcols, int nnz,
bool transpose) {
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
@ -344,7 +348,7 @@ std::pair<size_t, nb::bytes> BuildCsrMatmatDescriptor(
JAX_AS_STATUS(gpusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
empty, C.type, GPUSPARSE_ORDER_ROW)));
size_t buffer_size;
SparseConst alpha = ConstOne(C.type);
SparseConst alpha = ValueOrThrow(ConstOne(C.type));
SparseConst beta = ConstZero(C.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM_bufferSize(
handle.get(), op_A, GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
@ -360,9 +364,10 @@ std::pair<size_t, nb::bytes> BuildCsrMatmatDescriptor(
// CooToDense: Convert COO matrix to dense matrix
// Returns the descriptor for a CooToDense operation.
std::pair<size_t, nb::bytes> BuildCooToDenseDescriptor(
const dtype& data_dtype, const dtype& index_dtype, int rows,
int cols, int nnz) {
std::pair<size_t, nb::bytes> BuildCooToDenseDescriptor(const dtype& data_dtype,
const dtype& index_dtype,
int rows, int cols,
int nnz) {
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
@ -398,8 +403,8 @@ std::pair<size_t, nb::bytes> BuildCooToDenseDescriptor(
// Returns the descriptor for a CooFromDense operation.
std::pair<size_t, nb::bytes> BuildCooFromDenseDescriptor(
const dtype& data_dtype, const dtype& index_dtype, int rows,
int cols, int nnz) {
const dtype& data_dtype, const dtype& index_dtype, int rows, int cols,
int nnz) {
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
@ -434,9 +439,8 @@ std::pair<size_t, nb::bytes> BuildCooFromDenseDescriptor(
// Returns the descriptor for a CooMatvec operation.
std::pair<size_t, nb::bytes> BuildCooMatvecDescriptor(
const dtype& data_dtype, const dtype& x_dtype,
const dtype& compute_dtype, const dtype& index_dtype, int rows,
int cols, int nnz, bool transpose) {
const dtype& data_dtype, const dtype& x_dtype, const dtype& compute_dtype,
const dtype& index_dtype, int rows, int cols, int nnz, bool transpose) {
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
@ -465,7 +469,7 @@ std::pair<size_t, nb::bytes> BuildCooMatvecDescriptor(
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(gpusparseCreateDnVec(&vec_y, y.size, empty, y.type)));
size_t buffer_size;
SparseConst alpha = ConstOne(y.type);
SparseConst alpha = ValueOrThrow(ConstOne(y.type));
SparseConst beta = ConstZero(y.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMV_bufferSize(
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
@ -482,10 +486,10 @@ std::pair<size_t, nb::bytes> BuildCooMatvecDescriptor(
// Returns the descriptor for a CooMatmat operation.
std::pair<size_t, nb::bytes> BuildCooMatmatDescriptor(
const dtype& data_dtype, const dtype& b_dtype,
const dtype& compute_dtype, const dtype& index_dtype, int rows,
int cols, int BCcols, int nnz, bool transpose, int batch_count,
int lhs_batch_stride, int rhs_batch_stride) {
const dtype& data_dtype, const dtype& b_dtype, const dtype& compute_dtype,
const dtype& index_dtype, int rows, int cols, int BCcols, int nnz,
bool transpose, int batch_count, int lhs_batch_stride,
int rhs_batch_stride) {
// Three batch modes are supported, C_i = A_i B, C_i = A B_i, and
// Ci = A_i B_i, where `i` denotes the batch dimension.
// All three matrices A, B, and C must have the same batch count.
@ -535,7 +539,7 @@ std::pair<size_t, nb::bytes> BuildCooMatmatDescriptor(
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDnMatSetStridedBatch(
mat_c, /*batchCount=*/batch_count, /*batchStride=*/C.batch_stride)));
size_t buffer_size;
SparseConst alpha = ConstOne(C.type);
SparseConst alpha = ValueOrThrow(ConstOne(C.type));
SparseConst beta = ConstZero(C.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM_bufferSize(
handle.get(), op_A, GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/synchronization/mutex.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/vendor.h"
@ -58,7 +59,7 @@ SparseConst ConstZero(gpuDataType type) {
return c;
}
SparseConst ConstOne(gpuDataType type) {
absl::StatusOr<SparseConst> ConstOne(gpuDataType type) {
SparseConst c;
std::memset(&c, 0, sizeof(c));
switch (type) {
@ -138,6 +139,9 @@ SparseConst ConstOne(gpuDataType type) {
case GPU_C_64F:
c.f64[0] = 1.0;
break;
default:
return absl::InvalidArgumentError(
absl::StrCat("Unsupported data type: ", type));
}
return c;
}
@ -248,7 +252,7 @@ static absl::Status CsrMatvec_(gpuStream_t stream, void** buffers,
// are sufficient for basic matvec operations.
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
// or else the operation will segfault.
SparseConst alpha = ConstOne(d.y.type);
JAX_ASSIGN_OR_RETURN(SparseConst alpha, ConstOne(d.y.type));
SparseConst beta = ConstZero(d.y.type);
gpusparseSpMatDescr_t mat_a = 0;
@ -305,7 +309,7 @@ static absl::Status CsrMatmat_(gpuStream_t stream, void** buffers,
// are sufficient for basic matvec operations.
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
// or else the operation will segfault.
SparseConst alpha = ConstOne(d.C.type);
JAX_ASSIGN_OR_RETURN(SparseConst alpha, ConstOne(d.C.type));
SparseConst beta = ConstZero(d.C.type);
gpusparseSpMatDescr_t mat_a = 0;
@ -446,7 +450,7 @@ static absl::Status CooMatvec_(gpuStream_t stream, void** buffers,
// are sufficient for basic matvec operations.
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
// or else the operation will segfault.
SparseConst alpha = ConstOne(d.y.type);
JAX_ASSIGN_OR_RETURN(SparseConst alpha, ConstOne(d.y.type));
SparseConst beta = ConstZero(d.y.type);
gpusparseSpMatDescr_t mat_a = 0;
@ -502,7 +506,7 @@ static absl::Status CooMatmat_(gpuStream_t stream, void** buffers,
// are sufficient for basic matvec operations.
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
// or else the operation will segfault.
SparseConst alpha = ConstOne(d.C.type);
JAX_ASSIGN_OR_RETURN(SparseConst alpha, ConstOne(d.C.type));
SparseConst beta = ConstZero(d.C.type);
gpusparseSpMatDescr_t mat_a = 0;
@ -581,8 +585,8 @@ static absl::Status gtsv2(F computeGtsv2, gpuStream_t stream, void** buffers,
gpuMemcpyAsync(X, B, B_bytes, gpuMemcpyDeviceToDevice, stream)));
}
for (int i = 0; i < batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(computeGtsv2(
handle.get(), m, n, dl, d, du, X, ldb, buffer)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
computeGtsv2(handle.get(), m, n, dl, d, du, X, ldb, buffer)));
dl += m;
d += m;
du += m;

View File

@ -51,7 +51,7 @@ union SparseConst {
};
SparseConst ConstZero(gpuDataType type);
SparseConst ConstOne(gpuDataType type);
absl::StatusOr<SparseConst> ConstOne(gpuDataType type);
struct SparseMatDescriptor {
gpuDataType value_type;

View File

@ -34,7 +34,11 @@
#ifdef JAX_GPU_CUDA
#include "xla/stream_executor/cuda/cuda_asm_compiler.h"
#endif
#endif // JAX_GPU_CUDA
#ifdef JAX_GPU_HIP
#include "tsl/platform/env.h"
#endif // JAX_GPU_HIP
#define GPU_RETURN_IF_ERROR(expr) JAX_RETURN_IF_ERROR(JAX_AS_STATUS(expr))
@ -44,7 +48,12 @@ namespace {
constexpr float kBenchmarkTimeMillis = 10.;
struct gpuModuleDeleter {
void operator()(gpuModule_t module) { gpuModuleUnload(module); }
void operator()(gpuModule_t module) {
absl::Status status = JAX_AS_STATUS(gpuModuleUnload(module));
if (!status.ok()) {
LOG(WARNING) << "Failed to unload GPU module: " << status;
}
}
};
using OwnedGPUmodule =
@ -52,11 +61,11 @@ using OwnedGPUmodule =
absl::StatusOr<gpuDevice_t> GetStreamDevice(gpuStream_t stream) {
gpuDevice_t device;
gpuContext_t context;
#ifdef JAX_GPU_HIP
int device_id = gpuGetStreamDeviceId(stream);
GPU_RETURN_IF_ERROR(gpuDeviceGet(&device, device_id));
#else // JAX_GPU_CUDA
gpuContext_t context;
GPU_RETURN_IF_ERROR(gpuStreamGetCtx(stream, &context));
GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context));
absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); };
@ -210,7 +219,12 @@ class ModuleImage {
}
GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context));
absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); };
absl::Cleanup ctx_restorer = [] {
absl::Status status = JAX_AS_STATUS(gpuCtxPopCurrent(nullptr));
if (!status.ok()) {
LOG(WARNING) << "Failed to pop GPU context: " << status;
}
};
gpuModule_t module;
GPU_RETURN_IF_ERROR(gpuModuleLoadData(&module, module_image_.data()));

View File

@ -22,18 +22,20 @@ limitations under the License.
#if defined(JAX_GPU_CUDA)
#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cooperative_groups.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cuComplex.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cublas_v2.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cuda.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cuda_fp8.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cuda_runtime_api.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cufft.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cusolverDn.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cusolver_common.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cusparse.h" // IWYU pragma: export
#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: export
// IWYU pragma: begin_exports
#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h"
#include "third_party/gpus/cuda/include/cooperative_groups.h"
#include "third_party/gpus/cuda/include/cuComplex.h"
#include "third_party/gpus/cuda/include/cublas_v2.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cuda_fp8.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "third_party/gpus/cuda/include/cufft.h"
#include "third_party/gpus/cuda/include/cusolverDn.h"
#include "third_party/gpus/cuda/include/cusolver_common.h"
#include "third_party/gpus/cuda/include/cusparse.h"
#include "third_party/gpus/cudnn/cudnn.h"
// IWYU pragma: end_exports
#if CUDA_VERSION < 11080
#error "JAX requires CUDA 11.8 or newer."
@ -305,11 +307,13 @@ constexpr uint32_t kNumThreadsPerWarp = 32;
#elif defined(JAX_GPU_HIP)
#include "rocm/include/hip/amd_detail/amd_hip_cooperative_groups.h"
// IWYU pragma: begin_exports
#include "rocm/include/hip/hip_cooperative_groups.h"
#include "rocm/include/hip/hip_runtime_api.h"
#include "rocm/include/hipblas/hipblas.h"
#include "rocm/include/hipsolver/hipsolver.h"
#include "rocm/include/hipsparse/hipsparse.h"
// IWYU pragma: end_exports
#define JAX_GPU_NAMESPACE hip
#define JAX_GPU_PREFIX "hip"

View File

@ -14,6 +14,7 @@
# AMD HIP kernels
load("@rules_python//python:defs.bzl", "py_library")
load(
"//jaxlib:jax.bzl",
"if_rocm_is_configured",
@ -23,7 +24,10 @@ load(
licenses(["notice"])
package(default_visibility = ["//:__subpackages__"])
package(
default_applicable_licenses = [],
default_visibility = ["//:__subpackages__"],
)
cc_library(
name = "hip_vendor",
@ -203,6 +207,7 @@ pybind_extension(
":hipsolver_kernels_ffi",
"//jaxlib:kernel_nanobind_helpers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@local_config_rocm//rocm:hipblas",
"@local_config_rocm//rocm:hipsolver",
@ -223,6 +228,7 @@ cc_library(
"//jaxlib:kernel_helpers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@local_config_rocm//rocm:hipsparse",
"@local_config_rocm//rocm:rocm_headers",
@ -243,6 +249,7 @@ pybind_extension(
":hip_gpu_kernel_helpers",
":hip_vendor",
":hipsparse_kernels",
"//jaxlib:absl_status_casters",
"//jaxlib:kernel_nanobind_helpers",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
@ -277,6 +284,7 @@ cc_library(
"@local_config_rocm//rocm:rocm_headers",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@xla//xla/service:custom_call_status",
],
)
@ -376,14 +384,15 @@ cc_library(
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@tsl//tsl/platform:env",
"@xla//xla/service:custom_call_status",
"@xla//xla/stream_executor/gpu:asm_compiler",
"@xla//xla/tsl/util:env_var",
],
)

View File

@ -142,8 +142,8 @@ NB_MODULE(rocm_plugin_extension, m) {
HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
reinterpret_cast<hipDeviceptr_t>(data_ptr));
if (result != hipSuccess) {
LOG(FATAL) << "Not able to get the device_ordinal for ptr: " << data_ptr
<< ". Error: " << ToString(result);
LOG(FATAL) << "Not able to get the device_ordinal for ptr: "
<< data_ptr << ". Error: " << ToString(result);
}
return device_ordinal;
},