diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index c65ad088a..38936ee49 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -20,8 +20,8 @@ limitations under the License. #include "nanobind/nanobind.h" #include "nanobind/stl/pair.h" #include "absl/container/flat_hash_map.h" -#include "absl/strings/str_format.h" #include "absl/status/statusor.h" +#include "absl/strings/str_format.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/solver_handle_pool.h" #include "jaxlib/gpu/solver_kernels.h" @@ -481,6 +481,11 @@ nb::dict Registrations() { dict[JAX_GPU_PREFIX "solver_orgqr_ffi"] = EncapsulateFfiHandler(OrgqrFfi); dict[JAX_GPU_PREFIX "solver_syevd_ffi"] = EncapsulateFfiHandler(SyevdFfi); dict[JAX_GPU_PREFIX "solver_syrk_ffi"] = EncapsulateFfiHandler(SyrkFfi); + dict[JAX_GPU_PREFIX "solver_gesvd_ffi"] = EncapsulateFfiHandler(GesvdFfi); + +#ifdef JAX_GPU_CUDA + dict[JAX_GPU_PREFIX "solver_gesvdj_ffi"] = EncapsulateFfiHandler(GesvdjFfi); +#endif // JAX_GPU_CUDA return dict; } diff --git a/jaxlib/gpu/solver_interface.cc b/jaxlib/gpu/solver_interface.cc index 3c8282ec6..4d1af3c50 100644 --- a/jaxlib/gpu/solver_interface.cc +++ b/jaxlib/gpu/solver_interface.cc @@ -232,6 +232,91 @@ JAX_GPU_DEFINE_SYRK(gpublasComplex, gpublasCsyrk); JAX_GPU_DEFINE_SYRK(gpublasDoubleComplex, gpublasZsyrk); #undef JAX_GPU_DEFINE_SYRK +// Singular Value Decomposition: gesvd + +#define JAX_GPU_DEFINE_GESVD(Type, Name) \ + template <> \ + absl::StatusOr GesvdBufferSize(gpusolverDnHandle_t handle, \ + signed char job, int m, int n) { \ + int lwork; \ + JAX_RETURN_IF_ERROR( \ + JAX_AS_STATUS(Name##_bufferSize(handle, job, job, m, n, &lwork))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status Gesvd(gpusolverDnHandle_t handle, signed char job, int m, \ + int n, Type *a, RealType::value *s, Type *u, \ + Type *vt, Type *workspace, int lwork, int *info) { \ + return JAX_AS_STATUS(Name(handle, job, job, m, n, a, m, s, u, m, vt, n, \ + workspace, lwork, /*rwork=*/nullptr, info)); \ + } + +JAX_GPU_DEFINE_GESVD(float, gpusolverDnSgesvd); +JAX_GPU_DEFINE_GESVD(double, gpusolverDnDgesvd); +JAX_GPU_DEFINE_GESVD(gpuComplex, gpusolverDnCgesvd); +JAX_GPU_DEFINE_GESVD(gpuDoubleComplex, gpusolverDnZgesvd); +#undef JAX_GPU_DEFINE_GESVD + +#ifdef JAX_GPU_CUDA + +#define JAX_GPU_DEFINE_GESVDJ(Type, Name) \ + template <> \ + absl::StatusOr GesvdjBufferSize( \ + gpusolverDnHandle_t handle, gpusolverEigMode_t job, int econ, int m, \ + int n, gpuGesvdjInfo_t params) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(Name##_bufferSize( \ + handle, job, econ, m, n, /*a=*/nullptr, /*lda=*/m, /*s=*/nullptr, \ + /*u=*/nullptr, /*ldu=*/m, /*v=*/nullptr, /*ldv=*/n, &lwork, params))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status Gesvdj( \ + gpusolverDnHandle_t handle, gpusolverEigMode_t job, int econ, int m, \ + int n, Type *a, RealType::value *s, Type *u, Type *v, \ + Type *workspace, int lwork, int *info, gpuGesvdjInfo_t params) { \ + return JAX_AS_STATUS(Name(handle, job, econ, m, n, a, m, s, u, m, v, n, \ + workspace, lwork, info, params)); \ + } + +JAX_GPU_DEFINE_GESVDJ(float, gpusolverDnSgesvdj); +JAX_GPU_DEFINE_GESVDJ(double, gpusolverDnDgesvdj); +JAX_GPU_DEFINE_GESVDJ(gpuComplex, gpusolverDnCgesvdj); +JAX_GPU_DEFINE_GESVDJ(gpuDoubleComplex, gpusolverDnZgesvdj); +#undef JAX_GPU_DEFINE_GESVDJ + +#define JAX_GPU_DEFINE_GESVDJ_BATCHED(Type, Name) \ + template <> \ + absl::StatusOr GesvdjBatchedBufferSize( \ + gpusolverDnHandle_t handle, gpusolverEigMode_t job, int m, int n, \ + gpuGesvdjInfo_t params, int batch) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ + Name##_bufferSize(handle, job, m, n, /*a=*/nullptr, /*lda=*/m, \ + /*s=*/nullptr, /*u=*/nullptr, /*ldu=*/m, \ + /*v=*/nullptr, /*ldv=*/n, &lwork, params, batch))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status GesvdjBatched( \ + gpusolverDnHandle_t handle, gpusolverEigMode_t job, int m, int n, \ + Type *a, RealType::value *s, Type *u, Type *v, Type *workspace, \ + int lwork, int *info, gpuGesvdjInfo_t params, int batch) { \ + return JAX_AS_STATUS(Name(handle, job, m, n, a, m, s, u, m, v, n, \ + workspace, lwork, info, params, batch)); \ + } + +JAX_GPU_DEFINE_GESVDJ_BATCHED(float, gpusolverDnSgesvdjBatched); +JAX_GPU_DEFINE_GESVDJ_BATCHED(double, gpusolverDnDgesvdjBatched); +JAX_GPU_DEFINE_GESVDJ_BATCHED(gpuComplex, gpusolverDnCgesvdjBatched); +JAX_GPU_DEFINE_GESVDJ_BATCHED(gpuDoubleComplex, gpusolverDnZgesvdjBatched); +#undef JAX_GPU_DEFINE_GESVDJ_BATCHED + +#endif // JAX_GPU_CUDA + } // namespace solver } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/solver_interface.h b/jaxlib/gpu/solver_interface.h index 5072be984..336480e2e 100644 --- a/jaxlib/gpu/solver_interface.h +++ b/jaxlib/gpu/solver_interface.h @@ -165,6 +165,49 @@ JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Syevd); JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Syrk); #undef JAX_GPU_SOLVER_Syrk_ARGS +// Singular Value Decomposition: gesvd + +#define JAX_GPU_SOLVER_GesvdBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, signed char job, int m, int n +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, GesvdBufferSize); +#undef JAX_GPU_SOLVER_GesvdBufferSize_ARGS + +#define JAX_GPU_SOLVER_Gesvd_ARGS(Type, Real) \ + gpusolverDnHandle_t handle, signed char job, int m, int n, Type *a, Real *s, \ + Type *u, Type *vt, Type *workspace, int lwork, int *info +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Gesvd); +#undef JAX_GPU_SOLVER_Gesvd_ARGS + +#ifdef JAX_GPU_CUDA + +#define JAX_GPU_SOLVER_GesvdjBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t job, int econ, int m, int n, \ + gesvdjInfo_t params +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, GesvdjBufferSize); +#undef JAX_GPU_SOLVER_GesvdjBufferSize_ARGS + +#define JAX_GPU_SOLVER_Gesvdj_ARGS(Type, Real) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t job, int econ, int m, int n, \ + Type *a, Real *s, Type *u, Type *v, Type *workspace, \ + int lwork, int *info, gesvdjInfo_t params +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Gesvdj); +#undef JAX_GPU_SOLVER_Gesvdj_ARGS + +#define JAX_GPU_SOLVER_GesvdjBatchedBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t job, int m, int n, \ + gpuGesvdjInfo_t params, int batch +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, GesvdjBatchedBufferSize); +#undef JAX_GPU_SOLVER_GesvdjBatchedBufferSize_ARGS + +#define JAX_GPU_SOLVER_GesvdjBatched_ARGS(Type, Real) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t job, int m, int n, Type *a, \ + Real *s, Type *u, Type *v, Type *workspace, int lwork, \ + int *info, gpuGesvdjInfo_t params, int batch +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, GesvdjBatched); +#undef JAX_GPU_SOLVER_GesvdjBatched_ARGS + +#endif // JAX_GPU_CUDA + #undef JAX_GPU_SOLVER_EXPAND_DEFINITION } // namespace solver diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index e3f63234f..9191a0ff8 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -33,6 +33,14 @@ limitations under the License. #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" +#if JAX_GPU_64_BIT +#include +#endif + +#ifdef JAX_GPU_CUDA +#include +#endif + #define JAX_FFI_RETURN_IF_GPU_ERROR(...) \ FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(__VA_ARGS__)) @@ -56,26 +64,32 @@ inline absl::StatusOr AllocateWorkspace(ffi::ScratchAllocator& scratch, return static_cast(maybe_workspace.value()); } -#define SOLVER_DISPATCH_IMPL(impl, ...) \ - if (dataType == ffi::F32) { \ - return impl(__VA_ARGS__); \ - } else if (dataType == ffi::F64) { \ - return impl(__VA_ARGS__); \ - } else if (dataType == ffi::C64) { \ - return impl(__VA_ARGS__); \ - } else if (dataType == ffi::C128) { \ - return impl(__VA_ARGS__); \ +#define SOLVER_DISPATCH_IMPL(impl, ...) \ + switch (dataType) { \ + case ffi::F32: \ + return impl(__VA_ARGS__); \ + case ffi::F64: \ + return impl(__VA_ARGS__); \ + case ffi::C64: \ + return impl(__VA_ARGS__); \ + case ffi::C128: \ + return impl(__VA_ARGS__); \ + default: \ + break; \ } -#define SOLVER_BLAS_DISPATCH_IMPL(impl, ...) \ - if (dataType == ffi::F32) { \ - return impl(__VA_ARGS__); \ - } else if (dataType == ffi::F64) { \ - return impl(__VA_ARGS__); \ - } else if (dataType == ffi::C64) { \ - return impl(__VA_ARGS__); \ - } else if (dataType == ffi::C128) { \ - return impl(__VA_ARGS__); \ +#define SOLVER_BLAS_DISPATCH_IMPL(impl, ...) \ + switch (dataType) { \ + case ffi::F32: \ + return impl(__VA_ARGS__); \ + case ffi::F64: \ + return impl(__VA_ARGS__); \ + case ffi::C64: \ + return impl(__VA_ARGS__); \ + case ffi::C128: \ + return impl(__VA_ARGS__); \ + default: \ + break; \ } // LU decomposition: getrf @@ -445,8 +459,8 @@ ffi::Error SyevdImpl(int64_t batch, int64_t size, gpuStream_t stream, } ffi::Error SyevdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, - SyevdAlgorithm algorithm, bool lower, ffi::AnyBuffer a, - ffi::Result out, + SyevdAlgorithm algorithm, bool lower, + ffi::AnyBuffer a, ffi::Result out, ffi::Result w, ffi::Result> info) { auto dataType = a.element_type(); @@ -561,6 +575,345 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(SyrkFfi, SyrkDispatch, .Ret() // c_out ); +// Singular Value Decomposition: gesvd + +#if JAX_GPU_64_BIT + +ffi::Error Gesvd64Impl(int64_t batch, int64_t m, int64_t n, gpuStream_t stream, + ffi::ScratchAllocator& scratch, bool full_matrices, + bool compute_uv, ffi::AnyBuffer a, + ffi::Result out, + ffi::Result s, + ffi::Result u, + ffi::Result vt, + ffi::Result> info) { + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + signed char job = compute_uv ? (full_matrices ? 'A' : 'S') : 'N'; + + auto dataType = a.element_type(); + gpuDataType aType, sType; + switch (dataType) { + case ffi::F32: + aType = GPU_R_32F; + sType = GPU_R_32F; + break; + case ffi::F64: + aType = GPU_R_64F; + sType = GPU_R_64F; + break; + case ffi::C64: + aType = GPU_C_32F; + sType = GPU_R_32F; + break; + case ffi::C128: + aType = GPU_C_64F; + sType = GPU_R_64F; + break; + default: + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in gesvd", absl::FormatStreamed(dataType))); + } + + gpusolverDnParams_t params; + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateParams(¶ms)); + std::unique_ptr + params_cleanup( + params, [](gpusolverDnParams_t p) { gpusolverDnDestroyParams(p); }); + + size_t workspaceInBytesOnDevice, workspaceInBytesOnHost; + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXgesvd_bufferSize( + handle.get(), params, job, job, m, n, aType, /*a=*/nullptr, m, sType, + /*s=*/nullptr, aType, /*u=*/nullptr, m, aType, /*vt=*/nullptr, n, aType, + &workspaceInBytesOnDevice, &workspaceInBytesOnHost)); + + auto maybe_workspace = scratch.Allocate(workspaceInBytesOnDevice); + if (!maybe_workspace.has_value()) { + return ffi::Error(ffi::ErrorCode::kResourceExhausted, + "Unable to allocate device workspace for gesvd"); + } + auto workspaceOnDevice = maybe_workspace.value(); + auto workspaceOnHost = + std::unique_ptr(new char[workspaceInBytesOnHost]); + + const char* a_data = static_cast(a.untyped_data()); + char* out_data = static_cast(out->untyped_data()); + char* s_data = static_cast(s->untyped_data()); + char* u_data = static_cast(u->untyped_data()); + char* vt_data = static_cast(vt->untyped_data()); + int* info_data = info->typed_data(); + if (a_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); + } + + size_t out_step = m * n * ffi::ByteWidth(dataType); + size_t s_step = n * ffi::ByteWidth(ffi::ToReal(dataType)); + size_t u_step = 0; + size_t vt_step = 0; + if (compute_uv) { + u_step = m * (full_matrices ? m : n) * ffi::ByteWidth(dataType); + vt_step = n * n * ffi::ByteWidth(dataType); + } + for (auto i = 0; i < batch; ++i) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXgesvd( + handle.get(), params, job, job, m, n, aType, out_data, m, sType, s_data, + aType, u_data, m, aType, vt_data, n, aType, workspaceOnDevice, + workspaceInBytesOnDevice, workspaceOnHost.get(), workspaceInBytesOnHost, + info_data)); + out_data += out_step; + s_data += s_step; + u_data += u_step; + vt_data += vt_step; + ++info_data; + } + + return ffi::Error::Success(); +} + +#else + +template +ffi::Error GesvdImpl(int64_t batch, int64_t rows, int64_t cols, + gpuStream_t stream, ffi::ScratchAllocator& scratch, + bool full_matrices, bool compute_uv, ffi::AnyBuffer a, + ffi::Result out, + ffi::Result s, + ffi::Result u, + ffi::Result vt, + ffi::Result> info) { + FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows)); + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + signed char job = compute_uv ? (full_matrices ? 'A' : 'S') : 'N'; + + FFI_ASSIGN_OR_RETURN(int lwork, + solver::GesvdBufferSize(handle.get(), job, m, n)); + FFI_ASSIGN_OR_RETURN(auto workspace, + AllocateWorkspace(scratch, lwork, "gesvd")); + auto a_data = static_cast(a.untyped_data()); + auto out_data = static_cast(out->untyped_data()); + auto s_data = static_cast::value*>(s->untyped_data()); + auto u_data = compute_uv ? static_cast(u->untyped_data()) : nullptr; + auto vt_data = compute_uv ? static_cast(vt->untyped_data()) : nullptr; + auto info_data = info->typed_data(); + if (a_data != out_data) { + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream))); + } + + int out_step = m * n; + int u_step = compute_uv ? m * (full_matrices ? m : n) : 0; + int vt_step = compute_uv ? n * n : 0; + for (auto i = 0; i < batch; ++i) { + FFI_RETURN_IF_ERROR_STATUS( + solver::Gesvd(handle.get(), job, m, n, out_data, s_data, u_data, + vt_data, workspace, lwork, info_data)); + out_data += out_step; + s_data += n; // n is always less than m because of the logic in dispatch. + u_data += u_step; + vt_data += vt_step; + ++info_data; + } + return ffi::Error::Success(); +} + +#endif // JAX_GPU_64_BIT + +ffi::Error GesvdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, + bool full_matrices, bool compute_uv, bool transposed, + ffi::AnyBuffer a, ffi::Result out, + ffi::Result s, + ffi::Result u, + ffi::Result vt, + ffi::Result> info) { + auto dataType = a.element_type(); + if (out->element_type() != dataType || + s->element_type() != ffi::ToReal(dataType) || + u->element_type() != dataType || vt->element_type() != dataType) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to gesvd must have the same element type"); + } + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(a.dimensions())); + int64_t m = transposed ? cols : rows; + int64_t n = transposed ? rows : cols; + if (n > m) { + return ffi::Error::InvalidArgument( + "The GPU implementation of gesvd requires that the input matrix be m x " + "n with m >= n"); + } + FFI_RETURN_IF_ERROR( + CheckShape(out->dimensions(), {batch, rows, cols}, "out", "gesvd")); + FFI_RETURN_IF_ERROR(CheckShape(s->dimensions(), {batch, n}, "s", "gesvd")); + if (compute_uv) { + if (full_matrices) { + FFI_RETURN_IF_ERROR( + CheckShape(u->dimensions(), {batch, m, m}, "u", "gesvd")); + } else { + if (transposed) { + FFI_RETURN_IF_ERROR( + CheckShape(u->dimensions(), {batch, n, m}, "u", "gesvd")); + } else { + FFI_RETURN_IF_ERROR( + CheckShape(u->dimensions(), {batch, m, n}, "u", "gesvd")); + } + } + FFI_RETURN_IF_ERROR( + CheckShape(vt->dimensions(), {batch, n, n}, "vt", "gesvd")); + } + FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "gesvd")); + +#if JAX_GPU_64_BIT + return Gesvd64Impl(batch, m, n, stream, scratch, full_matrices, compute_uv, a, + out, s, u, vt, info); +#else + SOLVER_DISPATCH_IMPL(GesvdImpl, batch, m, n, stream, scratch, full_matrices, + compute_uv, a, out, s, u, vt, info); + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in gesvd", absl::FormatStreamed(dataType))); +#endif +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GesvdFfi, GesvdDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Attr("full_matrices") + .Attr("compute_uv") + .Attr("transposed") + .Arg() // a + .Ret() // out + .Ret() // s + .Ret() // u + .Ret() // vt + .Ret>() // info +); + +#ifdef JAX_GPU_CUDA + +template +ffi::Error GesvdjImpl(int64_t batch, int64_t rows, int64_t cols, + gpuStream_t stream, ffi::ScratchAllocator& scratch, + bool full_matrices, bool compute_uv, ffi::AnyBuffer a, + ffi::Result out, + ffi::Result s, + ffi::Result u, + ffi::Result v, + ffi::Result> info) { + FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows)); + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + + gpusolverEigMode_t job = + compute_uv ? GPUSOLVER_EIG_MODE_VECTOR : GPUSOLVER_EIG_MODE_NOVECTOR; + int econ = full_matrices ? 0 : 1; + + gpuGesvdjInfo_t params; + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateGesvdjInfo(¶ms)); + std::unique_ptr params_cleanup( + params, [](gpuGesvdjInfo_t p) { gpusolverDnDestroyGesvdjInfo(p); }); + + auto a_data = static_cast(a.untyped_data()); + auto out_data = static_cast(out->untyped_data()); + auto s_data = static_cast::value*>(s->untyped_data()); + auto u_data = static_cast(u->untyped_data()); + auto v_data = static_cast(v->untyped_data()); + auto info_data = info->typed_data(); + if (a_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); + } + + if (batch <= 1 || batch > std::numeric_limits::max() || m > 32 || + n > 32 || econ) { + FFI_ASSIGN_OR_RETURN(int lwork, solver::GesvdjBufferSize( + handle.get(), job, econ, m, n, params)); + FFI_ASSIGN_OR_RETURN(auto workspace, + AllocateWorkspace(scratch, lwork, "gesvdj")); + int k = std::min(m, n); + int out_step = m * n; + int u_step = m * (full_matrices ? m : k); + int v_step = n * (full_matrices ? n : k); + for (auto i = 0; i < batch; ++i) { + FFI_RETURN_IF_ERROR_STATUS(solver::Gesvdj( + handle.get(), job, econ, m, n, out_data, s_data, u_data, v_data, + workspace, lwork, info_data, params)); + out_data += out_step; + s_data += k; + u_data += u_step; + v_data += v_step; + ++info_data; + } + } else { + FFI_ASSIGN_OR_RETURN(int lwork, solver::GesvdjBatchedBufferSize( + handle.get(), job, m, n, params, + static_cast(batch))); + FFI_ASSIGN_OR_RETURN( + auto workspace, AllocateWorkspace(scratch, lwork, "gesvdj_batched")); + FFI_RETURN_IF_ERROR_STATUS(solver::GesvdjBatched( + handle.get(), job, m, n, out_data, s_data, u_data, v_data, workspace, + lwork, info_data, params, static_cast(batch))); + } + return ffi::Error::Success(); +} + +ffi::Error GesvdjDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, + bool full_matrices, bool compute_uv, ffi::AnyBuffer a, + ffi::Result out, + ffi::Result s, + ffi::Result u, + ffi::Result v, + ffi::Result> info) { + auto dataType = a.element_type(); + if (out->element_type() != dataType || + s->element_type() != ffi::ToReal(dataType) || + u->element_type() != dataType || v->element_type() != dataType) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to gesvdj must have the same element type"); + } + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(a.dimensions())); + int64_t size = std::min(rows, cols); + FFI_RETURN_IF_ERROR( + CheckShape(out->dimensions(), {batch, rows, cols}, "out", "gesvdj")); + FFI_RETURN_IF_ERROR( + CheckShape(s->dimensions(), {batch, size}, "s", "gesvdj")); + // U and V must always be allocated even if compute_uv is false. + if (full_matrices) { + FFI_RETURN_IF_ERROR( + CheckShape(u->dimensions(), {batch, rows, rows}, "u", "gesvdj")); + FFI_RETURN_IF_ERROR( + CheckShape(v->dimensions(), {batch, cols, cols}, "v", "gesvdj")); + } else { + FFI_RETURN_IF_ERROR( + CheckShape(u->dimensions(), {batch, rows, size}, "u", "gesvdj")); + FFI_RETURN_IF_ERROR( + CheckShape(v->dimensions(), {batch, cols, size}, "v", "gesvdj")); + } + FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "gesvdj")); + + SOLVER_DISPATCH_IMPL(GesvdjImpl, batch, rows, cols, stream, scratch, + full_matrices, compute_uv, a, out, s, u, v, info); + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in gesvdj", absl::FormatStreamed(dataType))); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GesvdjFfi, GesvdjDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Attr("full_matrices") + .Attr("compute_uv") + .Arg() // a + .Ret() // out + .Ret() // s + .Ret() // u + .Ret() // v + .Ret>() // info +); + +#endif // JAX_GPU_CUDA + #undef SOLVER_DISPATCH_IMPL #undef SOLVER_BLAS_DISPATCH_IMPL diff --git a/jaxlib/gpu/solver_kernels_ffi.h b/jaxlib/gpu/solver_kernels_ffi.h index 3bebe40be..022564eb1 100644 --- a/jaxlib/gpu/solver_kernels_ffi.h +++ b/jaxlib/gpu/solver_kernels_ffi.h @@ -35,6 +35,11 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GeqrfFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(OrgqrFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(SyevdFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(SyrkFfi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(GesvdFfi); + +#ifdef JAX_GPU_CUDA +XLA_FFI_DECLARE_HANDLER_SYMBOL(GesvdjFfi); +#endif // JAX_GPU_CUDA } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index bc61d5818..fa247b08b 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -78,6 +78,8 @@ typedef cusolverStatus_t gpusolverStatus_t; typedef cusolverEigMode_t gpusolverEigMode_t; typedef syevjInfo gpuSyevjInfo; typedef syevjInfo_t gpuSyevjInfo_t; +typedef gesvdjInfo gpuGesvdjInfo; +typedef gesvdjInfo_t gpuGesvdjInfo_t; typedef cusparseIndexType_t gpusparseIndexType_t; typedef cusparseHandle_t gpusparseHandle_t; typedef cusparseOperation_t gpusparseOperation_t; @@ -120,6 +122,8 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpusolverDnSetStream cusolverDnSetStream #define gpusolverDnCreateSyevjInfo cusolverDnCreateSyevjInfo #define gpusolverDnDestroySyevjInfo cusolverDnDestroySyevjInfo +#define gpusolverDnCreateGesvdjInfo cusolverDnCreateGesvdjInfo +#define gpusolverDnDestroyGesvdjInfo cusolverDnDestroyGesvdjInfo #define gpusolverDnSgeqrf cusolverDnSgeqrf #define gpusolverDnDgeqrf cusolverDnDgeqrf #define gpusolverDnCgeqrf cusolverDnCgeqrf @@ -184,6 +188,22 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; cusolverDnCgesvd_bufferSize(h, m, n, lwork) #define gpusolverDnZgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \ cusolverDnZgesvd_bufferSize(h, m, n, lwork) +#define gpusolverDnSgesvdj cusolverDnSgesvdj +#define gpusolverDnDgesvdj cusolverDnDgesvdj +#define gpusolverDnCgesvdj cusolverDnCgesvdj +#define gpusolverDnZgesvdj cusolverDnZgesvdj +#define gpusolverDnSgesvdj_bufferSize cusolverDnSgesvdj_bufferSize +#define gpusolverDnDgesvdj_bufferSize cusolverDnDgesvdj_bufferSize +#define gpusolverDnCgesvdj_bufferSize cusolverDnCgesvdj_bufferSize +#define gpusolverDnZgesvdj_bufferSize cusolverDnZgesvdj_bufferSize +#define gpusolverDnSgesvdjBatched cusolverDnSgesvdjBatched +#define gpusolverDnDgesvdjBatched cusolverDnDgesvdjBatched +#define gpusolverDnCgesvdjBatched cusolverDnCgesvdjBatched +#define gpusolverDnZgesvdjBatched cusolverDnZgesvdjBatched +#define gpusolverDnSgesvdjBatched_bufferSize cusolverDnSgesvdjBatched_bufferSize +#define gpusolverDnDgesvdjBatched_bufferSize cusolverDnDgesvdjBatched_bufferSize +#define gpusolverDnCgesvdjBatched_bufferSize cusolverDnCgesvdjBatched_bufferSize +#define gpusolverDnZgesvdjBatched_bufferSize cusolverDnZgesvdjBatched_bufferSize #define gpusolverDnSsytrd_bufferSize cusolverDnSsytrd_bufferSize #define gpusolverDnDsytrd_bufferSize cusolverDnDsytrd_bufferSize #define gpusolverDnChetrd_bufferSize cusolverDnChetrd_bufferSize @@ -196,6 +216,7 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define GPUSOLVER_FILL_MODE_LOWER CUBLAS_FILL_MODE_LOWER #define GPUSOLVER_FILL_MODE_UPPER CUBLAS_FILL_MODE_UPPER #define GPUSOLVER_EIG_MODE_VECTOR CUSOLVER_EIG_MODE_VECTOR +#define GPUSOLVER_EIG_MODE_NOVECTOR CUSOLVER_EIG_MODE_NOVECTOR #define GPUSOLVER_STATUS_SUCCESS CUSOLVER_STATUS_SUCCESS #define GPUBLAS_OP_N CUBLAS_OP_N @@ -311,6 +332,22 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpuGetDeviceProperties cudaGetDeviceProperties #define gpuLaunchCooperativeKernel cudaLaunchCooperativeKernel +#define JAX_GPU_64_BIT 1 + +#define GPU_R_32F CUDA_R_32F +#define GPU_R_64F CUDA_R_64F +#define GPU_C_32F CUDA_C_32F +#define GPU_C_64F CUDA_C_64F + +typedef cudaDataType gpuDataType; +typedef cusolverDnParams gpusolverDnParams; +typedef cusolverDnParams_t gpusolverDnParams_t; +#define gpusolverDnCreateParams cusolverDnCreateParams +#define gpusolverDnDestroyParams cusolverDnDestroyParams + +#define gpusolverDnXgesvd_bufferSize cusolverDnXgesvd_bufferSize +#define gpusolverDnXgesvd cusolverDnXgesvd + namespace jax::JAX_GPU_NAMESPACE { namespace { constexpr uint32_t kNumThreadsPerWarp = 32; @@ -331,6 +368,7 @@ constexpr uint32_t kNumThreadsPerWarp = 32; #define JAX_GPU_PREFIX "hip" #define JAX_GPU_HAVE_SPARSE 1 +#define JAX_GPU_64_BIT 0 #define JAX_GPU_HAVE_FP8 0 typedef hipFloatComplex gpuComplex; @@ -472,6 +510,7 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; #define GPUSOLVER_FILL_MODE_LOWER HIPSOLVER_FILL_MODE_LOWER #define GPUSOLVER_FILL_MODE_UPPER HIPSOLVER_FILL_MODE_UPPER #define GPUSOLVER_EIG_MODE_VECTOR HIPSOLVER_EIG_MODE_VECTOR +#define GPUSOLVER_EIG_MODE_NOVECTOR HIPSOLVER_EIG_MODE_NOVECTOR #define GPUSOLVER_STATUS_SUCCESS HIPSOLVER_STATUS_SUCCESS #define GPUBLAS_OP_N HIPBLAS_OP_N