mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Refactor gpusolver kernel definitions into separate build target.
There is a lot of boilerplate required for each new custom call to cuSolver / cuBLAS, and having both the FFI logic and the framework wrappers in the same library was getting unwieldy. This change adds a new "interface" target which just includes the shims to wrap cuSolver/BLAS functions, and then these are used from `solver_kernels_ffi` where the FFI logic lives. PiperOrigin-RevId: 673832309
This commit is contained in:
parent
2067da818e
commit
a3bf75e442
@ -227,6 +227,22 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cusolver_interface",
|
||||
srcs = ["//jaxlib/gpu:solver_interface.cc"],
|
||||
hdrs = ["//jaxlib/gpu:solver_interface.h"],
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@xla//xla/tsl/cuda:cublas",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/cuda:cusolver",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cusolver_kernels_ffi",
|
||||
srcs = ["//jaxlib/gpu:solver_kernels_ffi.cc"],
|
||||
@ -237,6 +253,7 @@ cc_library(
|
||||
":cuda_make_batch_pointers",
|
||||
":cuda_solver_handle_pool",
|
||||
":cuda_vendor",
|
||||
":cusolver_interface",
|
||||
"//jaxlib:ffi_helpers",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
|
@ -53,6 +53,8 @@ exports_files(srcs = [
|
||||
"solver.cc",
|
||||
"solver_handle_pool.cc",
|
||||
"solver_handle_pool.h",
|
||||
"solver_interface.cc",
|
||||
"solver_interface.h",
|
||||
"solver_kernels.cc",
|
||||
"solver_kernels.h",
|
||||
"solver_kernels_ffi.cc",
|
||||
|
237
jaxlib/gpu/solver_interface.cc
Normal file
237
jaxlib/gpu/solver_interface.cc
Normal file
@ -0,0 +1,237 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/gpu/solver_interface.h"
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
namespace solver {
|
||||
|
||||
// LU decomposition: getrf
|
||||
|
||||
#define JAX_GPU_DEFINE_GETRF(Type, Name) \
|
||||
template <> \
|
||||
absl::StatusOr<int> GetrfBufferSize<Type>(gpusolverDnHandle_t handle, int m, \
|
||||
int n) { \
|
||||
int lwork; \
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
|
||||
Name##_bufferSize(handle, m, n, /*A=*/nullptr, m, &lwork))); \
|
||||
return lwork; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
absl::Status Getrf<Type>(gpusolverDnHandle_t handle, int m, int n, Type *a, \
|
||||
Type *workspace, int lwork, int *ipiv, int *info) { \
|
||||
return JAX_AS_STATUS( \
|
||||
Name(handle, m, n, a, m, workspace, lwork, ipiv, info)); \
|
||||
}
|
||||
|
||||
JAX_GPU_DEFINE_GETRF(float, gpusolverDnSgetrf);
|
||||
JAX_GPU_DEFINE_GETRF(double, gpusolverDnDgetrf);
|
||||
JAX_GPU_DEFINE_GETRF(gpuComplex, gpusolverDnCgetrf);
|
||||
JAX_GPU_DEFINE_GETRF(gpuDoubleComplex, gpusolverDnZgetrf);
|
||||
#undef JAX_GPU_DEFINE_GETRF
|
||||
|
||||
#define JAX_GPU_DEFINE_GETRF_BATCHED(Type, Name) \
|
||||
template <> \
|
||||
absl::Status GetrfBatched<Type>(gpublasHandle_t handle, int n, Type **a, \
|
||||
int lda, int *ipiv, int *info, int batch) { \
|
||||
return JAX_AS_STATUS(Name(handle, n, a, lda, ipiv, info, batch)); \
|
||||
}
|
||||
|
||||
JAX_GPU_DEFINE_GETRF_BATCHED(float, gpublasSgetrfBatched);
|
||||
JAX_GPU_DEFINE_GETRF_BATCHED(double, gpublasDgetrfBatched);
|
||||
JAX_GPU_DEFINE_GETRF_BATCHED(gpublasComplex, gpublasCgetrfBatched);
|
||||
JAX_GPU_DEFINE_GETRF_BATCHED(gpublasDoubleComplex, gpublasZgetrfBatched);
|
||||
#undef JAX_GPU_DEFINE_GETRF_BATCHED
|
||||
|
||||
// QR decomposition: geqrf
|
||||
|
||||
#define JAX_GPU_DEFINE_GEQRF(Type, Name) \
|
||||
template <> \
|
||||
absl::StatusOr<int> GeqrfBufferSize<Type>(gpusolverDnHandle_t handle, int m, \
|
||||
int n) { \
|
||||
int lwork; \
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
|
||||
Name##_bufferSize(handle, m, n, /*A=*/nullptr, m, &lwork))); \
|
||||
return lwork; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
absl::Status Geqrf<Type>(gpusolverDnHandle_t handle, int m, int n, Type *a, \
|
||||
Type *tau, Type *workspace, int lwork, int *info) { \
|
||||
return JAX_AS_STATUS( \
|
||||
Name(handle, m, n, a, m, tau, workspace, lwork, info)); \
|
||||
}
|
||||
|
||||
JAX_GPU_DEFINE_GEQRF(float, gpusolverDnSgeqrf);
|
||||
JAX_GPU_DEFINE_GEQRF(double, gpusolverDnDgeqrf);
|
||||
JAX_GPU_DEFINE_GEQRF(gpuComplex, gpusolverDnCgeqrf);
|
||||
JAX_GPU_DEFINE_GEQRF(gpuDoubleComplex, gpusolverDnZgeqrf);
|
||||
#undef JAX_GPU_DEFINE_GEQRF
|
||||
|
||||
#define JAX_GPU_DEFINE_GEQRF_BATCHED(Type, Name) \
|
||||
template <> \
|
||||
absl::Status GeqrfBatched<Type>(gpublasHandle_t handle, int m, int n, \
|
||||
Type **a, Type **tau, int *info, \
|
||||
int batch) { \
|
||||
return JAX_AS_STATUS(Name(handle, m, n, a, m, tau, info, batch)); \
|
||||
}
|
||||
|
||||
JAX_GPU_DEFINE_GEQRF_BATCHED(float, gpublasSgeqrfBatched);
|
||||
JAX_GPU_DEFINE_GEQRF_BATCHED(double, gpublasDgeqrfBatched);
|
||||
JAX_GPU_DEFINE_GEQRF_BATCHED(gpublasComplex, gpublasCgeqrfBatched);
|
||||
JAX_GPU_DEFINE_GEQRF_BATCHED(gpublasDoubleComplex, gpublasZgeqrfBatched);
|
||||
#undef JAX_GPU_DEFINE_GEQRF_BATCHED
|
||||
|
||||
// Householder transformations: orgqr
|
||||
|
||||
#define JAX_GPU_DEFINE_ORGQR(Type, Name) \
|
||||
template <> \
|
||||
absl::StatusOr<int> OrgqrBufferSize<Type>(gpusolverDnHandle_t handle, int m, \
|
||||
int n, int k) { \
|
||||
int lwork; \
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(Name##_bufferSize( \
|
||||
handle, m, n, k, /*A=*/nullptr, /*lda=*/m, /*tau=*/nullptr, &lwork))); \
|
||||
return lwork; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
absl::Status Orgqr<Type>(gpusolverDnHandle_t handle, int m, int n, int k, \
|
||||
Type *a, Type *tau, Type *workspace, int lwork, \
|
||||
int *info) { \
|
||||
return JAX_AS_STATUS( \
|
||||
Name(handle, m, n, k, a, m, tau, workspace, lwork, info)); \
|
||||
}
|
||||
|
||||
JAX_GPU_DEFINE_ORGQR(float, gpusolverDnSorgqr);
|
||||
JAX_GPU_DEFINE_ORGQR(double, gpusolverDnDorgqr);
|
||||
JAX_GPU_DEFINE_ORGQR(gpuComplex, gpusolverDnCungqr);
|
||||
JAX_GPU_DEFINE_ORGQR(gpuDoubleComplex, gpusolverDnZungqr);
|
||||
#undef JAX_GPU_DEFINE_ORGQR
|
||||
|
||||
// Symmetric (Hermitian) eigendecomposition:
|
||||
// * Jacobi algorithm: syevj/heevj (batches of matrices up to 32)
|
||||
// * QR algorithm: syevd/heevd
|
||||
|
||||
#define JAX_GPU_DEFINE_SYEVJ(Type, Name) \
|
||||
template <> \
|
||||
absl::StatusOr<int> SyevjBufferSize<Type>( \
|
||||
gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
|
||||
gpusolverFillMode_t uplo, int n, gpuSyevjInfo_t params) { \
|
||||
int lwork; \
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
|
||||
Name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \
|
||||
/*w=*/nullptr, &lwork, params))); \
|
||||
return lwork; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
absl::Status Syevj<Type>( \
|
||||
gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
|
||||
gpusolverFillMode_t uplo, int n, Type *a, RealType<Type>::value *w, \
|
||||
Type *workspace, int lwork, int *info, gpuSyevjInfo_t params) { \
|
||||
return JAX_AS_STATUS( \
|
||||
Name(handle, jobz, uplo, n, a, n, w, workspace, lwork, info, params)); \
|
||||
}
|
||||
|
||||
JAX_GPU_DEFINE_SYEVJ(float, gpusolverDnSsyevj);
|
||||
JAX_GPU_DEFINE_SYEVJ(double, gpusolverDnDsyevj);
|
||||
JAX_GPU_DEFINE_SYEVJ(gpuComplex, gpusolverDnCheevj);
|
||||
JAX_GPU_DEFINE_SYEVJ(gpuDoubleComplex, gpusolverDnZheevj);
|
||||
#undef JAX_GPU_DEFINE_SYEVJ
|
||||
|
||||
#define JAX_GPU_DEFINE_SYEVJ_BATCHED(Type, Name) \
|
||||
template <> \
|
||||
absl::StatusOr<int> SyevjBatchedBufferSize<Type>( \
|
||||
gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
|
||||
gpusolverFillMode_t uplo, int n, gpuSyevjInfo_t params, int batch) { \
|
||||
int lwork; \
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
|
||||
Name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \
|
||||
/*w=*/nullptr, &lwork, params, batch))); \
|
||||
return lwork; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
absl::Status SyevjBatched<Type>( \
|
||||
gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
|
||||
gpusolverFillMode_t uplo, int n, Type *a, RealType<Type>::value *w, \
|
||||
Type *workspace, int lwork, int *info, gpuSyevjInfo_t params, \
|
||||
int batch) { \
|
||||
return JAX_AS_STATUS(Name(handle, jobz, uplo, n, a, n, w, workspace, \
|
||||
lwork, info, params, batch)); \
|
||||
}
|
||||
|
||||
JAX_GPU_DEFINE_SYEVJ_BATCHED(float, gpusolverDnSsyevjBatched);
|
||||
JAX_GPU_DEFINE_SYEVJ_BATCHED(double, gpusolverDnDsyevjBatched);
|
||||
JAX_GPU_DEFINE_SYEVJ_BATCHED(gpuComplex, gpusolverDnCheevjBatched);
|
||||
JAX_GPU_DEFINE_SYEVJ_BATCHED(gpuDoubleComplex, gpusolverDnZheevjBatched);
|
||||
#undef JAX_GPU_DEFINE_SYEVJ_BATCHED
|
||||
|
||||
#define JAX_GPU_DEFINE_SYEVD(Type, Name) \
|
||||
template <> \
|
||||
absl::StatusOr<int> SyevdBufferSize<Type>(gpusolverDnHandle_t handle, \
|
||||
gpusolverEigMode_t jobz, \
|
||||
gpusolverFillMode_t uplo, int n) { \
|
||||
int lwork; \
|
||||
JAX_RETURN_IF_ERROR( \
|
||||
JAX_AS_STATUS(Name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, \
|
||||
/*lda=*/n, /*w=*/nullptr, &lwork))); \
|
||||
return lwork; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
absl::Status Syevd<Type>(gpusolverDnHandle_t handle, \
|
||||
gpusolverEigMode_t jobz, gpusolverFillMode_t uplo, \
|
||||
int n, Type *a, RealType<Type>::value *w, \
|
||||
Type *workspace, int lwork, int *info) { \
|
||||
return JAX_AS_STATUS( \
|
||||
Name(handle, jobz, uplo, n, a, n, w, workspace, lwork, info)); \
|
||||
}
|
||||
|
||||
JAX_GPU_DEFINE_SYEVD(float, gpusolverDnSsyevd);
|
||||
JAX_GPU_DEFINE_SYEVD(double, gpusolverDnDsyevd);
|
||||
JAX_GPU_DEFINE_SYEVD(gpuComplex, gpusolverDnCheevd);
|
||||
JAX_GPU_DEFINE_SYEVD(gpuDoubleComplex, gpusolverDnZheevd);
|
||||
#undef JAX_GPU_DEFINE_SYEVD
|
||||
|
||||
// Symmetric rank-k update: syrk
|
||||
|
||||
#define JAX_GPU_DEFINE_SYRK(Type, Name) \
|
||||
template <> \
|
||||
absl::Status Syrk<Type>(gpublasHandle_t handle, gpublasFillMode_t uplo, \
|
||||
gpublasOperation_t trans, int n, int k, \
|
||||
const Type *alpha, const Type *a, const Type *beta, \
|
||||
Type *c) { \
|
||||
int lda = trans == GPUBLAS_OP_N ? n : k; \
|
||||
return JAX_AS_STATUS( \
|
||||
Name(handle, uplo, trans, n, k, alpha, a, lda, beta, c, n)); \
|
||||
}
|
||||
|
||||
JAX_GPU_DEFINE_SYRK(float, gpublasSsyrk);
|
||||
JAX_GPU_DEFINE_SYRK(double, gpublasDsyrk);
|
||||
JAX_GPU_DEFINE_SYRK(gpublasComplex, gpublasCsyrk);
|
||||
JAX_GPU_DEFINE_SYRK(gpublasDoubleComplex, gpublasZsyrk);
|
||||
#undef JAX_GPU_DEFINE_SYRK
|
||||
|
||||
} // namespace solver
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
174
jaxlib/gpu/solver_interface.h
Normal file
174
jaxlib/gpu/solver_interface.h
Normal file
@ -0,0 +1,174 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file defines a standard interface to the GPU linear algebra libraries.
|
||||
|
||||
#ifndef JAXLIB_GPU_SOLVER_INTERFACE_H_
|
||||
#define JAXLIB_GPU_SOLVER_INTERFACE_H_
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
namespace solver {
|
||||
|
||||
template <typename T>
|
||||
struct RealType {
|
||||
using value = T;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct RealType<gpuComplex> {
|
||||
using value = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct RealType<gpuDoubleComplex> {
|
||||
using value = double;
|
||||
};
|
||||
|
||||
#define JAX_GPU_SOLVER_EXPAND_DEFINITION(ReturnType, FunctionName) \
|
||||
template <typename T> \
|
||||
ReturnType FunctionName( \
|
||||
JAX_GPU_SOLVER_##FunctionName##_ARGS(T, typename RealType<T>::value)) { \
|
||||
return absl::UnimplementedError(absl::StrFormat( \
|
||||
#FunctionName " not implemented for type %s", typeid(T).name())); \
|
||||
} \
|
||||
template <> \
|
||||
ReturnType FunctionName<float>( \
|
||||
JAX_GPU_SOLVER_##FunctionName##_ARGS(float, float)); \
|
||||
template <> \
|
||||
ReturnType FunctionName<double>( \
|
||||
JAX_GPU_SOLVER_##FunctionName##_ARGS(double, double)); \
|
||||
template <> \
|
||||
ReturnType FunctionName<gpuComplex>( \
|
||||
JAX_GPU_SOLVER_##FunctionName##_ARGS(gpuComplex, float)); \
|
||||
template <> \
|
||||
ReturnType FunctionName<gpuDoubleComplex>( \
|
||||
JAX_GPU_SOLVER_##FunctionName##_ARGS(gpuDoubleComplex, double))
|
||||
|
||||
// LU decomposition: getrf
|
||||
|
||||
#define JAX_GPU_SOLVER_GetrfBufferSize_ARGS(Type, ...) \
|
||||
gpusolverDnHandle_t handle, int m, int n
|
||||
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr<int>, GetrfBufferSize);
|
||||
#undef JAX_GPU_SOLVER_GetrfBufferSize_ARGS
|
||||
|
||||
#define JAX_GPU_SOLVER_Getrf_ARGS(Type, ...) \
|
||||
gpusolverDnHandle_t handle, int m, int n, Type *a, Type *workspace, \
|
||||
int lwork, int *ipiv, int *info
|
||||
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Getrf);
|
||||
#undef JAX_GPU_SOLVER_Getrf_ARGS
|
||||
|
||||
#define JAX_GPU_SOLVER_GetrfBatched_ARGS(Type, ...) \
|
||||
gpublasHandle_t handle, int n, Type **a, int lda, int *ipiv, int *info, \
|
||||
int batch
|
||||
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, GetrfBatched);
|
||||
#undef JAX_GPU_SOLVER_GetrfBatched_ARGS
|
||||
|
||||
// QR decomposition: geqrf
|
||||
|
||||
#define JAX_GPU_SOLVER_GeqrfBufferSize_ARGS(Type, ...) \
|
||||
gpusolverDnHandle_t handle, int m, int n
|
||||
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr<int>, GeqrfBufferSize);
|
||||
#undef JAX_GPU_SOLVER_GeqrfBufferSize_ARGS
|
||||
|
||||
#define JAX_GPU_SOLVER_Geqrf_ARGS(Type, ...) \
|
||||
gpusolverDnHandle_t handle, int m, int n, Type *a, Type *tau, \
|
||||
Type *workspace, int lwork, int *info
|
||||
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Geqrf);
|
||||
#undef JAX_GPU_SOLVER_Geqrf_ARGS
|
||||
|
||||
#define JAX_GPU_SOLVER_GeqrfBatched_ARGS(Type, ...) \
|
||||
gpublasHandle_t handle, int m, int n, Type **a, Type **tau, int *info, \
|
||||
int batch
|
||||
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, GeqrfBatched);
|
||||
#undef JAX_GPU_SOLVER_GeqrfBatched_ARGS
|
||||
|
||||
// Householder transformations: orgqr
|
||||
|
||||
#define JAX_GPU_SOLVER_OrgqrBufferSize_ARGS(Type, ...) \
|
||||
gpusolverDnHandle_t handle, int m, int n, int k
|
||||
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr<int>, OrgqrBufferSize);
|
||||
#undef JAX_GPU_SOLVER_OrgqrBufferSize_ARGS
|
||||
|
||||
#define JAX_GPU_SOLVER_Orgqr_ARGS(Type, ...) \
|
||||
gpusolverDnHandle_t handle, int m, int n, int k, Type *a, Type *tau, \
|
||||
Type *workspace, int lwork, int *info
|
||||
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Orgqr);
|
||||
#undef JAX_GPU_SOLVER_Orgqr_ARGS
|
||||
|
||||
// Symmetric (Hermitian) eigendecomposition:
|
||||
// * Jacobi algorithm: syevj/heevj (batches of matrices up to 32)
|
||||
// * QR algorithm: syevd/heevd
|
||||
|
||||
#define JAX_GPU_SOLVER_SyevjBufferSize_ARGS(Type, ...) \
|
||||
gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
|
||||
gpusolverFillMode_t uplo, int n, gpuSyevjInfo_t params
|
||||
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr<int>, SyevjBufferSize);
|
||||
#undef JAX_GPU_SOLVER_SyevjBufferSize_ARGS
|
||||
|
||||
#define JAX_GPU_SOLVER_Syevj_ARGS(Type, Real) \
|
||||
gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
|
||||
gpusolverFillMode_t uplo, int n, Type *a, Real *w, Type *workspace, \
|
||||
int lwork, int *info, gpuSyevjInfo_t params
|
||||
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Syevj);
|
||||
#undef JAX_GPU_SOLVER_Syevj_ARGS
|
||||
|
||||
#define JAX_GPU_SOLVER_SyevjBatchedBufferSize_ARGS(Type, ...) \
|
||||
gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
|
||||
gpusolverFillMode_t uplo, int n, gpuSyevjInfo_t params, int batch
|
||||
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr<int>, SyevjBatchedBufferSize);
|
||||
#undef JAX_GPU_SOLVER_SyevjBatchedBufferSize_ARGS
|
||||
|
||||
#define JAX_GPU_SOLVER_SyevjBatched_ARGS(Type, Real) \
|
||||
gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
|
||||
gpusolverFillMode_t uplo, int n, Type *a, Real *w, Type *workspace, \
|
||||
int lwork, int *info, gpuSyevjInfo_t params, int batch
|
||||
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, SyevjBatched);
|
||||
#undef JAX_GPU_SOLVER_SyevjBatched_ARGS
|
||||
|
||||
#define JAX_GPU_SOLVER_SyevdBufferSize_ARGS(Type, ...) \
|
||||
gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
|
||||
gpusolverFillMode_t uplo, int n
|
||||
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr<int>, SyevdBufferSize);
|
||||
#undef JAX_GPU_SOLVER_SyevdBufferSize_ARGS
|
||||
|
||||
#define JAX_GPU_SOLVER_Syevd_ARGS(Type, Real) \
|
||||
gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
|
||||
gpusolverFillMode_t uplo, int n, Type *a, Real *w, Type *workspace, \
|
||||
int lwork, int *info
|
||||
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Syevd);
|
||||
#undef JAX_GPU_SOLVER_Syevd_ARGS
|
||||
|
||||
// Symmetric rank-k update: syrk
|
||||
|
||||
#define JAX_GPU_SOLVER_Syrk_ARGS(Type, ...) \
|
||||
gpublasHandle_t handle, gpublasFillMode_t uplo, gpublasOperation_t trans, \
|
||||
int n, int k, const Type *alpha, const Type *a, const Type *beta, \
|
||||
Type *c
|
||||
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Syrk);
|
||||
#undef JAX_GPU_SOLVER_Syrk_ARGS
|
||||
|
||||
#undef JAX_GPU_SOLVER_EXPAND_DEFINITION
|
||||
|
||||
} // namespace solver
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_GPU_SOLVER_INTERFACE_H_
|
@ -29,9 +29,13 @@ limitations under the License.
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/make_batch_pointers.h"
|
||||
#include "jaxlib/gpu/solver_handle_pool.h"
|
||||
#include "jaxlib/gpu/solver_interface.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
|
||||
#define JAX_FFI_RETURN_IF_GPU_ERROR(...) \
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(__VA_ARGS__))
|
||||
|
||||
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::JAX_GPU_NAMESPACE::SyevdAlgorithm);
|
||||
|
||||
namespace jax {
|
||||
@ -39,7 +43,6 @@ namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
namespace ffi = ::xla::ffi;
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
inline absl::StatusOr<T*> AllocateWorkspace(ffi::ScratchAllocator& scratch,
|
||||
int64_t size,
|
||||
@ -53,22 +56,6 @@ inline absl::StatusOr<T*> AllocateWorkspace(ffi::ScratchAllocator& scratch,
|
||||
return static_cast<T*>(maybe_workspace.value());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct RealType {
|
||||
using Type = T;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct RealType<gpuComplex> {
|
||||
using Type = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct RealType<gpuDoubleComplex> {
|
||||
using Type = double;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
#define SOLVER_DISPATCH_IMPL(impl, ...) \
|
||||
if (dataType == ffi::F32) { \
|
||||
return impl<float>(__VA_ARGS__); \
|
||||
@ -93,33 +80,6 @@ struct RealType<gpuDoubleComplex> {
|
||||
|
||||
// LU decomposition: getrf
|
||||
|
||||
namespace {
|
||||
#define GETRF_KERNEL_IMPL(type, name) \
|
||||
template <> \
|
||||
struct GetrfKernel<type> { \
|
||||
static absl::StatusOr<int> BufferSize(gpusolverDnHandle_t handle, int m, \
|
||||
int n) { \
|
||||
int lwork; \
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
|
||||
name##_bufferSize(handle, m, n, /*A=*/nullptr, /*lda=*/m, &lwork))); \
|
||||
return lwork; \
|
||||
} \
|
||||
static absl::Status Run(gpusolverDnHandle_t handle, int m, int n, type* a, \
|
||||
type* workspace, int lwork, int* ipiv, \
|
||||
int* info) { \
|
||||
return JAX_AS_STATUS( \
|
||||
name(handle, m, n, a, m, workspace, lwork, ipiv, info)); \
|
||||
} \
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct GetrfKernel;
|
||||
GETRF_KERNEL_IMPL(float, gpusolverDnSgetrf);
|
||||
GETRF_KERNEL_IMPL(double, gpusolverDnDgetrf);
|
||||
GETRF_KERNEL_IMPL(gpuComplex, gpusolverDnCgetrf);
|
||||
GETRF_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZgetrf);
|
||||
#undef GETRF_KERNEL_IMPL
|
||||
|
||||
template <typename T>
|
||||
ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols,
|
||||
gpuStream_t stream, ffi::ScratchAllocator& scratch,
|
||||
@ -131,7 +91,7 @@ ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols,
|
||||
|
||||
FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream));
|
||||
FFI_ASSIGN_OR_RETURN(int lwork,
|
||||
GetrfKernel<T>::BufferSize(handle.get(), m, n));
|
||||
solver::GetrfBufferSize<T>(handle.get(), m, n));
|
||||
FFI_ASSIGN_OR_RETURN(auto workspace,
|
||||
AllocateWorkspace<T>(scratch, lwork, "getrf"));
|
||||
|
||||
@ -140,13 +100,13 @@ ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols,
|
||||
auto ipiv_data = ipiv->typed_data();
|
||||
auto info_data = info->typed_data();
|
||||
if (a_data != out_data) {
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
|
||||
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
|
||||
JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
|
||||
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
|
||||
int ipiv_step = std::min(m, n);
|
||||
for (auto i = 0; i < batch; ++i) {
|
||||
FFI_RETURN_IF_ERROR_STATUS(GetrfKernel<T>::Run(
|
||||
FFI_RETURN_IF_ERROR_STATUS(solver::Getrf<T>(
|
||||
handle.get(), m, n, out_data, workspace, lwork, ipiv_data, info_data));
|
||||
out_data += m * n;
|
||||
ipiv_data += ipiv_step;
|
||||
@ -155,23 +115,6 @@ ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols,
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
#define GETRF_BATCHED_KERNEL_IMPL(type, name) \
|
||||
template <> \
|
||||
struct GetrfBatchedKernel<type> { \
|
||||
static absl::Status Run(gpublasHandle_t handle, int n, type** a, int lda, \
|
||||
int* ipiv, int* info, int batch) { \
|
||||
return JAX_AS_STATUS(name(handle, n, a, lda, ipiv, info, batch)); \
|
||||
} \
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct GetrfBatchedKernel;
|
||||
GETRF_BATCHED_KERNEL_IMPL(float, gpublasSgetrfBatched);
|
||||
GETRF_BATCHED_KERNEL_IMPL(double, gpublasDgetrfBatched);
|
||||
GETRF_BATCHED_KERNEL_IMPL(gpublasComplex, gpublasCgetrfBatched);
|
||||
GETRF_BATCHED_KERNEL_IMPL(gpublasDoubleComplex, gpublasZgetrfBatched);
|
||||
#undef GETRF_BATCHED_KERNEL_IMPL
|
||||
|
||||
template <typename T>
|
||||
ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream,
|
||||
ffi::ScratchAllocator& scratch, ffi::AnyBuffer a,
|
||||
@ -188,15 +131,15 @@ ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream,
|
||||
auto ipiv_data = ipiv->typed_data();
|
||||
auto info_data = info->typed_data();
|
||||
if (a_data != out_data) {
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
|
||||
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
|
||||
JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
|
||||
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
|
||||
MakeBatchPointersAsync(stream, out_data, batch_ptrs, batch,
|
||||
sizeof(T) * n * n);
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError()));
|
||||
JAX_FFI_RETURN_IF_GPU_ERROR(gpuGetLastError());
|
||||
|
||||
FFI_RETURN_IF_ERROR_STATUS(GetrfBatchedKernel<T>::Run(
|
||||
FFI_RETURN_IF_ERROR_STATUS(solver::GetrfBatched(
|
||||
handle.get(), n, batch_ptrs, n, ipiv_data, info_data, batch));
|
||||
|
||||
return ffi::Error::Success();
|
||||
@ -228,7 +171,6 @@ ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
|
||||
return ffi::Error::InvalidArgument(absl::StrFormat(
|
||||
"Unsupported dtype %s in getrf", absl::FormatStreamed(dataType)));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(GetrfFfi, GetrfDispatch,
|
||||
ffi::Ffi::Bind()
|
||||
@ -242,33 +184,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GetrfFfi, GetrfDispatch,
|
||||
|
||||
// QR decomposition: geqrf
|
||||
|
||||
namespace {
|
||||
#define GEQRF_KERNEL_IMPL(type, name) \
|
||||
template <> \
|
||||
struct GeqrfKernel<type> { \
|
||||
static absl::StatusOr<int> BufferSize(gpusolverDnHandle_t handle, int m, \
|
||||
int n) { \
|
||||
int lwork; \
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
|
||||
name##_bufferSize(handle, m, n, /*A=*/nullptr, /*lda=*/m, &lwork))); \
|
||||
return lwork; \
|
||||
} \
|
||||
static absl::Status Run(gpusolverDnHandle_t handle, int m, int n, type* a, \
|
||||
type* tau, type* workspace, int lwork, \
|
||||
int* info) { \
|
||||
return JAX_AS_STATUS( \
|
||||
name(handle, m, n, a, m, tau, workspace, lwork, info)); \
|
||||
} \
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct GeqrfKernel;
|
||||
GEQRF_KERNEL_IMPL(float, gpusolverDnSgeqrf);
|
||||
GEQRF_KERNEL_IMPL(double, gpusolverDnDgeqrf);
|
||||
GEQRF_KERNEL_IMPL(gpuComplex, gpusolverDnCgeqrf);
|
||||
GEQRF_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZgeqrf);
|
||||
#undef GEQRF_KERNEL_IMPL
|
||||
|
||||
template <typename T>
|
||||
ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols,
|
||||
gpuStream_t stream, ffi::ScratchAllocator& scratch,
|
||||
@ -279,7 +194,7 @@ ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols,
|
||||
|
||||
FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream));
|
||||
FFI_ASSIGN_OR_RETURN(int lwork,
|
||||
GeqrfKernel<T>::BufferSize(handle.get(), m, n));
|
||||
solver::GeqrfBufferSize<T>(handle.get(), m, n));
|
||||
|
||||
FFI_ASSIGN_OR_RETURN(auto workspace,
|
||||
AllocateWorkspace<T>(scratch, lwork, "geqrf"));
|
||||
@ -292,14 +207,14 @@ ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols,
|
||||
auto out_data = static_cast<T*>(out->untyped_data());
|
||||
auto tau_data = static_cast<T*>(tau->untyped_data());
|
||||
if (a_data != out_data) {
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
|
||||
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
|
||||
JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
|
||||
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
|
||||
int out_step = m * n;
|
||||
int tau_step = std::min(m, n);
|
||||
for (auto i = 0; i < batch; ++i) {
|
||||
FFI_RETURN_IF_ERROR_STATUS(GeqrfKernel<T>::Run(
|
||||
FFI_RETURN_IF_ERROR_STATUS(solver::Geqrf<T>(
|
||||
handle.get(), m, n, out_data, tau_data, workspace, lwork, info));
|
||||
out_data += out_step;
|
||||
tau_data += tau_step;
|
||||
@ -307,23 +222,6 @@ ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols,
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
#define GEQRF_BATCHED_KERNEL_IMPL(type, name) \
|
||||
template <> \
|
||||
struct GeqrfBatchedKernel<type> { \
|
||||
static absl::Status Run(gpublasHandle_t handle, int m, int n, type** a, \
|
||||
type** tau, int* info, int batch) { \
|
||||
return JAX_AS_STATUS(name(handle, m, n, a, m, tau, info, batch)); \
|
||||
} \
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct GeqrfBatchedKernel;
|
||||
GEQRF_BATCHED_KERNEL_IMPL(float, gpublasSgeqrfBatched);
|
||||
GEQRF_BATCHED_KERNEL_IMPL(double, gpublasDgeqrfBatched);
|
||||
GEQRF_BATCHED_KERNEL_IMPL(gpublasComplex, gpublasCgeqrfBatched);
|
||||
GEQRF_BATCHED_KERNEL_IMPL(gpublasDoubleComplex, gpublasZgeqrfBatched);
|
||||
#undef GEQRF_BATCHED_KERNEL_IMPL
|
||||
|
||||
template <typename T>
|
||||
ffi::Error GeqrfBatchedImpl(int64_t batch, int64_t rows, int64_t cols,
|
||||
gpuStream_t stream, ffi::ScratchAllocator& scratch,
|
||||
@ -341,21 +239,21 @@ ffi::Error GeqrfBatchedImpl(int64_t batch, int64_t rows, int64_t cols,
|
||||
auto out_data = out->untyped_data();
|
||||
auto tau_data = tau->untyped_data();
|
||||
if (a_data != out_data) {
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
|
||||
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
|
||||
JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
|
||||
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
|
||||
MakeBatchPointersAsync(stream, out_data, out_batch_ptrs, batch,
|
||||
sizeof(T) * m * n);
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError()));
|
||||
JAX_FFI_RETURN_IF_GPU_ERROR(gpuGetLastError());
|
||||
MakeBatchPointersAsync(stream, tau_data, tau_batch_ptrs, batch,
|
||||
sizeof(T) * std::min(m, n));
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError()));
|
||||
JAX_FFI_RETURN_IF_GPU_ERROR(gpuGetLastError());
|
||||
|
||||
// We ignore the output value of `info` because it is only used for shape
|
||||
// checking.
|
||||
int info;
|
||||
FFI_RETURN_IF_ERROR_STATUS(GeqrfBatchedKernel<T>::Run(
|
||||
FFI_RETURN_IF_ERROR_STATUS(solver::GeqrfBatched<T>(
|
||||
handle.get(), m, n, out_batch_ptrs, tau_batch_ptrs, &info, batch));
|
||||
|
||||
return ffi::Error::Success();
|
||||
@ -385,7 +283,6 @@ ffi::Error GeqrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
|
||||
return ffi::Error::InvalidArgument(absl::StrFormat(
|
||||
"Unsupported dtype %s in geqrf", absl::FormatStreamed(dataType)));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(GeqrfFfi, GeqrfDispatch,
|
||||
ffi::Ffi::Bind()
|
||||
@ -398,34 +295,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GeqrfFfi, GeqrfDispatch,
|
||||
|
||||
// Householder transformations: orgqr
|
||||
|
||||
namespace {
|
||||
#define ORGQR_KERNEL_IMPL(type, name) \
|
||||
template <> \
|
||||
struct OrgqrKernel<type> { \
|
||||
static absl::StatusOr<int> BufferSize(gpusolverDnHandle_t handle, int m, \
|
||||
int n, int k) { \
|
||||
int lwork; \
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
|
||||
name##_bufferSize(handle, m, n, k, /*A=*/nullptr, /*lda=*/m, \
|
||||
/*tau=*/nullptr, &lwork))); \
|
||||
return lwork; \
|
||||
} \
|
||||
static absl::Status Run(gpusolverDnHandle_t handle, int m, int n, int k, \
|
||||
type* a, type* tau, type* workspace, int lwork, \
|
||||
int* info) { \
|
||||
return JAX_AS_STATUS( \
|
||||
name(handle, m, n, k, a, m, tau, workspace, lwork, info)); \
|
||||
} \
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct OrgqrKernel;
|
||||
ORGQR_KERNEL_IMPL(float, gpusolverDnSorgqr);
|
||||
ORGQR_KERNEL_IMPL(double, gpusolverDnDorgqr);
|
||||
ORGQR_KERNEL_IMPL(gpuComplex, gpusolverDnCungqr);
|
||||
ORGQR_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZungqr);
|
||||
#undef ORGQR_KERNEL_IMPL
|
||||
|
||||
template <typename T>
|
||||
ffi::Error OrgqrImpl(int64_t batch, int64_t rows, int64_t cols, int64_t size,
|
||||
gpuStream_t stream, ffi::ScratchAllocator& scratch,
|
||||
@ -437,7 +306,7 @@ ffi::Error OrgqrImpl(int64_t batch, int64_t rows, int64_t cols, int64_t size,
|
||||
|
||||
FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream));
|
||||
FFI_ASSIGN_OR_RETURN(int lwork,
|
||||
OrgqrKernel<T>::BufferSize(handle.get(), m, n, k));
|
||||
solver::OrgqrBufferSize<T>(handle.get(), m, n, k));
|
||||
|
||||
FFI_ASSIGN_OR_RETURN(auto workspace,
|
||||
AllocateWorkspace<T>(scratch, lwork, "orgqr"));
|
||||
@ -450,13 +319,13 @@ ffi::Error OrgqrImpl(int64_t batch, int64_t rows, int64_t cols, int64_t size,
|
||||
auto tau_data = static_cast<T*>(tau.untyped_data());
|
||||
auto out_data = static_cast<T*>(out->untyped_data());
|
||||
if (a_data != out_data) {
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
|
||||
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
|
||||
JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
|
||||
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
|
||||
int out_step = m * n;
|
||||
for (auto i = 0; i < batch; ++i) {
|
||||
FFI_RETURN_IF_ERROR_STATUS(OrgqrKernel<T>::Run(
|
||||
FFI_RETURN_IF_ERROR_STATUS(solver::Orgqr<T>(
|
||||
handle.get(), m, n, k, out_data, tau_data, workspace, lwork, info));
|
||||
out_data += out_step;
|
||||
tau_data += k;
|
||||
@ -492,7 +361,6 @@ ffi::Error OrgqrDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
|
||||
return ffi::Error::InvalidArgument(absl::StrFormat(
|
||||
"Unsupported dtype %s in orgqr", absl::FormatStreamed(dataType)));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch,
|
||||
ffi::Ffi::Bind()
|
||||
@ -510,98 +378,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch,
|
||||
// dispatches dynamically to both syevd and syevj depending on the problem
|
||||
// size and the algorithm selected by the user via the `algorithm` attribute.
|
||||
|
||||
namespace {
|
||||
#define SYEVJ_KERNEL_IMPL(type, name) \
|
||||
template <> \
|
||||
struct SyevjKernel<type> { \
|
||||
static absl::StatusOr<int> BufferSize(gpusolverDnHandle_t handle, \
|
||||
gpusolverEigMode_t jobz, \
|
||||
gpusolverFillMode_t uplo, int n, \
|
||||
gpuSyevjInfo_t params) { \
|
||||
int lwork; \
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
|
||||
name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \
|
||||
/*w=*/nullptr, &lwork, params))); \
|
||||
return lwork; \
|
||||
} \
|
||||
static absl::Status Run(gpusolverDnHandle_t handle, \
|
||||
gpusolverEigMode_t jobz, gpusolverFillMode_t uplo, \
|
||||
int n, type* a, RealType<type>::Type* w, \
|
||||
type* workspace, int lwork, int* info, \
|
||||
gpuSyevjInfo_t params) { \
|
||||
return JAX_AS_STATUS(name(handle, jobz, uplo, n, a, n, w, workspace, \
|
||||
lwork, info, params)); \
|
||||
} \
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct SyevjKernel;
|
||||
SYEVJ_KERNEL_IMPL(float, gpusolverDnSsyevj);
|
||||
SYEVJ_KERNEL_IMPL(double, gpusolverDnDsyevj);
|
||||
SYEVJ_KERNEL_IMPL(gpuComplex, gpusolverDnCheevj);
|
||||
SYEVJ_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZheevj);
|
||||
#undef SYEVJ_KERNEL_IMPL
|
||||
|
||||
#define SYEVJ_BATCHED_KERNEL_IMPL(type, name) \
|
||||
template <> \
|
||||
struct SyevjBatchedKernel<type> { \
|
||||
static absl::StatusOr<int> BufferSize(gpusolverDnHandle_t handle, \
|
||||
gpusolverEigMode_t jobz, \
|
||||
gpusolverFillMode_t uplo, int n, \
|
||||
gpuSyevjInfo_t params, int batch) { \
|
||||
int lwork; \
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
|
||||
name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \
|
||||
/*w=*/nullptr, &lwork, params, batch))); \
|
||||
return lwork; \
|
||||
} \
|
||||
static absl::Status Run(gpusolverDnHandle_t handle, \
|
||||
gpusolverEigMode_t jobz, gpusolverFillMode_t uplo, \
|
||||
int n, type* a, RealType<type>::Type* w, \
|
||||
type* workspace, int lwork, int* info, \
|
||||
gpuSyevjInfo_t params, int batch) { \
|
||||
return JAX_AS_STATUS(name(handle, jobz, uplo, n, a, n, w, workspace, \
|
||||
lwork, info, params, batch)); \
|
||||
} \
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct SyevjBatchedKernel;
|
||||
SYEVJ_BATCHED_KERNEL_IMPL(float, gpusolverDnSsyevjBatched);
|
||||
SYEVJ_BATCHED_KERNEL_IMPL(double, gpusolverDnDsyevjBatched);
|
||||
SYEVJ_BATCHED_KERNEL_IMPL(gpuComplex, gpusolverDnCheevjBatched);
|
||||
SYEVJ_BATCHED_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZheevjBatched);
|
||||
#undef SYEVJ_BATCHED_KERNEL_IMPL
|
||||
|
||||
#define SYEVD_KERNEL_IMPL(type, name) \
|
||||
template <> \
|
||||
struct SyevdKernel<type> { \
|
||||
static absl::StatusOr<int> BufferSize(gpusolverDnHandle_t handle, \
|
||||
gpusolverEigMode_t jobz, \
|
||||
gpusolverFillMode_t uplo, int n) { \
|
||||
int lwork; \
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
|
||||
name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \
|
||||
/*w=*/nullptr, &lwork))); \
|
||||
return lwork; \
|
||||
} \
|
||||
static absl::Status Run(gpusolverDnHandle_t handle, \
|
||||
gpusolverEigMode_t jobz, gpusolverFillMode_t uplo, \
|
||||
int n, type* a, RealType<type>::Type* w, \
|
||||
type* workspace, int lwork, int* info) { \
|
||||
return JAX_AS_STATUS( \
|
||||
name(handle, jobz, uplo, n, a, n, w, workspace, lwork, info)); \
|
||||
} \
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct SyevdKernel;
|
||||
SYEVD_KERNEL_IMPL(float, gpusolverDnSsyevd);
|
||||
SYEVD_KERNEL_IMPL(double, gpusolverDnDsyevd);
|
||||
SYEVD_KERNEL_IMPL(gpuComplex, gpusolverDnCheevd);
|
||||
SYEVD_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZheevd);
|
||||
#undef SYEVD_KERNEL_IMPL
|
||||
|
||||
template <typename T>
|
||||
ffi::Error SyevdImpl(int64_t batch, int64_t size, gpuStream_t stream,
|
||||
ffi::ScratchAllocator& scratch, SyevdAlgorithm algorithm,
|
||||
@ -618,49 +394,48 @@ ffi::Error SyevdImpl(int64_t batch, int64_t size, gpuStream_t stream,
|
||||
|
||||
auto a_data = static_cast<T*>(a.untyped_data());
|
||||
auto out_data = static_cast<T*>(out->untyped_data());
|
||||
auto w_data = static_cast<RealType<T>::Type*>(w->untyped_data());
|
||||
auto w_data = static_cast<solver::RealType<T>::value*>(w->untyped_data());
|
||||
auto info_data = info->typed_data();
|
||||
if (a_data != out_data) {
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
|
||||
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
|
||||
JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
|
||||
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
if (algorithm == SyevdAlgorithm::kJacobi ||
|
||||
(algorithm == SyevdAlgorithm::kDefault && size <= 32)) {
|
||||
gpuSyevjInfo_t params;
|
||||
FFI_RETURN_IF_ERROR_STATUS(
|
||||
JAX_AS_STATUS(gpusolverDnCreateSyevjInfo(¶ms)));
|
||||
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateSyevjInfo(¶ms));
|
||||
std::unique_ptr<gpuSyevjInfo, void (*)(gpuSyevjInfo_t)> params_cleanup(
|
||||
params, [](gpuSyevjInfo_t p) { gpusolverDnDestroySyevjInfo(p); });
|
||||
|
||||
if (batch == 1) {
|
||||
FFI_ASSIGN_OR_RETURN(int lwork, SyevjKernel<T>::BufferSize(
|
||||
FFI_ASSIGN_OR_RETURN(int lwork, solver::SyevjBufferSize<T>(
|
||||
handle.get(), jobz, uplo, n, params));
|
||||
FFI_ASSIGN_OR_RETURN(auto workspace,
|
||||
AllocateWorkspace<T>(scratch, lwork, "syevj"));
|
||||
FFI_RETURN_IF_ERROR_STATUS(
|
||||
SyevjKernel<T>::Run(handle.get(), jobz, uplo, n, out_data, w_data,
|
||||
workspace, lwork, info_data, params));
|
||||
FFI_RETURN_IF_ERROR_STATUS(solver::Syevj<T>(handle.get(), jobz, uplo, n,
|
||||
out_data, w_data, workspace,
|
||||
lwork, info_data, params));
|
||||
} else {
|
||||
FFI_ASSIGN_OR_RETURN(
|
||||
int lwork, SyevjBatchedKernel<T>::BufferSize(handle.get(), jobz, uplo,
|
||||
int lwork, solver::SyevjBatchedBufferSize<T>(handle.get(), jobz, uplo,
|
||||
n, params, batch));
|
||||
FFI_ASSIGN_OR_RETURN(
|
||||
auto workspace,
|
||||
AllocateWorkspace<T>(scratch, lwork, "syevj_batched"));
|
||||
FFI_RETURN_IF_ERROR_STATUS(SyevjBatchedKernel<T>::Run(
|
||||
handle.get(), jobz, uplo, n, out_data, w_data, workspace, lwork,
|
||||
info_data, params, batch));
|
||||
FFI_RETURN_IF_ERROR_STATUS(
|
||||
solver::SyevjBatched<T>(handle.get(), jobz, uplo, n, out_data, w_data,
|
||||
workspace, lwork, info_data, params, batch));
|
||||
}
|
||||
} else {
|
||||
FFI_ASSIGN_OR_RETURN(
|
||||
int lwork, SyevdKernel<T>::BufferSize(handle.get(), jobz, uplo, n));
|
||||
int lwork, solver::SyevdBufferSize<T>(handle.get(), jobz, uplo, n));
|
||||
FFI_ASSIGN_OR_RETURN(auto workspace,
|
||||
AllocateWorkspace<T>(scratch, lwork, "syevd"));
|
||||
int out_step = n * n;
|
||||
for (auto i = 0; i < batch; ++i) {
|
||||
FFI_RETURN_IF_ERROR_STATUS(
|
||||
SyevdKernel<T>::Run(handle.get(), jobz, uplo, n, out_data, w_data,
|
||||
workspace, lwork, info_data));
|
||||
FFI_RETURN_IF_ERROR_STATUS(solver::Syevd<T>(handle.get(), jobz, uplo, n,
|
||||
out_data, w_data, workspace,
|
||||
lwork, info_data));
|
||||
out_data += out_step;
|
||||
w_data += n;
|
||||
++info_data;
|
||||
@ -695,7 +470,6 @@ ffi::Error SyevdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
|
||||
return ffi::Error::InvalidArgument(absl::StrFormat(
|
||||
"Unsupported dtype %s in syevd", absl::FormatStreamed(dataType)));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(SyevdFfi, SyevdDispatch,
|
||||
ffi::Ffi::Bind()
|
||||
@ -709,110 +483,83 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(SyevdFfi, SyevdDispatch,
|
||||
.Ret<ffi::Buffer<ffi::S32>>() // info
|
||||
);
|
||||
|
||||
#define SYRK_KERNEL_IMPL(type, fn) \
|
||||
template <> \
|
||||
struct SyrkKernel<type> { \
|
||||
static absl::Status Run(gpublasHandle_t handle, std::int64_t n, \
|
||||
std::int64_t k, bool transpose, \
|
||||
const type* alpha, const type* beta, \
|
||||
const type* a_matrix, type* c_matrix) { \
|
||||
gpublasOperation_t op = transpose ? GPUBLAS_OP_N : GPUBLAS_OP_T; \
|
||||
gpublasFillMode_t uplo = GPUSOLVER_FILL_MODE_UPPER; \
|
||||
int lda = transpose ? n : k; \
|
||||
return JAX_AS_STATUS(fn(handle, uplo, op, n, k, \
|
||||
alpha, a_matrix, lda, beta, \
|
||||
c_matrix, n)); \
|
||||
} \
|
||||
}
|
||||
// Symmetric rank-k update: syrk
|
||||
|
||||
template <typename T>
|
||||
struct SyrkKernel;
|
||||
|
||||
SYRK_KERNEL_IMPL(float, gpublasSsyrk);
|
||||
SYRK_KERNEL_IMPL(double, gpublasDsyrk);
|
||||
SYRK_KERNEL_IMPL(gpublasComplex, gpublasCsyrk);
|
||||
SYRK_KERNEL_IMPL(gpublasDoubleComplex, gpublasZsyrk);
|
||||
#undef SYRK_KERNEL_IMPL
|
||||
|
||||
template <typename T>
|
||||
ffi::Error SyrkImpl(gpuStream_t stream,
|
||||
ffi::AnyBuffer a_matrix,
|
||||
ffi::AnyBuffer c_matrix,
|
||||
bool transpose,
|
||||
ffi::AnyBuffer alpha,
|
||||
ffi::AnyBuffer beta,
|
||||
ffi::Result<ffi::AnyBuffer> c_matrix_out) {
|
||||
ffi::Error SyrkImpl(gpuStream_t stream, bool transpose, ffi::AnyBuffer a,
|
||||
ffi::AnyBuffer c_in, ffi::AnyBuffer alpha,
|
||||
ffi::AnyBuffer beta, ffi::Result<ffi::AnyBuffer> c_out) {
|
||||
FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]),
|
||||
SplitBatch2D(a_matrix.dimensions()));
|
||||
FFI_ASSIGN_OR_RETURN((auto [batch_c, rows_c, cols_c]),
|
||||
SplitBatch2D(c_matrix.dimensions()));
|
||||
FFI_ASSIGN_OR_RETURN((auto [batch_out, rows_out, cols_out]),
|
||||
SplitBatch2D(c_matrix_out->dimensions()));
|
||||
if (batch != batch_c || batch != batch_out) {
|
||||
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
|
||||
"a_matrix, c_matrix and c_matrix_out must have the same "
|
||||
"batch size.");
|
||||
SplitBatch2D(a.dimensions()));
|
||||
if (alpha.element_count() != 1 || beta.element_count() != 1) {
|
||||
return ffi::Error::InvalidArgument(
|
||||
"The alpha and beta inputs to syrk must be scalars");
|
||||
}
|
||||
int n = transpose ? cols : rows;
|
||||
int k = transpose ? rows : cols;
|
||||
|
||||
auto size = transpose ? cols : rows;
|
||||
FFI_RETURN_IF_ERROR(
|
||||
CheckShape(c_matrix_out->dimensions().last(2), {n, n}, "out", "Syrk"));
|
||||
CheckShape(c_in.dimensions(), {batch, size, size}, "c_in", "syrk"));
|
||||
FFI_RETURN_IF_ERROR(
|
||||
CheckShape(c_matrix.dimensions().last(2), {n, n}, "C", "Syrk"));
|
||||
CheckShape(c_out->dimensions(), {batch, size, size}, "c_out", "syrk"));
|
||||
|
||||
const T* a_data = static_cast<const T*>(a_matrix.untyped_data());
|
||||
T* c_data = static_cast<T*>(c_matrix.untyped_data());
|
||||
T* c_out_data = static_cast<T*>(c_matrix_out->untyped_data());
|
||||
FFI_ASSIGN_OR_RETURN(auto n,
|
||||
MaybeCastNoOverflow<int>(transpose ? cols : rows));
|
||||
FFI_ASSIGN_OR_RETURN(auto k,
|
||||
MaybeCastNoOverflow<int>(transpose ? rows : cols));
|
||||
gpublasFillMode_t uplo = GPUSOLVER_FILL_MODE_UPPER;
|
||||
gpublasOperation_t trans = transpose ? GPUBLAS_OP_N : GPUBLAS_OP_T;
|
||||
|
||||
const T* a_data = static_cast<const T*>(a.untyped_data());
|
||||
T* c_data = static_cast<T*>(c_in.untyped_data());
|
||||
T* c_out_data = static_cast<T*>(c_out->untyped_data());
|
||||
|
||||
// with alpha or beta provided as device_pointers, cublas<T>syrk will SIGSEGV
|
||||
T host_alpha;
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
|
||||
&host_alpha, alpha.untyped_data(), sizeof(T), gpuMemcpyDeviceToHost,
|
||||
stream)));
|
||||
JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(&host_alpha, alpha.untyped_data(),
|
||||
sizeof(T), gpuMemcpyDeviceToHost,
|
||||
stream));
|
||||
|
||||
T host_beta;
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
|
||||
&host_beta, beta.untyped_data(), sizeof(T), gpuMemcpyDeviceToHost,
|
||||
stream)));
|
||||
JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(&host_beta, beta.untyped_data(),
|
||||
sizeof(T), gpuMemcpyDeviceToHost,
|
||||
stream));
|
||||
|
||||
if (c_data != c_out_data) {
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
|
||||
c_out_data, c_data, c_matrix.size_bytes(), gpuMemcpyDeviceToDevice,
|
||||
stream)));
|
||||
JAX_FFI_RETURN_IF_GPU_ERROR(
|
||||
gpuMemcpyAsync(c_out_data, c_data, c_in.size_bytes(),
|
||||
gpuMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream));
|
||||
for (int i = 0; i < batch; ++i) {
|
||||
FFI_RETURN_IF_ERROR_STATUS(SyrkKernel<T>::Run(
|
||||
handle.get(), n, k, transpose, &host_alpha, &host_beta,
|
||||
a_data + i * k * n, c_out_data + i * n * n));
|
||||
FFI_RETURN_IF_ERROR_STATUS(solver::Syrk<T>(handle.get(), uplo, trans, n, k,
|
||||
&host_alpha, a_data, &host_beta,
|
||||
c_out_data));
|
||||
a_data += k * n;
|
||||
c_out_data += n * n;
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
ffi::Error SyrkDispatch(
|
||||
gpuStream_t stream,
|
||||
ffi::AnyBuffer a_matrix,
|
||||
ffi::AnyBuffer c_matrix,
|
||||
bool transpose,
|
||||
ffi::AnyBuffer alpha,
|
||||
ffi::AnyBuffer beta,
|
||||
ffi::Result<ffi::AnyBuffer> c_matrix_out) {
|
||||
auto dataType = a_matrix.element_type();
|
||||
SOLVER_BLAS_DISPATCH_IMPL(SyrkImpl, stream, a_matrix, c_matrix, transpose,
|
||||
alpha, beta, c_matrix_out);
|
||||
return ffi::Error::InvalidArgument("Unsupported element type for Syrk");
|
||||
ffi::Error SyrkDispatch(gpuStream_t stream, bool transpose, ffi::AnyBuffer a,
|
||||
ffi::AnyBuffer c_in, ffi::AnyBuffer alpha,
|
||||
ffi::AnyBuffer beta,
|
||||
ffi::Result<ffi::AnyBuffer> c_out) {
|
||||
auto dataType = a.element_type();
|
||||
SOLVER_BLAS_DISPATCH_IMPL(SyrkImpl, stream, transpose, a, c_in, alpha, beta,
|
||||
c_out);
|
||||
return ffi::Error::InvalidArgument(absl::StrFormat(
|
||||
"Unsupported dtype %s in syrk", absl::FormatStreamed(dataType)));
|
||||
}
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(SyrkFfi, SyrkDispatch,
|
||||
ffi::Ffi::Bind()
|
||||
.Ctx<ffi::PlatformStream<gpuStream_t>>()
|
||||
.Arg<ffi::AnyBuffer>() // a_matrix
|
||||
.Arg<ffi::AnyBuffer>() // c_matrix
|
||||
.Attr<bool>("transpose") // transpose
|
||||
.Arg<ffi::AnyBuffer>() // alpha
|
||||
.Arg<ffi::AnyBuffer>() // beta
|
||||
.Ret<ffi::AnyBuffer>()); // c_matrix_out
|
||||
.Arg<ffi::AnyBuffer>() // a
|
||||
.Arg<ffi::AnyBuffer>() // c_in
|
||||
.Arg<ffi::AnyBuffer>() // alpha
|
||||
.Arg<ffi::AnyBuffer>() // beta
|
||||
.Ret<ffi::AnyBuffer>() // c_out
|
||||
);
|
||||
|
||||
#undef SOLVER_DISPATCH_IMPL
|
||||
#undef SOLVER_BLAS_DISPATCH_IMPL
|
||||
|
@ -168,6 +168,21 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hipsolver_interface",
|
||||
srcs = ["//jaxlib/gpu:solver_interface.cc"],
|
||||
hdrs = ["//jaxlib/gpu:solver_interface.h"],
|
||||
deps = [
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_vendor",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@local_config_rocm//rocm:hipblas",
|
||||
"@local_config_rocm//rocm:hipsolver",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hipsolver_kernels_ffi",
|
||||
srcs = ["//jaxlib/gpu:solver_kernels_ffi.cc"],
|
||||
@ -178,6 +193,7 @@ cc_library(
|
||||
":hip_make_batch_pointers",
|
||||
":hip_solver_handle_pool",
|
||||
":hip_vendor",
|
||||
":hipsolver_interface",
|
||||
"//jaxlib:ffi_helpers",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
|
Loading…
x
Reference in New Issue
Block a user