mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Port GPU kernel for Householder transformation to FFI.
PiperOrigin-RevId: 666305682
This commit is contained in:
parent
0b4f64e002
commit
b56ed8eedd
@ -51,6 +51,8 @@ XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_geqrf_ffi", "CUDA",
|
||||
GeqrfFfi);
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_csrlsvqr", Csrlsvqr, "CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_orgqr", Orgqr, "CUDA");
|
||||
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_orgqr_ffi", "CUDA",
|
||||
OrgqrFfi);
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevd", Syevd, "CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevj", Syevj, "CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_sytrd", Sytrd, "CUDA");
|
||||
|
@ -477,6 +477,7 @@ nb::dict Registrations() {
|
||||
|
||||
dict[JAX_GPU_PREFIX "solver_getrf_ffi"] = EncapsulateFfiHandler(GetrfFfi);
|
||||
dict[JAX_GPU_PREFIX "solver_geqrf_ffi"] = EncapsulateFfiHandler(GeqrfFfi);
|
||||
dict[JAX_GPU_PREFIX "solver_orgqr_ffi"] = EncapsulateFfiHandler(OrgqrFfi);
|
||||
|
||||
return dict;
|
||||
}
|
||||
|
@ -51,13 +51,13 @@ inline absl::StatusOr<T*> AllocateWorkspace(ffi::ScratchAllocator& scratch,
|
||||
} // namespace
|
||||
|
||||
#define SOLVER_DISPATCH_IMPL(impl, ...) \
|
||||
if (dataType == ffi::DataType::F32) { \
|
||||
if (dataType == ffi::F32) { \
|
||||
return impl<float>(__VA_ARGS__); \
|
||||
} else if (dataType == ffi::DataType::F64) { \
|
||||
} else if (dataType == ffi::F64) { \
|
||||
return impl<double>(__VA_ARGS__); \
|
||||
} else if (dataType == ffi::DataType::C64) { \
|
||||
} else if (dataType == ffi::C64) { \
|
||||
return impl<gpuComplex>(__VA_ARGS__); \
|
||||
} else if (dataType == ffi::DataType::C128) { \
|
||||
} else if (dataType == ffi::C128) { \
|
||||
return impl<gpuDoubleComplex>(__VA_ARGS__); \
|
||||
}
|
||||
|
||||
@ -94,8 +94,8 @@ template <typename T>
|
||||
ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols,
|
||||
gpuStream_t stream, ffi::ScratchAllocator& scratch,
|
||||
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::S32>> ipiv,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::S32>> info) {
|
||||
ffi::Result<ffi::Buffer<ffi::S32>> ipiv,
|
||||
ffi::Result<ffi::Buffer<ffi::S32>> info) {
|
||||
FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow<int>(rows));
|
||||
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(cols));
|
||||
|
||||
@ -110,13 +110,12 @@ ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols,
|
||||
auto ipiv_data = ipiv->typed_data();
|
||||
auto info_data = info->typed_data();
|
||||
if (a_data != out_data) {
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(
|
||||
gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * rows * cols,
|
||||
gpuMemcpyDeviceToDevice, stream)));
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
|
||||
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
int ipiv_step = std::min(m, n);
|
||||
for (int i = 0; i < batch; ++i) {
|
||||
for (auto i = 0; i < batch; ++i) {
|
||||
FFI_RETURN_IF_ERROR_STATUS(GetrfKernel<T>::Run(
|
||||
handle.get(), m, n, out_data, workspace, lwork, ipiv_data, info_data));
|
||||
out_data += m * n;
|
||||
@ -147,8 +146,8 @@ template <typename T>
|
||||
ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream,
|
||||
ffi::ScratchAllocator& scratch, ffi::AnyBuffer a,
|
||||
ffi::Result<ffi::AnyBuffer> out,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::S32>> ipiv,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::S32>> info) {
|
||||
ffi::Result<ffi::Buffer<ffi::S32>> ipiv,
|
||||
ffi::Result<ffi::Buffer<ffi::S32>> info) {
|
||||
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(cols));
|
||||
FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream));
|
||||
FFI_ASSIGN_OR_RETURN(auto batch_ptrs,
|
||||
@ -159,9 +158,8 @@ ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream,
|
||||
auto ipiv_data = ipiv->typed_data();
|
||||
auto info_data = info->typed_data();
|
||||
if (a_data != out_data) {
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(
|
||||
gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * cols * cols,
|
||||
gpuMemcpyDeviceToDevice, stream)));
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
|
||||
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
MakeBatchPointersAsync(stream, out_data, batch_ptrs, batch,
|
||||
@ -176,8 +174,8 @@ ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream,
|
||||
|
||||
ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
|
||||
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::S32>> ipiv,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::S32>> info) {
|
||||
ffi::Result<ffi::Buffer<ffi::S32>> ipiv,
|
||||
ffi::Result<ffi::Buffer<ffi::S32>> info) {
|
||||
auto dataType = a.element_type();
|
||||
if (dataType != out->element_type()) {
|
||||
return ffi::Error::InvalidArgument(
|
||||
@ -201,15 +199,14 @@ ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(
|
||||
GetrfFfi, GetrfDispatch,
|
||||
ffi::Ffi::Bind()
|
||||
.Ctx<ffi::PlatformStream<gpuStream_t>>()
|
||||
.Ctx<ffi::ScratchAllocator>()
|
||||
.Arg<ffi::AnyBuffer>() // a
|
||||
.Ret<ffi::AnyBuffer>() // out
|
||||
.Ret<ffi::Buffer<ffi::DataType::S32>>() // ipiv
|
||||
.Ret<ffi::Buffer<ffi::DataType::S32>>() // info
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(GetrfFfi, GetrfDispatch,
|
||||
ffi::Ffi::Bind()
|
||||
.Ctx<ffi::PlatformStream<gpuStream_t>>()
|
||||
.Ctx<ffi::ScratchAllocator>()
|
||||
.Arg<ffi::AnyBuffer>() // a
|
||||
.Ret<ffi::AnyBuffer>() // out
|
||||
.Ret<ffi::Buffer<ffi::S32>>() // ipiv
|
||||
.Ret<ffi::Buffer<ffi::S32>>() // info
|
||||
);
|
||||
|
||||
// QR decomposition: geqrf
|
||||
@ -264,14 +261,13 @@ ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols,
|
||||
auto out_data = static_cast<T*>(out->untyped_data());
|
||||
auto tau_data = static_cast<T*>(tau->untyped_data());
|
||||
if (a_data != out_data) {
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(
|
||||
gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * rows * cols,
|
||||
gpuMemcpyDeviceToDevice, stream)));
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
|
||||
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
int out_step = m * n;
|
||||
int tau_step = std::min(m, n);
|
||||
for (int i = 0; i < batch; ++i) {
|
||||
for (auto i = 0; i < batch; ++i) {
|
||||
FFI_RETURN_IF_ERROR_STATUS(GeqrfKernel<T>::Run(
|
||||
handle.get(), m, n, out_data, tau_data, workspace, lwork, info));
|
||||
out_data += out_step;
|
||||
@ -284,8 +280,8 @@ ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols,
|
||||
template <> \
|
||||
struct GeqrfBatchedKernel<type> { \
|
||||
static absl::Status Run(gpublasHandle_t handle, int m, int n, type** a, \
|
||||
type** tau, int* info, int batch) { \
|
||||
return JAX_AS_STATUS(name(handle, m, n, a, m, tau, info, batch)); \
|
||||
type** tau, int* info, int batch) { \
|
||||
return JAX_AS_STATUS(name(handle, m, n, a, m, tau, info, batch)); \
|
||||
} \
|
||||
}
|
||||
|
||||
@ -314,9 +310,8 @@ ffi::Error GeqrfBatchedImpl(int64_t batch, int64_t rows, int64_t cols,
|
||||
auto out_data = out->untyped_data();
|
||||
auto tau_data = tau->untyped_data();
|
||||
if (a_data != out_data) {
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(
|
||||
gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * rows * cols,
|
||||
gpuMemcpyDeviceToDevice, stream)));
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
|
||||
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
MakeBatchPointersAsync(stream, out_data, out_batch_ptrs, batch,
|
||||
@ -369,6 +364,112 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GeqrfFfi, GeqrfDispatch,
|
||||
.Ret<ffi::AnyBuffer>() // tau
|
||||
);
|
||||
|
||||
// Householder transformations: orgqr
|
||||
|
||||
namespace {
|
||||
#define ORGQR_KERNEL_IMPL(type, name) \
|
||||
template <> \
|
||||
struct OrgqrKernel<type> { \
|
||||
static absl::StatusOr<int> BufferSize(gpusolverDnHandle_t handle, int m, \
|
||||
int n, int k) { \
|
||||
int lwork; \
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
|
||||
name##_bufferSize(handle, m, n, k, /*A=*/nullptr, /*lda=*/m, \
|
||||
/*tau=*/nullptr, &lwork))); \
|
||||
return lwork; \
|
||||
} \
|
||||
static absl::Status Run(gpusolverDnHandle_t handle, int m, int n, int k, \
|
||||
type* a, type* tau, type* workspace, int lwork, \
|
||||
int* info) { \
|
||||
return JAX_AS_STATUS( \
|
||||
name(handle, m, n, k, a, m, tau, workspace, lwork, info)); \
|
||||
} \
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct OrgqrKernel;
|
||||
ORGQR_KERNEL_IMPL(float, gpusolverDnSorgqr);
|
||||
ORGQR_KERNEL_IMPL(double, gpusolverDnDorgqr);
|
||||
ORGQR_KERNEL_IMPL(gpuComplex, gpusolverDnCungqr);
|
||||
ORGQR_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZungqr);
|
||||
#undef ORGQR_KERNEL_IMPL
|
||||
|
||||
template <typename T>
|
||||
ffi::Error OrgqrImpl(int64_t batch, int64_t rows, int64_t cols, int64_t size,
|
||||
gpuStream_t stream, ffi::ScratchAllocator& scratch,
|
||||
ffi::AnyBuffer a, ffi::AnyBuffer tau,
|
||||
ffi::Result<ffi::AnyBuffer> out) {
|
||||
FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow<int>(rows));
|
||||
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(cols));
|
||||
FFI_ASSIGN_OR_RETURN(auto k, MaybeCastNoOverflow<int>(size));
|
||||
|
||||
FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream));
|
||||
FFI_ASSIGN_OR_RETURN(int lwork,
|
||||
OrgqrKernel<T>::BufferSize(handle.get(), m, n, k));
|
||||
|
||||
FFI_ASSIGN_OR_RETURN(auto workspace,
|
||||
AllocateWorkspace<T>(scratch, lwork, "orgqr"));
|
||||
// Note: We ignore the returned value of info because it is only used for
|
||||
// shape checking (which we already do ourselves), but it is expected to be
|
||||
// in device memory, so we need to allocate it.
|
||||
FFI_ASSIGN_OR_RETURN(auto info, AllocateWorkspace<int>(scratch, 1, "orgqr"));
|
||||
|
||||
auto a_data = static_cast<T*>(a.untyped_data());
|
||||
auto tau_data = static_cast<T*>(tau.untyped_data());
|
||||
auto out_data = static_cast<T*>(out->untyped_data());
|
||||
if (a_data != out_data) {
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
|
||||
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
int out_step = m * n;
|
||||
for (auto i = 0; i < batch; ++i) {
|
||||
FFI_RETURN_IF_ERROR_STATUS(OrgqrKernel<T>::Run(
|
||||
handle.get(), m, n, k, out_data, tau_data, workspace, lwork, info));
|
||||
out_data += out_step;
|
||||
tau_data += k;
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
ffi::Error OrgqrDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
|
||||
ffi::AnyBuffer a, ffi::AnyBuffer tau,
|
||||
ffi::Result<ffi::AnyBuffer> out) {
|
||||
auto dataType = a.element_type();
|
||||
if (dataType != tau.element_type() || dataType != out->element_type()) {
|
||||
return ffi::Error::InvalidArgument(
|
||||
"The inputs and outputs to orgqr must have the same element type");
|
||||
}
|
||||
FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]),
|
||||
SplitBatch2D(a.dimensions()));
|
||||
FFI_ASSIGN_OR_RETURN((auto [tau_batch, size]),
|
||||
SplitBatch1D(tau.dimensions()));
|
||||
if (tau_batch != batch) {
|
||||
return ffi::Error::InvalidArgument(
|
||||
"The batch dimensions of the inputs to orgqr must match");
|
||||
}
|
||||
if (size > cols) {
|
||||
return ffi::Error::InvalidArgument(
|
||||
"The trailing dimension of the tau input to orgqr must be less than or "
|
||||
"equal to the number of columns of the input matrix");
|
||||
}
|
||||
FFI_RETURN_IF_ERROR(
|
||||
CheckShape(out->dimensions(), {batch, rows, cols}, "out", "orgqr"));
|
||||
SOLVER_DISPATCH_IMPL(OrgqrImpl, batch, rows, cols, size, stream, scratch, a,
|
||||
tau, out);
|
||||
return ffi::Error::InvalidArgument("Unsupported element type for orgqr");
|
||||
}
|
||||
} // namespace
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch,
|
||||
ffi::Ffi::Bind()
|
||||
.Ctx<ffi::PlatformStream<gpuStream_t>>()
|
||||
.Ctx<ffi::ScratchAllocator>()
|
||||
.Arg<ffi::AnyBuffer>() // a
|
||||
.Arg<ffi::AnyBuffer>() // tau
|
||||
.Ret<ffi::AnyBuffer>() // out
|
||||
);
|
||||
|
||||
#undef SOLVER_DISPATCH_IMPL
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
|
@ -24,6 +24,7 @@ namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(GetrfFfi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(GeqrfFfi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(OrgqrFfi);
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
||||
|
Loading…
x
Reference in New Issue
Block a user