Update Eigh kernel on GPU to use 64-bit interface when it is available.

Part of https://github.com/jax-ml/jax/issues/23413

PiperOrigin-RevId: 684546802
This commit is contained in:
Dan Foreman-Mackey 2024-10-10 12:58:53 -07:00 committed by jax authors
parent 9d44d72339
commit 6625a2b3ed
2 changed files with 183 additions and 81 deletions

View File

@ -21,6 +21,14 @@ limitations under the License.
#include <optional>
#include <string_view>
#if JAX_GPU_HAVE_64_BIT
#include <cstddef>
#endif
#ifdef JAX_GPU_CUDA
#include <limits>
#endif
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
@ -33,14 +41,6 @@ limitations under the License.
#include "jaxlib/gpu/vendor.h"
#include "xla/ffi/api/ffi.h"
#if JAX_GPU_64_BIT
#include <cstddef>
#endif
#ifdef JAX_GPU_CUDA
#include <limits>
#endif
#define JAX_FFI_RETURN_IF_GPU_ERROR(...) \
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(__VA_ARGS__))
@ -64,6 +64,28 @@ inline absl::StatusOr<T*> AllocateWorkspace(ffi::ScratchAllocator& scratch,
return static_cast<T*>(maybe_workspace.value());
}
#if JAX_GPU_HAVE_64_BIT
// Map an FFI buffer element type to the appropriate GPU solver type.
inline absl::StatusOr<gpuDataType> SolverDataType(ffi::DataType dataType,
std::string_view func) {
switch (dataType) {
case ffi::F32:
return GPU_R_32F;
case ffi::F64:
return GPU_R_64F;
case ffi::C64:
return GPU_C_32F;
case ffi::C128:
return GPU_C_64F;
default:
return absl::InvalidArgumentError(absl::StrFormat(
"Unsupported dtype %s in %s", absl::FormatStreamed(dataType), func));
}
}
#endif
#define SOLVER_DISPATCH_IMPL(impl, ...) \
switch (dataType) { \
case ffi::F32: \
@ -392,11 +414,74 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch,
// dispatches dynamically to both syevd and syevj depending on the problem
// size and the algorithm selected by the user via the `algorithm` attribute.
#if JAX_GPU_HAVE_64_BIT
ffi::Error Syevd64Impl(int64_t batch, int64_t n, gpuStream_t stream,
ffi::ScratchAllocator& scratch, bool lower,
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
ffi::Result<ffi::AnyBuffer> w,
ffi::Result<ffi::Buffer<ffi::S32>> info) {
FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream));
auto dataType = a.element_type();
FFI_ASSIGN_OR_RETURN(auto aType, SolverDataType(dataType, "syevd"));
FFI_ASSIGN_OR_RETURN(auto wType, SolverDataType(w->element_type(), "syevd"));
gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR;
gpusolverFillMode_t uplo =
lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER;
gpusolverDnParams_t params;
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateParams(&params));
std::unique_ptr<gpusolverDnParams, void (*)(gpusolverDnParams_t)>
params_cleanup(
params, [](gpusolverDnParams_t p) { gpusolverDnDestroyParams(p); });
size_t workspaceInBytesOnDevice, workspaceInBytesOnHost;
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevd_bufferSize(
handle.get(), params, jobz, uplo, n, aType, /*a=*/nullptr, n, wType,
/*w=*/nullptr, aType, &workspaceInBytesOnDevice,
&workspaceInBytesOnHost));
auto maybe_workspace = scratch.Allocate(workspaceInBytesOnDevice);
if (!maybe_workspace.has_value()) {
return ffi::Error(ffi::ErrorCode::kResourceExhausted,
"Unable to allocate device workspace for syevd");
}
auto workspaceOnDevice = maybe_workspace.value();
auto workspaceOnHost =
std::unique_ptr<char[]>(new char[workspaceInBytesOnHost]);
const char* a_data = static_cast<const char*>(a.untyped_data());
char* out_data = static_cast<char*>(out->untyped_data());
char* w_data = static_cast<char*>(w->untyped_data());
int* info_data = info->typed_data();
if (a_data != out_data) {
JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
}
size_t out_step = n * n * ffi::ByteWidth(dataType);
size_t w_step = n * ffi::ByteWidth(ffi::ToReal(dataType));
for (auto i = 0; i < batch; ++i) {
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevd(
handle.get(), params, jobz, uplo, n, aType, out_data, n, wType, w_data,
aType, workspaceOnDevice, workspaceInBytesOnDevice,
workspaceOnHost.get(), workspaceInBytesOnHost, info_data));
out_data += out_step;
w_data += w_step;
++info_data;
}
return ffi::Error::Success();
}
#endif
template <typename T>
ffi::Error SyevdImpl(int64_t batch, int64_t size, gpuStream_t stream,
ffi::ScratchAllocator& scratch, SyevdAlgorithm algorithm,
bool lower, ffi::AnyBuffer a,
ffi::Result<ffi::AnyBuffer> out,
ffi::ScratchAllocator& scratch, bool lower,
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
ffi::Result<ffi::AnyBuffer> w,
ffi::Result<ffi::Buffer<ffi::S32>> info) {
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(size));
@ -408,59 +493,84 @@ ffi::Error SyevdImpl(int64_t batch, int64_t size, gpuStream_t stream,
auto a_data = static_cast<T*>(a.untyped_data());
auto out_data = static_cast<T*>(out->untyped_data());
auto w_data = static_cast<typename solver::RealType<T>::value*>(w->untyped_data());
auto w_data =
static_cast<typename solver::RealType<T>::value*>(w->untyped_data());
auto info_data = info->typed_data();
if (a_data != out_data) {
JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
}
if (algorithm == SyevdAlgorithm::kJacobi ||
(algorithm == SyevdAlgorithm::kDefault && size <= 32)) {
gpuSyevjInfo_t params;
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateSyevjInfo(&params));
std::unique_ptr<gpuSyevjInfo, void (*)(gpuSyevjInfo_t)> params_cleanup(
params, [](gpuSyevjInfo_t p) { gpusolverDnDestroySyevjInfo(p); });
if (batch == 1) {
FFI_ASSIGN_OR_RETURN(int lwork, solver::SyevjBufferSize<T>(
handle.get(), jobz, uplo, n, params));
FFI_ASSIGN_OR_RETURN(auto workspace,
AllocateWorkspace<T>(scratch, lwork, "syevj"));
FFI_RETURN_IF_ERROR_STATUS(solver::Syevj<T>(handle.get(), jobz, uplo, n,
out_data, w_data, workspace,
lwork, info_data, params));
} else {
FFI_ASSIGN_OR_RETURN(
int lwork, solver::SyevjBatchedBufferSize<T>(handle.get(), jobz, uplo,
n, params, batch));
FFI_ASSIGN_OR_RETURN(
auto workspace,
AllocateWorkspace<T>(scratch, lwork, "syevj_batched"));
FFI_RETURN_IF_ERROR_STATUS(
solver::SyevjBatched<T>(handle.get(), jobz, uplo, n, out_data, w_data,
workspace, lwork, info_data, params, batch));
}
FFI_ASSIGN_OR_RETURN(int lwork,
solver::SyevdBufferSize<T>(handle.get(), jobz, uplo, n));
FFI_ASSIGN_OR_RETURN(auto workspace,
AllocateWorkspace<T>(scratch, lwork, "syevd"));
int out_step = n * n;
for (auto i = 0; i < batch; ++i) {
FFI_RETURN_IF_ERROR_STATUS(solver::Syevd<T>(handle.get(), jobz, uplo, n,
out_data, w_data, workspace,
lwork, info_data));
out_data += out_step;
w_data += n;
++info_data;
}
return ffi::Error::Success();
}
template <typename T>
ffi::Error SyevdjImpl(int64_t batch, int64_t size, gpuStream_t stream,
ffi::ScratchAllocator& scratch, bool lower,
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
ffi::Result<ffi::AnyBuffer> w,
ffi::Result<ffi::Buffer<ffi::S32>> info) {
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(size));
FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream));
gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR;
gpusolverFillMode_t uplo =
lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER;
auto a_data = static_cast<T*>(a.untyped_data());
auto out_data = static_cast<T*>(out->untyped_data());
auto w_data =
static_cast<typename solver::RealType<T>::value*>(w->untyped_data());
auto info_data = info->typed_data();
if (a_data != out_data) {
JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
}
gpuSyevjInfo_t params;
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateSyevjInfo(&params));
std::unique_ptr<gpuSyevjInfo, void (*)(gpuSyevjInfo_t)> params_cleanup(
params, [](gpuSyevjInfo_t p) { gpusolverDnDestroySyevjInfo(p); });
if (batch == 1) {
FFI_ASSIGN_OR_RETURN(int lwork, solver::SyevjBufferSize<T>(
handle.get(), jobz, uplo, n, params));
FFI_ASSIGN_OR_RETURN(auto workspace,
AllocateWorkspace<T>(scratch, lwork, "syevj"));
FFI_RETURN_IF_ERROR_STATUS(solver::Syevj<T>(handle.get(), jobz, uplo, n,
out_data, w_data, workspace,
lwork, info_data, params));
} else {
FFI_ASSIGN_OR_RETURN(
int lwork, solver::SyevdBufferSize<T>(handle.get(), jobz, uplo, n));
int lwork, solver::SyevjBatchedBufferSize<T>(handle.get(), jobz, uplo,
n, params, batch));
FFI_ASSIGN_OR_RETURN(auto workspace,
AllocateWorkspace<T>(scratch, lwork, "syevd"));
int out_step = n * n;
for (auto i = 0; i < batch; ++i) {
FFI_RETURN_IF_ERROR_STATUS(solver::Syevd<T>(handle.get(), jobz, uplo, n,
out_data, w_data, workspace,
lwork, info_data));
out_data += out_step;
w_data += n;
++info_data;
}
AllocateWorkspace<T>(scratch, lwork, "syevj_batched"));
FFI_RETURN_IF_ERROR_STATUS(
solver::SyevjBatched<T>(handle.get(), jobz, uplo, n, out_data, w_data,
workspace, lwork, info_data, params, batch));
}
return ffi::Error::Success();
}
ffi::Error SyevdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
SyevdAlgorithm algorithm, bool lower,
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
SyevdAlgorithm algorithm, bool lower, ffi::AnyBuffer a,
ffi::Result<ffi::AnyBuffer> out,
ffi::Result<ffi::AnyBuffer> w,
ffi::Result<ffi::Buffer<ffi::S32>> info) {
auto dataType = a.element_type();
@ -479,8 +589,18 @@ ffi::Error SyevdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
CheckShape(out->dimensions(), {batch, rows, cols}, "out", "syevd"));
FFI_RETURN_IF_ERROR(CheckShape(w->dimensions(), {batch, cols}, "w", "syevd"));
FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "syevd"));
SOLVER_DISPATCH_IMPL(SyevdImpl, batch, cols, stream, scratch, algorithm,
lower, a, out, w, info);
if (algorithm == SyevdAlgorithm::kJacobi ||
(algorithm == SyevdAlgorithm::kDefault && cols <= 32)) {
SOLVER_DISPATCH_IMPL(SyevdjImpl, batch, cols, stream, scratch, lower, a,
out, w, info);
} else {
#if JAX_GPU_HAVE_64_BIT
return Syevd64Impl(batch, cols, stream, scratch, lower, a, out, w, info);
#else
SOLVER_DISPATCH_IMPL(SyevdImpl, batch, cols, stream, scratch, lower, a, out,
w, info);
#endif
}
return ffi::Error::InvalidArgument(absl::StrFormat(
"Unsupported dtype %s in syevd", absl::FormatStreamed(dataType)));
}
@ -577,7 +697,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(SyrkFfi, SyrkDispatch,
// Singular Value Decomposition: gesvd
#if JAX_GPU_64_BIT
#if JAX_GPU_HAVE_64_BIT
ffi::Error Gesvd64Impl(int64_t batch, int64_t m, int64_t n, gpuStream_t stream,
ffi::ScratchAllocator& scratch, bool full_matrices,
@ -589,30 +709,9 @@ ffi::Error Gesvd64Impl(int64_t batch, int64_t m, int64_t n, gpuStream_t stream,
ffi::Result<ffi::Buffer<ffi::S32>> info) {
FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream));
signed char job = compute_uv ? (full_matrices ? 'A' : 'S') : 'N';
auto dataType = a.element_type();
gpuDataType aType, sType;
switch (dataType) {
case ffi::F32:
aType = GPU_R_32F;
sType = GPU_R_32F;
break;
case ffi::F64:
aType = GPU_R_64F;
sType = GPU_R_64F;
break;
case ffi::C64:
aType = GPU_C_32F;
sType = GPU_R_32F;
break;
case ffi::C128:
aType = GPU_C_64F;
sType = GPU_R_64F;
break;
default:
return ffi::Error::InvalidArgument(absl::StrFormat(
"Unsupported dtype %s in gesvd", absl::FormatStreamed(dataType)));
}
FFI_ASSIGN_OR_RETURN(auto aType, SolverDataType(dataType, "syevd"));
FFI_ASSIGN_OR_RETURN(auto sType, SolverDataType(s->element_type(), "syevd"));
gpusolverDnParams_t params;
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateParams(&params));
@ -692,7 +791,8 @@ ffi::Error GesvdImpl(int64_t batch, int64_t rows, int64_t cols,
AllocateWorkspace<T>(scratch, lwork, "gesvd"));
auto a_data = static_cast<T*>(a.untyped_data());
auto out_data = static_cast<T*>(out->untyped_data());
auto s_data = static_cast<typename solver::RealType<T>::value*>(s->untyped_data());
auto s_data =
static_cast<typename solver::RealType<T>::value*>(s->untyped_data());
auto u_data = compute_uv ? static_cast<T*>(u->untyped_data()) : nullptr;
auto vt_data = compute_uv ? static_cast<T*>(vt->untyped_data()) : nullptr;
auto info_data = info->typed_data();
@ -717,7 +817,7 @@ ffi::Error GesvdImpl(int64_t batch, int64_t rows, int64_t cols,
return ffi::Error::Success();
}
#endif // JAX_GPU_64_BIT
#endif // JAX_GPU_HAVE_64_BIT
ffi::Error GesvdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
bool full_matrices, bool compute_uv, bool transposed,
@ -763,7 +863,7 @@ ffi::Error GesvdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
}
FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "gesvd"));
#if JAX_GPU_64_BIT
#if JAX_GPU_HAVE_64_BIT
return Gesvd64Impl(batch, m, n, stream, scratch, full_matrices, compute_uv, a,
out, s, u, vt, info);
#else

