mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Update Eigh kernel on GPU to use 64-bit interface when it is available.
Part of https://github.com/jax-ml/jax/issues/23413 PiperOrigin-RevId: 684546802
This commit is contained in:
parent
9d44d72339
commit
6625a2b3ed
@ -21,6 +21,14 @@ limitations under the License.
|
||||
#include <optional>
|
||||
#include <string_view>
|
||||
|
||||
#if JAX_GPU_HAVE_64_BIT
|
||||
#include <cstddef>
|
||||
#endif
|
||||
|
||||
#ifdef JAX_GPU_CUDA
|
||||
#include <limits>
|
||||
#endif
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
@ -33,14 +41,6 @@ limitations under the License.
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
|
||||
#if JAX_GPU_64_BIT
|
||||
#include <cstddef>
|
||||
#endif
|
||||
|
||||
#ifdef JAX_GPU_CUDA
|
||||
#include <limits>
|
||||
#endif
|
||||
|
||||
#define JAX_FFI_RETURN_IF_GPU_ERROR(...) \
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(__VA_ARGS__))
|
||||
|
||||
@ -64,6 +64,28 @@ inline absl::StatusOr<T*> AllocateWorkspace(ffi::ScratchAllocator& scratch,
|
||||
return static_cast<T*>(maybe_workspace.value());
|
||||
}
|
||||
|
||||
#if JAX_GPU_HAVE_64_BIT
|
||||
|
||||
// Map an FFI buffer element type to the appropriate GPU solver type.
|
||||
inline absl::StatusOr<gpuDataType> SolverDataType(ffi::DataType dataType,
|
||||
std::string_view func) {
|
||||
switch (dataType) {
|
||||
case ffi::F32:
|
||||
return GPU_R_32F;
|
||||
case ffi::F64:
|
||||
return GPU_R_64F;
|
||||
case ffi::C64:
|
||||
return GPU_C_32F;
|
||||
case ffi::C128:
|
||||
return GPU_C_64F;
|
||||
default:
|
||||
return absl::InvalidArgumentError(absl::StrFormat(
|
||||
"Unsupported dtype %s in %s", absl::FormatStreamed(dataType), func));
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#define SOLVER_DISPATCH_IMPL(impl, ...) \
|
||||
switch (dataType) { \
|
||||
case ffi::F32: \
|
||||
@ -392,11 +414,74 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch,
|
||||
// dispatches dynamically to both syevd and syevj depending on the problem
|
||||
// size and the algorithm selected by the user via the `algorithm` attribute.
|
||||
|
||||
#if JAX_GPU_HAVE_64_BIT
|
||||
|
||||
ffi::Error Syevd64Impl(int64_t batch, int64_t n, gpuStream_t stream,
|
||||
ffi::ScratchAllocator& scratch, bool lower,
|
||||
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
|
||||
ffi::Result<ffi::AnyBuffer> w,
|
||||
ffi::Result<ffi::Buffer<ffi::S32>> info) {
|
||||
FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream));
|
||||
auto dataType = a.element_type();
|
||||
FFI_ASSIGN_OR_RETURN(auto aType, SolverDataType(dataType, "syevd"));
|
||||
FFI_ASSIGN_OR_RETURN(auto wType, SolverDataType(w->element_type(), "syevd"));
|
||||
|
||||
gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR;
|
||||
gpusolverFillMode_t uplo =
|
||||
lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER;
|
||||
|
||||
gpusolverDnParams_t params;
|
||||
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateParams(¶ms));
|
||||
std::unique_ptr<gpusolverDnParams, void (*)(gpusolverDnParams_t)>
|
||||
params_cleanup(
|
||||
params, [](gpusolverDnParams_t p) { gpusolverDnDestroyParams(p); });
|
||||
|
||||
size_t workspaceInBytesOnDevice, workspaceInBytesOnHost;
|
||||
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevd_bufferSize(
|
||||
handle.get(), params, jobz, uplo, n, aType, /*a=*/nullptr, n, wType,
|
||||
/*w=*/nullptr, aType, &workspaceInBytesOnDevice,
|
||||
&workspaceInBytesOnHost));
|
||||
|
||||
auto maybe_workspace = scratch.Allocate(workspaceInBytesOnDevice);
|
||||
if (!maybe_workspace.has_value()) {
|
||||
return ffi::Error(ffi::ErrorCode::kResourceExhausted,
|
||||
"Unable to allocate device workspace for syevd");
|
||||
}
|
||||
auto workspaceOnDevice = maybe_workspace.value();
|
||||
auto workspaceOnHost =
|
||||
std::unique_ptr<char[]>(new char[workspaceInBytesOnHost]);
|
||||
|
||||
const char* a_data = static_cast<const char*>(a.untyped_data());
|
||||
char* out_data = static_cast<char*>(out->untyped_data());
|
||||
char* w_data = static_cast<char*>(w->untyped_data());
|
||||
int* info_data = info->typed_data();
|
||||
if (a_data != out_data) {
|
||||
JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
|
||||
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
|
||||
size_t out_step = n * n * ffi::ByteWidth(dataType);
|
||||
size_t w_step = n * ffi::ByteWidth(ffi::ToReal(dataType));
|
||||
|
||||
for (auto i = 0; i < batch; ++i) {
|
||||
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevd(
|
||||
handle.get(), params, jobz, uplo, n, aType, out_data, n, wType, w_data,
|
||||
aType, workspaceOnDevice, workspaceInBytesOnDevice,
|
||||
workspaceOnHost.get(), workspaceInBytesOnHost, info_data));
|
||||
out_data += out_step;
|
||||
w_data += w_step;
|
||||
++info_data;
|
||||
}
|
||||
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
ffi::Error SyevdImpl(int64_t batch, int64_t size, gpuStream_t stream,
|
||||
ffi::ScratchAllocator& scratch, SyevdAlgorithm algorithm,
|
||||
bool lower, ffi::AnyBuffer a,
|
||||
ffi::Result<ffi::AnyBuffer> out,
|
||||
ffi::ScratchAllocator& scratch, bool lower,
|
||||
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
|
||||
ffi::Result<ffi::AnyBuffer> w,
|
||||
ffi::Result<ffi::Buffer<ffi::S32>> info) {
|
||||
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(size));
|
||||
@ -408,59 +493,84 @@ ffi::Error SyevdImpl(int64_t batch, int64_t size, gpuStream_t stream,
|
||||
|
||||
auto a_data = static_cast<T*>(a.untyped_data());
|
||||
auto out_data = static_cast<T*>(out->untyped_data());
|
||||
auto w_data = static_cast<typename solver::RealType<T>::value*>(w->untyped_data());
|
||||
auto w_data =
|
||||
static_cast<typename solver::RealType<T>::value*>(w->untyped_data());
|
||||
auto info_data = info->typed_data();
|
||||
if (a_data != out_data) {
|
||||
JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
|
||||
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
if (algorithm == SyevdAlgorithm::kJacobi ||
|
||||
(algorithm == SyevdAlgorithm::kDefault && size <= 32)) {
|
||||
gpuSyevjInfo_t params;
|
||||
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateSyevjInfo(¶ms));
|
||||
std::unique_ptr<gpuSyevjInfo, void (*)(gpuSyevjInfo_t)> params_cleanup(
|
||||
params, [](gpuSyevjInfo_t p) { gpusolverDnDestroySyevjInfo(p); });
|
||||
|
||||
if (batch == 1) {
|
||||
FFI_ASSIGN_OR_RETURN(int lwork, solver::SyevjBufferSize<T>(
|
||||
handle.get(), jobz, uplo, n, params));
|
||||
FFI_ASSIGN_OR_RETURN(auto workspace,
|
||||
AllocateWorkspace<T>(scratch, lwork, "syevj"));
|
||||
FFI_RETURN_IF_ERROR_STATUS(solver::Syevj<T>(handle.get(), jobz, uplo, n,
|
||||
out_data, w_data, workspace,
|
||||
lwork, info_data, params));
|
||||
} else {
|
||||
FFI_ASSIGN_OR_RETURN(
|
||||
int lwork, solver::SyevjBatchedBufferSize<T>(handle.get(), jobz, uplo,
|
||||
n, params, batch));
|
||||
FFI_ASSIGN_OR_RETURN(
|
||||
auto workspace,
|
||||
AllocateWorkspace<T>(scratch, lwork, "syevj_batched"));
|
||||
FFI_RETURN_IF_ERROR_STATUS(
|
||||
solver::SyevjBatched<T>(handle.get(), jobz, uplo, n, out_data, w_data,
|
||||
workspace, lwork, info_data, params, batch));
|
||||
}
|
||||
FFI_ASSIGN_OR_RETURN(int lwork,
|
||||
solver::SyevdBufferSize<T>(handle.get(), jobz, uplo, n));
|
||||
FFI_ASSIGN_OR_RETURN(auto workspace,
|
||||
AllocateWorkspace<T>(scratch, lwork, "syevd"));
|
||||
int out_step = n * n;
|
||||
for (auto i = 0; i < batch; ++i) {
|
||||
FFI_RETURN_IF_ERROR_STATUS(solver::Syevd<T>(handle.get(), jobz, uplo, n,
|
||||
out_data, w_data, workspace,
|
||||
lwork, info_data));
|
||||
out_data += out_step;
|
||||
w_data += n;
|
||||
++info_data;
|
||||
}
|
||||
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ffi::Error SyevdjImpl(int64_t batch, int64_t size, gpuStream_t stream,
|
||||
ffi::ScratchAllocator& scratch, bool lower,
|
||||
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
|
||||
ffi::Result<ffi::AnyBuffer> w,
|
||||
ffi::Result<ffi::Buffer<ffi::S32>> info) {
|
||||
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(size));
|
||||
FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream));
|
||||
|
||||
gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR;
|
||||
gpusolverFillMode_t uplo =
|
||||
lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER;
|
||||
|
||||
auto a_data = static_cast<T*>(a.untyped_data());
|
||||
auto out_data = static_cast<T*>(out->untyped_data());
|
||||
auto w_data =
|
||||
static_cast<typename solver::RealType<T>::value*>(w->untyped_data());
|
||||
auto info_data = info->typed_data();
|
||||
if (a_data != out_data) {
|
||||
JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
|
||||
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
|
||||
gpuSyevjInfo_t params;
|
||||
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateSyevjInfo(¶ms));
|
||||
std::unique_ptr<gpuSyevjInfo, void (*)(gpuSyevjInfo_t)> params_cleanup(
|
||||
params, [](gpuSyevjInfo_t p) { gpusolverDnDestroySyevjInfo(p); });
|
||||
|
||||
if (batch == 1) {
|
||||
FFI_ASSIGN_OR_RETURN(int lwork, solver::SyevjBufferSize<T>(
|
||||
handle.get(), jobz, uplo, n, params));
|
||||
FFI_ASSIGN_OR_RETURN(auto workspace,
|
||||
AllocateWorkspace<T>(scratch, lwork, "syevj"));
|
||||
FFI_RETURN_IF_ERROR_STATUS(solver::Syevj<T>(handle.get(), jobz, uplo, n,
|
||||
out_data, w_data, workspace,
|
||||
lwork, info_data, params));
|
||||
} else {
|
||||
FFI_ASSIGN_OR_RETURN(
|
||||
int lwork, solver::SyevdBufferSize<T>(handle.get(), jobz, uplo, n));
|
||||
int lwork, solver::SyevjBatchedBufferSize<T>(handle.get(), jobz, uplo,
|
||||
n, params, batch));
|
||||
FFI_ASSIGN_OR_RETURN(auto workspace,
|
||||
AllocateWorkspace<T>(scratch, lwork, "syevd"));
|
||||
int out_step = n * n;
|
||||
for (auto i = 0; i < batch; ++i) {
|
||||
FFI_RETURN_IF_ERROR_STATUS(solver::Syevd<T>(handle.get(), jobz, uplo, n,
|
||||
out_data, w_data, workspace,
|
||||
lwork, info_data));
|
||||
out_data += out_step;
|
||||
w_data += n;
|
||||
++info_data;
|
||||
}
|
||||
AllocateWorkspace<T>(scratch, lwork, "syevj_batched"));
|
||||
FFI_RETURN_IF_ERROR_STATUS(
|
||||
solver::SyevjBatched<T>(handle.get(), jobz, uplo, n, out_data, w_data,
|
||||
workspace, lwork, info_data, params, batch));
|
||||
}
|
||||
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
ffi::Error SyevdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
|
||||
SyevdAlgorithm algorithm, bool lower,
|
||||
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
|
||||
SyevdAlgorithm algorithm, bool lower, ffi::AnyBuffer a,
|
||||
ffi::Result<ffi::AnyBuffer> out,
|
||||
ffi::Result<ffi::AnyBuffer> w,
|
||||
ffi::Result<ffi::Buffer<ffi::S32>> info) {
|
||||
auto dataType = a.element_type();
|
||||
@ -479,8 +589,18 @@ ffi::Error SyevdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
|
||||
CheckShape(out->dimensions(), {batch, rows, cols}, "out", "syevd"));
|
||||
FFI_RETURN_IF_ERROR(CheckShape(w->dimensions(), {batch, cols}, "w", "syevd"));
|
||||
FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "syevd"));
|
||||
SOLVER_DISPATCH_IMPL(SyevdImpl, batch, cols, stream, scratch, algorithm,
|
||||
lower, a, out, w, info);
|
||||
if (algorithm == SyevdAlgorithm::kJacobi ||
|
||||
(algorithm == SyevdAlgorithm::kDefault && cols <= 32)) {
|
||||
SOLVER_DISPATCH_IMPL(SyevdjImpl, batch, cols, stream, scratch, lower, a,
|
||||
out, w, info);
|
||||
} else {
|
||||
#if JAX_GPU_HAVE_64_BIT
|
||||
return Syevd64Impl(batch, cols, stream, scratch, lower, a, out, w, info);
|
||||
#else
|
||||
SOLVER_DISPATCH_IMPL(SyevdImpl, batch, cols, stream, scratch, lower, a, out,
|
||||
w, info);
|
||||
#endif
|
||||
}
|
||||
return ffi::Error::InvalidArgument(absl::StrFormat(
|
||||
"Unsupported dtype %s in syevd", absl::FormatStreamed(dataType)));
|
||||
}
|
||||
@ -577,7 +697,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(SyrkFfi, SyrkDispatch,
|
||||
|
||||
// Singular Value Decomposition: gesvd
|
||||
|
||||
#if JAX_GPU_64_BIT
|
||||
#if JAX_GPU_HAVE_64_BIT
|
||||
|
||||
ffi::Error Gesvd64Impl(int64_t batch, int64_t m, int64_t n, gpuStream_t stream,
|
||||
ffi::ScratchAllocator& scratch, bool full_matrices,
|
||||
@ -589,30 +709,9 @@ ffi::Error Gesvd64Impl(int64_t batch, int64_t m, int64_t n, gpuStream_t stream,
|
||||
ffi::Result<ffi::Buffer<ffi::S32>> info) {
|
||||
FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream));
|
||||
signed char job = compute_uv ? (full_matrices ? 'A' : 'S') : 'N';
|
||||
|
||||
auto dataType = a.element_type();
|
||||
gpuDataType aType, sType;
|
||||
switch (dataType) {
|
||||
case ffi::F32:
|
||||
aType = GPU_R_32F;
|
||||
sType = GPU_R_32F;
|
||||
break;
|
||||
case ffi::F64:
|
||||
aType = GPU_R_64F;
|
||||
sType = GPU_R_64F;
|
||||
break;
|
||||
case ffi::C64:
|
||||
aType = GPU_C_32F;
|
||||
sType = GPU_R_32F;
|
||||
break;
|
||||
case ffi::C128:
|
||||
aType = GPU_C_64F;
|
||||
sType = GPU_R_64F;
|
||||
break;
|
||||
default:
|
||||
return ffi::Error::InvalidArgument(absl::StrFormat(
|
||||
"Unsupported dtype %s in gesvd", absl::FormatStreamed(dataType)));
|
||||
}
|
||||
FFI_ASSIGN_OR_RETURN(auto aType, SolverDataType(dataType, "syevd"));
|
||||
FFI_ASSIGN_OR_RETURN(auto sType, SolverDataType(s->element_type(), "syevd"));
|
||||
|
||||
gpusolverDnParams_t params;
|
||||
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateParams(¶ms));
|
||||
@ -692,7 +791,8 @@ ffi::Error GesvdImpl(int64_t batch, int64_t rows, int64_t cols,
|
||||
AllocateWorkspace<T>(scratch, lwork, "gesvd"));
|
||||
auto a_data = static_cast<T*>(a.untyped_data());
|
||||
auto out_data = static_cast<T*>(out->untyped_data());
|
||||
auto s_data = static_cast<typename solver::RealType<T>::value*>(s->untyped_data());
|
||||
auto s_data =
|
||||
static_cast<typename solver::RealType<T>::value*>(s->untyped_data());
|
||||
auto u_data = compute_uv ? static_cast<T*>(u->untyped_data()) : nullptr;
|
||||
auto vt_data = compute_uv ? static_cast<T*>(vt->untyped_data()) : nullptr;
|
||||
auto info_data = info->typed_data();
|
||||
@ -717,7 +817,7 @@ ffi::Error GesvdImpl(int64_t batch, int64_t rows, int64_t cols,
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
#endif // JAX_GPU_64_BIT
|
||||
#endif // JAX_GPU_HAVE_64_BIT
|
||||
|
||||
ffi::Error GesvdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
|
||||
bool full_matrices, bool compute_uv, bool transposed,
|
||||
@ -763,7 +863,7 @@ ffi::Error GesvdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
|
||||
}
|
||||
FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "gesvd"));
|
||||
|
||||
#if JAX_GPU_64_BIT
|
||||
#if JAX_GPU_HAVE_64_BIT
|
||||
return Gesvd64Impl(batch, m, n, stream, scratch, full_matrices, compute_uv, a,
|
||||
out, s, u, vt, info);
|
||||
#else
|
||||
|
@ -332,7 +332,7 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
|
||||
#define gpuGetDeviceProperties cudaGetDeviceProperties
|
||||
#define gpuLaunchCooperativeKernel cudaLaunchCooperativeKernel
|
||||
|
||||
#define JAX_GPU_64_BIT 1
|
||||
#define JAX_GPU_HAVE_64_BIT 1
|
||||
|
||||
#define GPU_R_32F CUDA_R_32F
|
||||
#define GPU_R_64F CUDA_R_64F
|
||||
@ -345,6 +345,8 @@ typedef cusolverDnParams_t gpusolverDnParams_t;
|
||||
#define gpusolverDnCreateParams cusolverDnCreateParams
|
||||
#define gpusolverDnDestroyParams cusolverDnDestroyParams
|
||||
|
||||
#define gpusolverDnXsyevd_bufferSize cusolverDnXsyevd_bufferSize
|
||||
#define gpusolverDnXsyevd cusolverDnXsyevd
|
||||
#define gpusolverDnXgesvd_bufferSize cusolverDnXgesvd_bufferSize
|
||||
#define gpusolverDnXgesvd cusolverDnXgesvd
|
||||
|
||||
@ -368,7 +370,7 @@ constexpr uint32_t kNumThreadsPerWarp = 32;
|
||||
#define JAX_GPU_PREFIX "hip"
|
||||
|
||||
#define JAX_GPU_HAVE_SPARSE 1
|
||||
#define JAX_GPU_64_BIT 0
|
||||
#define JAX_GPU_HAVE_64_BIT 0
|
||||
#define JAX_GPU_HAVE_FP8 0
|
||||
|
||||
typedef hipFloatComplex gpuComplex;
|
||||
|
Loading…
x
Reference in New Issue
Block a user