rocm_jax/jaxlib/gpu/solver_interface.cc
Dan Foreman-Mackey a3bf75e442 Refactor gpusolver kernel definitions into separate build target.
There is a lot of boilerplate required for each new custom call to cuSolver / cuBLAS, and having both the FFI logic and the framework wrappers in the same library was getting unwieldy. This change adds a new "interface" target which just includes the shims to wrap cuSolver/BLAS functions, and then these are used from `solver_kernels_ffi` where the FFI logic lives.

PiperOrigin-RevId: 673832309
2024-09-12 07:11:36 -07:00

238 lines
13 KiB
C++

/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/gpu/solver_interface.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/vendor.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
namespace solver {
// LU decomposition: getrf
#define JAX_GPU_DEFINE_GETRF(Type, Name) \
template <> \
absl::StatusOr<int> GetrfBufferSize<Type>(gpusolverDnHandle_t handle, int m, \
int n) { \
int lwork; \
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
Name##_bufferSize(handle, m, n, /*A=*/nullptr, m, &lwork))); \
return lwork; \
} \
\
template <> \
absl::Status Getrf<Type>(gpusolverDnHandle_t handle, int m, int n, Type *a, \
Type *workspace, int lwork, int *ipiv, int *info) { \
return JAX_AS_STATUS( \
Name(handle, m, n, a, m, workspace, lwork, ipiv, info)); \
}
JAX_GPU_DEFINE_GETRF(float, gpusolverDnSgetrf);
JAX_GPU_DEFINE_GETRF(double, gpusolverDnDgetrf);
JAX_GPU_DEFINE_GETRF(gpuComplex, gpusolverDnCgetrf);
JAX_GPU_DEFINE_GETRF(gpuDoubleComplex, gpusolverDnZgetrf);
#undef JAX_GPU_DEFINE_GETRF
#define JAX_GPU_DEFINE_GETRF_BATCHED(Type, Name) \
template <> \
absl::Status GetrfBatched<Type>(gpublasHandle_t handle, int n, Type **a, \
int lda, int *ipiv, int *info, int batch) { \
return JAX_AS_STATUS(Name(handle, n, a, lda, ipiv, info, batch)); \
}
JAX_GPU_DEFINE_GETRF_BATCHED(float, gpublasSgetrfBatched);
JAX_GPU_DEFINE_GETRF_BATCHED(double, gpublasDgetrfBatched);
JAX_GPU_DEFINE_GETRF_BATCHED(gpublasComplex, gpublasCgetrfBatched);
JAX_GPU_DEFINE_GETRF_BATCHED(gpublasDoubleComplex, gpublasZgetrfBatched);
#undef JAX_GPU_DEFINE_GETRF_BATCHED
// QR decomposition: geqrf
#define JAX_GPU_DEFINE_GEQRF(Type, Name) \
template <> \
absl::StatusOr<int> GeqrfBufferSize<Type>(gpusolverDnHandle_t handle, int m, \
int n) { \
int lwork; \
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
Name##_bufferSize(handle, m, n, /*A=*/nullptr, m, &lwork))); \
return lwork; \
} \
\
template <> \
absl::Status Geqrf<Type>(gpusolverDnHandle_t handle, int m, int n, Type *a, \
Type *tau, Type *workspace, int lwork, int *info) { \
return JAX_AS_STATUS( \
Name(handle, m, n, a, m, tau, workspace, lwork, info)); \
}
JAX_GPU_DEFINE_GEQRF(float, gpusolverDnSgeqrf);
JAX_GPU_DEFINE_GEQRF(double, gpusolverDnDgeqrf);
JAX_GPU_DEFINE_GEQRF(gpuComplex, gpusolverDnCgeqrf);
JAX_GPU_DEFINE_GEQRF(gpuDoubleComplex, gpusolverDnZgeqrf);
#undef JAX_GPU_DEFINE_GEQRF
#define JAX_GPU_DEFINE_GEQRF_BATCHED(Type, Name) \
template <> \
absl::Status GeqrfBatched<Type>(gpublasHandle_t handle, int m, int n, \
Type **a, Type **tau, int *info, \
int batch) { \
return JAX_AS_STATUS(Name(handle, m, n, a, m, tau, info, batch)); \
}
JAX_GPU_DEFINE_GEQRF_BATCHED(float, gpublasSgeqrfBatched);
JAX_GPU_DEFINE_GEQRF_BATCHED(double, gpublasDgeqrfBatched);
JAX_GPU_DEFINE_GEQRF_BATCHED(gpublasComplex, gpublasCgeqrfBatched);
JAX_GPU_DEFINE_GEQRF_BATCHED(gpublasDoubleComplex, gpublasZgeqrfBatched);
#undef JAX_GPU_DEFINE_GEQRF_BATCHED
// Householder transformations: orgqr
#define JAX_GPU_DEFINE_ORGQR(Type, Name) \
template <> \
absl::StatusOr<int> OrgqrBufferSize<Type>(gpusolverDnHandle_t handle, int m, \
int n, int k) { \
int lwork; \
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(Name##_bufferSize( \
handle, m, n, k, /*A=*/nullptr, /*lda=*/m, /*tau=*/nullptr, &lwork))); \
return lwork; \
} \
\
template <> \
absl::Status Orgqr<Type>(gpusolverDnHandle_t handle, int m, int n, int k, \
Type *a, Type *tau, Type *workspace, int lwork, \
int *info) { \
return JAX_AS_STATUS( \
Name(handle, m, n, k, a, m, tau, workspace, lwork, info)); \
}
JAX_GPU_DEFINE_ORGQR(float, gpusolverDnSorgqr);
JAX_GPU_DEFINE_ORGQR(double, gpusolverDnDorgqr);
JAX_GPU_DEFINE_ORGQR(gpuComplex, gpusolverDnCungqr);
JAX_GPU_DEFINE_ORGQR(gpuDoubleComplex, gpusolverDnZungqr);
#undef JAX_GPU_DEFINE_ORGQR
// Symmetric (Hermitian) eigendecomposition:
// * Jacobi algorithm: syevj/heevj (batches of matrices up to 32)
// * QR algorithm: syevd/heevd
#define JAX_GPU_DEFINE_SYEVJ(Type, Name) \
template <> \
absl::StatusOr<int> SyevjBufferSize<Type>( \
gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
gpusolverFillMode_t uplo, int n, gpuSyevjInfo_t params) { \
int lwork; \
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
Name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \
/*w=*/nullptr, &lwork, params))); \
return lwork; \
} \
\
template <> \
absl::Status Syevj<Type>( \
gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
gpusolverFillMode_t uplo, int n, Type *a, RealType<Type>::value *w, \
Type *workspace, int lwork, int *info, gpuSyevjInfo_t params) { \
return JAX_AS_STATUS( \
Name(handle, jobz, uplo, n, a, n, w, workspace, lwork, info, params)); \
}
JAX_GPU_DEFINE_SYEVJ(float, gpusolverDnSsyevj);
JAX_GPU_DEFINE_SYEVJ(double, gpusolverDnDsyevj);
JAX_GPU_DEFINE_SYEVJ(gpuComplex, gpusolverDnCheevj);
JAX_GPU_DEFINE_SYEVJ(gpuDoubleComplex, gpusolverDnZheevj);
#undef JAX_GPU_DEFINE_SYEVJ
#define JAX_GPU_DEFINE_SYEVJ_BATCHED(Type, Name) \
template <> \
absl::StatusOr<int> SyevjBatchedBufferSize<Type>( \
gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
gpusolverFillMode_t uplo, int n, gpuSyevjInfo_t params, int batch) { \
int lwork; \
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
Name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \
/*w=*/nullptr, &lwork, params, batch))); \
return lwork; \
} \
\
template <> \
absl::Status SyevjBatched<Type>( \
gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
gpusolverFillMode_t uplo, int n, Type *a, RealType<Type>::value *w, \
Type *workspace, int lwork, int *info, gpuSyevjInfo_t params, \
int batch) { \
return JAX_AS_STATUS(Name(handle, jobz, uplo, n, a, n, w, workspace, \
lwork, info, params, batch)); \
}
JAX_GPU_DEFINE_SYEVJ_BATCHED(float, gpusolverDnSsyevjBatched);
JAX_GPU_DEFINE_SYEVJ_BATCHED(double, gpusolverDnDsyevjBatched);
JAX_GPU_DEFINE_SYEVJ_BATCHED(gpuComplex, gpusolverDnCheevjBatched);
JAX_GPU_DEFINE_SYEVJ_BATCHED(gpuDoubleComplex, gpusolverDnZheevjBatched);
#undef JAX_GPU_DEFINE_SYEVJ_BATCHED
#define JAX_GPU_DEFINE_SYEVD(Type, Name) \
template <> \
absl::StatusOr<int> SyevdBufferSize<Type>(gpusolverDnHandle_t handle, \
gpusolverEigMode_t jobz, \
gpusolverFillMode_t uplo, int n) { \
int lwork; \
JAX_RETURN_IF_ERROR( \
JAX_AS_STATUS(Name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, \
/*lda=*/n, /*w=*/nullptr, &lwork))); \
return lwork; \
} \
\
template <> \
absl::Status Syevd<Type>(gpusolverDnHandle_t handle, \
gpusolverEigMode_t jobz, gpusolverFillMode_t uplo, \
int n, Type *a, RealType<Type>::value *w, \
Type *workspace, int lwork, int *info) { \
return JAX_AS_STATUS( \
Name(handle, jobz, uplo, n, a, n, w, workspace, lwork, info)); \
}
JAX_GPU_DEFINE_SYEVD(float, gpusolverDnSsyevd);
JAX_GPU_DEFINE_SYEVD(double, gpusolverDnDsyevd);
JAX_GPU_DEFINE_SYEVD(gpuComplex, gpusolverDnCheevd);
JAX_GPU_DEFINE_SYEVD(gpuDoubleComplex, gpusolverDnZheevd);
#undef JAX_GPU_DEFINE_SYEVD
// Symmetric rank-k update: syrk
#define JAX_GPU_DEFINE_SYRK(Type, Name) \
template <> \
absl::Status Syrk<Type>(gpublasHandle_t handle, gpublasFillMode_t uplo, \
gpublasOperation_t trans, int n, int k, \
const Type *alpha, const Type *a, const Type *beta, \
Type *c) { \
int lda = trans == GPUBLAS_OP_N ? n : k; \
return JAX_AS_STATUS( \
Name(handle, uplo, trans, n, k, alpha, a, lda, beta, c, n)); \
}
JAX_GPU_DEFINE_SYRK(float, gpublasSsyrk);
JAX_GPU_DEFINE_SYRK(double, gpublasDsyrk);
JAX_GPU_DEFINE_SYRK(gpublasComplex, gpublasCsyrk);
JAX_GPU_DEFINE_SYRK(gpublasDoubleComplex, gpublasZsyrk);
#undef JAX_GPU_DEFINE_SYRK
} // namespace solver
} // namespace JAX_GPU_NAMESPACE
} // namespace jax