View File

@ -332,7 +332,7 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
#define gpuGetDeviceProperties cudaGetDeviceProperties
#define gpuLaunchCooperativeKernel cudaLaunchCooperativeKernel
#define JAX_GPU_64_BIT 1
#define JAX_GPU_HAVE_64_BIT 1
#define GPU_R_32F CUDA_R_32F
#define GPU_R_64F CUDA_R_64F
@ -345,6 +345,8 @@ typedef cusolverDnParams_t gpusolverDnParams_t;
#define gpusolverDnCreateParams cusolverDnCreateParams
#define gpusolverDnDestroyParams cusolverDnDestroyParams
#define gpusolverDnXsyevd_bufferSize cusolverDnXsyevd_bufferSize
#define gpusolverDnXsyevd cusolverDnXsyevd
#define gpusolverDnXgesvd_bufferSize cusolverDnXgesvd_bufferSize
#define gpusolverDnXgesvd cusolverDnXgesvd
@ -368,7 +370,7 @@ constexpr uint32_t kNumThreadsPerWarp = 32;
#define JAX_GPU_PREFIX "hip"
#define JAX_GPU_HAVE_SPARSE 1
#define JAX_GPU_64_BIT 0
#define JAX_GPU_HAVE_64_BIT 0
#define JAX_GPU_HAVE_FP8 0
typedef hipFloatComplex gpuComplex;