From 6625a2b3edfb76fe563937874208251129d23409 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 10 Oct 2024 12:58:53 -0700 Subject: [PATCH] Update Eigh kernel on GPU to use 64-bit interface when it is available. Part of https://github.com/jax-ml/jax/issues/23413 PiperOrigin-RevId: 684546802 --- jaxlib/gpu/solver_kernels_ffi.cc | 258 +++++++++++++++++++++---------- jaxlib/gpu/vendor.h | 6 +- 2 files changed, 183 insertions(+), 81 deletions(-) diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index 32cd97565..7852da4bc 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -21,6 +21,14 @@ limitations under the License. #include #include +#if JAX_GPU_HAVE_64_BIT +#include +#endif + +#ifdef JAX_GPU_CUDA +#include +#endif + #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" @@ -33,14 +41,6 @@ limitations under the License. #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" -#if JAX_GPU_64_BIT -#include -#endif - -#ifdef JAX_GPU_CUDA -#include -#endif - #define JAX_FFI_RETURN_IF_GPU_ERROR(...) \ FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(__VA_ARGS__)) @@ -64,6 +64,28 @@ inline absl::StatusOr AllocateWorkspace(ffi::ScratchAllocator& scratch, return static_cast(maybe_workspace.value()); } +#if JAX_GPU_HAVE_64_BIT + +// Map an FFI buffer element type to the appropriate GPU solver type. +inline absl::StatusOr SolverDataType(ffi::DataType dataType, + std::string_view func) { + switch (dataType) { + case ffi::F32: + return GPU_R_32F; + case ffi::F64: + return GPU_R_64F; + case ffi::C64: + return GPU_C_32F; + case ffi::C128: + return GPU_C_64F; + default: + return absl::InvalidArgumentError(absl::StrFormat( + "Unsupported dtype %s in %s", absl::FormatStreamed(dataType), func)); + } +} + +#endif + #define SOLVER_DISPATCH_IMPL(impl, ...) \ switch (dataType) { \ case ffi::F32: \ @@ -392,11 +414,74 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch, // dispatches dynamically to both syevd and syevj depending on the problem // size and the algorithm selected by the user via the `algorithm` attribute. +#if JAX_GPU_HAVE_64_BIT + +ffi::Error Syevd64Impl(int64_t batch, int64_t n, gpuStream_t stream, + ffi::ScratchAllocator& scratch, bool lower, + ffi::AnyBuffer a, ffi::Result out, + ffi::Result w, + ffi::Result> info) { + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + auto dataType = a.element_type(); + FFI_ASSIGN_OR_RETURN(auto aType, SolverDataType(dataType, "syevd")); + FFI_ASSIGN_OR_RETURN(auto wType, SolverDataType(w->element_type(), "syevd")); + + gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR; + gpusolverFillMode_t uplo = + lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER; + + gpusolverDnParams_t params; + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateParams(¶ms)); + std::unique_ptr + params_cleanup( + params, [](gpusolverDnParams_t p) { gpusolverDnDestroyParams(p); }); + + size_t workspaceInBytesOnDevice, workspaceInBytesOnHost; + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevd_bufferSize( + handle.get(), params, jobz, uplo, n, aType, /*a=*/nullptr, n, wType, + /*w=*/nullptr, aType, &workspaceInBytesOnDevice, + &workspaceInBytesOnHost)); + + auto maybe_workspace = scratch.Allocate(workspaceInBytesOnDevice); + if (!maybe_workspace.has_value()) { + return ffi::Error(ffi::ErrorCode::kResourceExhausted, + "Unable to allocate device workspace for syevd"); + } + auto workspaceOnDevice = maybe_workspace.value(); + auto workspaceOnHost = + std::unique_ptr(new char[workspaceInBytesOnHost]); + + const char* a_data = static_cast(a.untyped_data()); + char* out_data = static_cast(out->untyped_data()); + char* w_data = static_cast(w->untyped_data()); + int* info_data = info->typed_data(); + if (a_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); + } + + size_t out_step = n * n * ffi::ByteWidth(dataType); + size_t w_step = n * ffi::ByteWidth(ffi::ToReal(dataType)); + + for (auto i = 0; i < batch; ++i) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevd( + handle.get(), params, jobz, uplo, n, aType, out_data, n, wType, w_data, + aType, workspaceOnDevice, workspaceInBytesOnDevice, + workspaceOnHost.get(), workspaceInBytesOnHost, info_data)); + out_data += out_step; + w_data += w_step; + ++info_data; + } + + return ffi::Error::Success(); +} + +#endif + template ffi::Error SyevdImpl(int64_t batch, int64_t size, gpuStream_t stream, - ffi::ScratchAllocator& scratch, SyevdAlgorithm algorithm, - bool lower, ffi::AnyBuffer a, - ffi::Result out, + ffi::ScratchAllocator& scratch, bool lower, + ffi::AnyBuffer a, ffi::Result out, ffi::Result w, ffi::Result> info) { FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(size)); @@ -408,59 +493,84 @@ ffi::Error SyevdImpl(int64_t batch, int64_t size, gpuStream_t stream, auto a_data = static_cast(a.untyped_data()); auto out_data = static_cast(out->untyped_data()); - auto w_data = static_cast::value*>(w->untyped_data()); + auto w_data = + static_cast::value*>(w->untyped_data()); auto info_data = info->typed_data(); if (a_data != out_data) { JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); } - if (algorithm == SyevdAlgorithm::kJacobi || - (algorithm == SyevdAlgorithm::kDefault && size <= 32)) { - gpuSyevjInfo_t params; - JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateSyevjInfo(¶ms)); - std::unique_ptr params_cleanup( - params, [](gpuSyevjInfo_t p) { gpusolverDnDestroySyevjInfo(p); }); - if (batch == 1) { - FFI_ASSIGN_OR_RETURN(int lwork, solver::SyevjBufferSize( - handle.get(), jobz, uplo, n, params)); - FFI_ASSIGN_OR_RETURN(auto workspace, - AllocateWorkspace(scratch, lwork, "syevj")); - FFI_RETURN_IF_ERROR_STATUS(solver::Syevj(handle.get(), jobz, uplo, n, - out_data, w_data, workspace, - lwork, info_data, params)); - } else { - FFI_ASSIGN_OR_RETURN( - int lwork, solver::SyevjBatchedBufferSize(handle.get(), jobz, uplo, - n, params, batch)); - FFI_ASSIGN_OR_RETURN( - auto workspace, - AllocateWorkspace(scratch, lwork, "syevj_batched")); - FFI_RETURN_IF_ERROR_STATUS( - solver::SyevjBatched(handle.get(), jobz, uplo, n, out_data, w_data, - workspace, lwork, info_data, params, batch)); - } + FFI_ASSIGN_OR_RETURN(int lwork, + solver::SyevdBufferSize(handle.get(), jobz, uplo, n)); + FFI_ASSIGN_OR_RETURN(auto workspace, + AllocateWorkspace(scratch, lwork, "syevd")); + int out_step = n * n; + for (auto i = 0; i < batch; ++i) { + FFI_RETURN_IF_ERROR_STATUS(solver::Syevd(handle.get(), jobz, uplo, n, + out_data, w_data, workspace, + lwork, info_data)); + out_data += out_step; + w_data += n; + ++info_data; + } + + return ffi::Error::Success(); +} + +template +ffi::Error SyevdjImpl(int64_t batch, int64_t size, gpuStream_t stream, + ffi::ScratchAllocator& scratch, bool lower, + ffi::AnyBuffer a, ffi::Result out, + ffi::Result w, + ffi::Result> info) { + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(size)); + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + + gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR; + gpusolverFillMode_t uplo = + lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER; + + auto a_data = static_cast(a.untyped_data()); + auto out_data = static_cast(out->untyped_data()); + auto w_data = + static_cast::value*>(w->untyped_data()); + auto info_data = info->typed_data(); + if (a_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); + } + + gpuSyevjInfo_t params; + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateSyevjInfo(¶ms)); + std::unique_ptr params_cleanup( + params, [](gpuSyevjInfo_t p) { gpusolverDnDestroySyevjInfo(p); }); + + if (batch == 1) { + FFI_ASSIGN_OR_RETURN(int lwork, solver::SyevjBufferSize( + handle.get(), jobz, uplo, n, params)); + FFI_ASSIGN_OR_RETURN(auto workspace, + AllocateWorkspace(scratch, lwork, "syevj")); + FFI_RETURN_IF_ERROR_STATUS(solver::Syevj(handle.get(), jobz, uplo, n, + out_data, w_data, workspace, + lwork, info_data, params)); } else { FFI_ASSIGN_OR_RETURN( - int lwork, solver::SyevdBufferSize(handle.get(), jobz, uplo, n)); + int lwork, solver::SyevjBatchedBufferSize(handle.get(), jobz, uplo, + n, params, batch)); FFI_ASSIGN_OR_RETURN(auto workspace, - AllocateWorkspace(scratch, lwork, "syevd")); - int out_step = n * n; - for (auto i = 0; i < batch; ++i) { - FFI_RETURN_IF_ERROR_STATUS(solver::Syevd(handle.get(), jobz, uplo, n, - out_data, w_data, workspace, - lwork, info_data)); - out_data += out_step; - w_data += n; - ++info_data; - } + AllocateWorkspace(scratch, lwork, "syevj_batched")); + FFI_RETURN_IF_ERROR_STATUS( + solver::SyevjBatched(handle.get(), jobz, uplo, n, out_data, w_data, + workspace, lwork, info_data, params, batch)); } + return ffi::Error::Success(); } ffi::Error SyevdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, - SyevdAlgorithm algorithm, bool lower, - ffi::AnyBuffer a, ffi::Result out, + SyevdAlgorithm algorithm, bool lower, ffi::AnyBuffer a, + ffi::Result out, ffi::Result w, ffi::Result> info) { auto dataType = a.element_type(); @@ -479,8 +589,18 @@ ffi::Error SyevdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, CheckShape(out->dimensions(), {batch, rows, cols}, "out", "syevd")); FFI_RETURN_IF_ERROR(CheckShape(w->dimensions(), {batch, cols}, "w", "syevd")); FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "syevd")); - SOLVER_DISPATCH_IMPL(SyevdImpl, batch, cols, stream, scratch, algorithm, - lower, a, out, w, info); + if (algorithm == SyevdAlgorithm::kJacobi || + (algorithm == SyevdAlgorithm::kDefault && cols <= 32)) { + SOLVER_DISPATCH_IMPL(SyevdjImpl, batch, cols, stream, scratch, lower, a, + out, w, info); + } else { +#if JAX_GPU_HAVE_64_BIT + return Syevd64Impl(batch, cols, stream, scratch, lower, a, out, w, info); +#else + SOLVER_DISPATCH_IMPL(SyevdImpl, batch, cols, stream, scratch, lower, a, out, + w, info); +#endif + } return ffi::Error::InvalidArgument(absl::StrFormat( "Unsupported dtype %s in syevd", absl::FormatStreamed(dataType))); } @@ -577,7 +697,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(SyrkFfi, SyrkDispatch, // Singular Value Decomposition: gesvd -#if JAX_GPU_64_BIT +#if JAX_GPU_HAVE_64_BIT ffi::Error Gesvd64Impl(int64_t batch, int64_t m, int64_t n, gpuStream_t stream, ffi::ScratchAllocator& scratch, bool full_matrices, @@ -589,30 +709,9 @@ ffi::Error Gesvd64Impl(int64_t batch, int64_t m, int64_t n, gpuStream_t stream, ffi::Result> info) { FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); signed char job = compute_uv ? (full_matrices ? 'A' : 'S') : 'N'; - auto dataType = a.element_type(); - gpuDataType aType, sType; - switch (dataType) { - case ffi::F32: - aType = GPU_R_32F; - sType = GPU_R_32F; - break; - case ffi::F64: - aType = GPU_R_64F; - sType = GPU_R_64F; - break; - case ffi::C64: - aType = GPU_C_32F; - sType = GPU_R_32F; - break; - case ffi::C128: - aType = GPU_C_64F; - sType = GPU_R_64F; - break; - default: - return ffi::Error::InvalidArgument(absl::StrFormat( - "Unsupported dtype %s in gesvd", absl::FormatStreamed(dataType))); - } + FFI_ASSIGN_OR_RETURN(auto aType, SolverDataType(dataType, "syevd")); + FFI_ASSIGN_OR_RETURN(auto sType, SolverDataType(s->element_type(), "syevd")); gpusolverDnParams_t params; JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateParams(¶ms)); @@ -692,7 +791,8 @@ ffi::Error GesvdImpl(int64_t batch, int64_t rows, int64_t cols, AllocateWorkspace(scratch, lwork, "gesvd")); auto a_data = static_cast(a.untyped_data()); auto out_data = static_cast(out->untyped_data()); - auto s_data = static_cast::value*>(s->untyped_data()); + auto s_data = + static_cast::value*>(s->untyped_data()); auto u_data = compute_uv ? static_cast(u->untyped_data()) : nullptr; auto vt_data = compute_uv ? static_cast(vt->untyped_data()) : nullptr; auto info_data = info->typed_data(); @@ -717,7 +817,7 @@ ffi::Error GesvdImpl(int64_t batch, int64_t rows, int64_t cols, return ffi::Error::Success(); } -#endif // JAX_GPU_64_BIT +#endif // JAX_GPU_HAVE_64_BIT ffi::Error GesvdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, bool full_matrices, bool compute_uv, bool transposed, @@ -763,7 +863,7 @@ ffi::Error GesvdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, } FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "gesvd")); -#if JAX_GPU_64_BIT +#if JAX_GPU_HAVE_64_BIT return Gesvd64Impl(batch, m, n, stream, scratch, full_matrices, compute_uv, a, out, s, u, vt, info); #else diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index fa247b08b..648580f08 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -332,7 +332,7 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpuGetDeviceProperties cudaGetDeviceProperties #define gpuLaunchCooperativeKernel cudaLaunchCooperativeKernel -#define JAX_GPU_64_BIT 1 +#define JAX_GPU_HAVE_64_BIT 1 #define GPU_R_32F CUDA_R_32F #define GPU_R_64F CUDA_R_64F @@ -345,6 +345,8 @@ typedef cusolverDnParams_t gpusolverDnParams_t; #define gpusolverDnCreateParams cusolverDnCreateParams #define gpusolverDnDestroyParams cusolverDnDestroyParams +#define gpusolverDnXsyevd_bufferSize cusolverDnXsyevd_bufferSize +#define gpusolverDnXsyevd cusolverDnXsyevd #define gpusolverDnXgesvd_bufferSize cusolverDnXgesvd_bufferSize #define gpusolverDnXgesvd cusolverDnXgesvd @@ -368,7 +370,7 @@ constexpr uint32_t kNumThreadsPerWarp = 32; #define JAX_GPU_PREFIX "hip" #define JAX_GPU_HAVE_SPARSE 1 -#define JAX_GPU_64_BIT 0 +#define JAX_GPU_HAVE_64_BIT 0 #define JAX_GPU_HAVE_FP8 0 typedef hipFloatComplex gpuComplex;