Move logic about when to dispatch to batched LU decomposition algorithm on GPU into the kernel.

This simplifies the lowering logic, and means that we don't get hit with a performance penalty when exporting with shape polymorphism.

PiperOrigin-RevId: 662945116
This commit is contained in:
Dan Foreman-Mackey 2024-08-14 09:20:07 -07:00 committed by jax authors
parent bab70dda97
commit ad1bd38790
9 changed files with 104 additions and 233 deletions

View File

@ -114,22 +114,6 @@ cc_library(
],
)
cc_library(
name = "cublas_kernels_ffi",
srcs = ["//jaxlib/gpu:blas_kernels_ffi.cc"],
hdrs = ["//jaxlib/gpu:blas_kernels_ffi.h"],
deps = [
":cuda_blas_handle_pool",
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"//jaxlib:ffi_helpers",
"@xla//xla/ffi/api:ffi",
"@xla//xla/tsl/cuda:cublas",
"@xla//xla/tsl/cuda:cudart",
"@com_google_absl//absl/status",
],
)
pybind_extension(
name = "_blas",
srcs = ["//jaxlib/gpu:blas.cc"],
@ -148,7 +132,6 @@ pybind_extension(
module_name = "_blas",
deps = [
":cublas_kernels",
":cublas_kernels_ffi",
":cuda_vendor",
"//jaxlib:kernel_nanobind_helpers",
"@xla//xla/tsl/cuda:cublas",
@ -238,11 +221,13 @@ cc_library(
srcs = ["//jaxlib/gpu:solver_kernels_ffi.cc"],
hdrs = ["//jaxlib/gpu:solver_kernels_ffi.h"],
deps = [
":cuda_blas_handle_pool",
":cuda_gpu_kernel_helpers",
":cuda_solver_handle_pool",
":cuda_vendor",
"//jaxlib:ffi_helpers",
"@xla//xla/ffi/api:ffi",
"@xla//xla/tsl/cuda:cublas",
"@xla//xla/tsl/cuda:cudart",
"@xla//xla/tsl/cuda:cusolver",
"@com_google_absl//absl/status",
@ -274,6 +259,7 @@ pybind_extension(
":cusolver_kernels",
":cusolver_kernels_ffi",
"//jaxlib:kernel_nanobind_helpers",
"@xla//xla/tsl/cuda:cublas",
"@xla//xla/tsl/cuda:cudart",
"@xla//xla/tsl/cuda:cusolver",
"@xla//xla/tsl/python/lib/core:numpy",
@ -466,7 +452,6 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":cublas_kernels",
":cublas_kernels_ffi",
":cuda_linalg_kernels",
":cuda_prng_kernels",
":cuda_vendor",

View File

@ -29,8 +29,6 @@ exports_files(srcs = [
"blas_handle_pool.h",
"blas_kernels.cc",
"blas_kernels.h",
"blas_kernels_ffi.cc",
"blas_kernels_ffi.h",
"gpu_kernel_helpers.cc",
"gpu_kernel_helpers.h",
"gpu_kernels.cc",

View File

@ -22,7 +22,6 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "jaxlib/gpu/blas_kernels.h"
#include "jaxlib/gpu/blas_kernels_ffi.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/kernel_nanobind_helpers.h"
#include "xla/tsl/python/lib/core/numpy.h"
@ -70,9 +69,6 @@ nb::dict Registrations() {
nb::dict dict;
dict[JAX_GPU_PREFIX "blas_getrf_batched"] = EncapsulateFunction(GetrfBatched);
dict[JAX_GPU_PREFIX "blas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched);
dict[JAX_GPU_PREFIX "blas_getrf_batched_ffi"] =
EncapsulateFfiHandler(GetrfBatchedFfi);
return dict;
}

View File

@ -1,133 +0,0 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/gpu/blas_kernels_ffi.h"
#include "absl/status/status.h"
#include "jaxlib/ffi_helpers.h"
#include "jaxlib/gpu/blas_handle_pool.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/vendor.h"
#include "xla/ffi/api/ffi.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
namespace ffi = ::xla::ffi;
namespace {
#define GETRF_BATCHED_KERNEL_IMPL(type, name) \
template <> \
struct GetrfBatchedKernel<type> { \
static absl::Status Run(gpublasHandle_t handle, int n, type** a, int lda, \
int* ipiv, int* info, int batch) { \
return JAX_AS_STATUS(name(handle, n, a, lda, ipiv, info, batch)); \
} \
}
template <typename T>
struct GetrfBatchedKernel;
GETRF_BATCHED_KERNEL_IMPL(float, gpublasSgetrfBatched);
GETRF_BATCHED_KERNEL_IMPL(double, gpublasDgetrfBatched);
GETRF_BATCHED_KERNEL_IMPL(gpublasComplex, gpublasCgetrfBatched);
GETRF_BATCHED_KERNEL_IMPL(gpublasDoubleComplex, gpublasZgetrfBatched);
#undef GETRF_BATCHED_KERNEL_IMPL
template <typename T>
ffi::Error GetrfBatchedImpl(gpuStream_t stream, ffi::ScratchAllocator& scratch,
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
ffi::Result<ffi::Buffer<ffi::DataType::S32>> ipiv,
ffi::Result<ffi::Buffer<ffi::DataType::S32>> info) {
FFI_RETURN_IF_ERROR(CheckMatrixDimensions(a.dimensions()));
auto [batch, rows, cols] = SplitBatch2D(a.dimensions());
if (rows != cols) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"getrf_batched only supports square matrices");
}
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(cols));
FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream));
auto maybe_workspace = scratch.Allocate(sizeof(void*) * batch);
if (!maybe_workspace.has_value()) {
return ffi::Error(ffi::ErrorCode::kUnknown,
"Unable to allocate workspace for batched getrf");
}
auto workspace = maybe_workspace.value();
auto a_data = a.untyped_data();
auto out_data = out->untyped_data();
auto ipiv_data = ipiv->typed_data();
auto info_data = info->typed_data();
if (a_data != out_data) {
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(
gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * cols * cols,
gpuMemcpyDeviceToDevice, stream)));
}
FFI_ASSIGN_OR_RETURN(
auto a_ptrs_host,
MakeBatchPointers(stream, out_data, workspace, batch, sizeof(T) * n * n));
// TODO(phawkins, danfm): ideally we would not need to synchronize here, but
// to avoid it we need a way to keep the host-side buffer alive until the copy
// completes.
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
auto batch_ptrs = static_cast<T**>(workspace);
FFI_RETURN_IF_ERROR_STATUS(GetrfBatchedKernel<T>::Run(
handle.get(), n, batch_ptrs, n, ipiv_data, info_data, batch));
return ffi::Error::Success();
}
ffi::Error GetrfBatchedDispatch(
gpuStream_t stream, ffi::ScratchAllocator scratch, ffi::AnyBuffer a,
ffi::Result<ffi::AnyBuffer> out,
ffi::Result<ffi::Buffer<ffi::DataType::S32>> ipiv,
ffi::Result<ffi::Buffer<ffi::DataType::S32>> info) {
auto dataType = a.element_type();
if (dataType != out->element_type()) {
return ffi::Error(
ffi::ErrorCode::kInvalidArgument,
"Input and output to getrf_batched must have the same element type");
}
if (dataType == ffi::DataType::F32) {
return GetrfBatchedImpl<float>(stream, scratch, a, out, ipiv, info);
} else if (dataType == ffi::DataType::F64) {
return GetrfBatchedImpl<double>(stream, scratch, a, out, ipiv, info);
} else if (dataType == ffi::DataType::C64) {
return GetrfBatchedImpl<gpublasComplex>(stream, scratch, a, out, ipiv,
info);
} else if (dataType == ffi::DataType::C128) {
return GetrfBatchedImpl<gpublasDoubleComplex>(stream, scratch, a, out, ipiv,
info);
}
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"Unsupported element type for getrf");
}
} // namespace
XLA_FFI_DEFINE_HANDLER_SYMBOL(
GetrfBatchedFfi, GetrfBatchedDispatch,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<gpuStream_t>>()
.Ctx<ffi::ScratchAllocator>()
.Arg<ffi::AnyBuffer>() // a
.Ret<ffi::AnyBuffer>() // out
.Ret<ffi::Buffer<ffi::DataType::S32>>() // ipiv
.Ret<ffi::Buffer<ffi::DataType::S32>>() // info
);
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

View File

@ -1,30 +0,0 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_GPU_BLAS_KERNELS_FFI_H_
#define JAXLIB_GPU_BLAS_KERNELS_FFI_H_
#include "jaxlib/gpu/vendor.h"
#include "xla/ffi/api/ffi.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
XLA_FFI_DECLARE_HANDLER_SYMBOL(GetrfBatchedFfi);
} // namespace JAX_GPU_NAMESPACE
} // namespace jax
#endif // JAXLIB_GPU_BLAS_KERNELS_FFI_H_

View File

@ -17,7 +17,6 @@ limitations under the License.
// JAX-generated HLO code from outside of JAX.
#include "jaxlib/gpu/blas_kernels.h"
#include "jaxlib/gpu/blas_kernels_ffi.h"
#include "jaxlib/gpu/linalg_kernels.h"
#include "jaxlib/gpu/prng_kernels.h"
#include "jaxlib/gpu/rnn_kernels.h"
@ -36,8 +35,6 @@ namespace {
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_getrf_batched", GetrfBatched,
"CUDA");
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cublas_getrf_batched_ffi", "CUDA",
GetrfBatchedFfi);
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_geqrf_batched", GeqrfBatched,
"CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn", RNNForward, "CUDA");

View File

@ -16,10 +16,12 @@ limitations under the License.
#include "jaxlib/gpu/solver_kernels_ffi.h"
#include <algorithm>
#include <cstdint>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "jaxlib/ffi_helpers.h"
#include "jaxlib/gpu/blas_handle_pool.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/solver_handle_pool.h"
#include "jaxlib/gpu/vendor.h"
@ -58,12 +60,11 @@ GETRF_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZgetrf);
#undef GETRF_KERNEL_IMPL
template <typename T>
ffi::Error GetrfImpl(gpuStream_t stream, ffi::ScratchAllocator& scratch,
ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols,
gpuStream_t stream, ffi::ScratchAllocator& scratch,
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
ffi::Result<ffi::Buffer<ffi::DataType::S32>> ipiv,
ffi::Result<ffi::Buffer<ffi::DataType::S32>> info) {
FFI_RETURN_IF_ERROR(CheckMatrixDimensions(a.dimensions()));
auto [batch, rows, cols] = SplitBatch2D(a.dimensions());
FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow<int>(rows));
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(cols));
@ -98,6 +99,64 @@ ffi::Error GetrfImpl(gpuStream_t stream, ffi::ScratchAllocator& scratch,
return ffi::Error::Success();
}
#define GETRF_BATCHED_KERNEL_IMPL(type, name) \
template <> \
struct GetrfBatchedKernel<type> { \
static absl::Status Run(gpublasHandle_t handle, int n, type** a, int lda, \
int* ipiv, int* info, int batch) { \
return JAX_AS_STATUS(name(handle, n, a, lda, ipiv, info, batch)); \
} \
}
template <typename T>
struct GetrfBatchedKernel;
GETRF_BATCHED_KERNEL_IMPL(float, gpublasSgetrfBatched);
GETRF_BATCHED_KERNEL_IMPL(double, gpublasDgetrfBatched);
GETRF_BATCHED_KERNEL_IMPL(gpublasComplex, gpublasCgetrfBatched);
GETRF_BATCHED_KERNEL_IMPL(gpublasDoubleComplex, gpublasZgetrfBatched);
#undef GETRF_BATCHED_KERNEL_IMPL
template <typename T>
ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream,
ffi::ScratchAllocator& scratch, ffi::AnyBuffer a,
ffi::Result<ffi::AnyBuffer> out,
ffi::Result<ffi::Buffer<ffi::DataType::S32>> ipiv,
ffi::Result<ffi::Buffer<ffi::DataType::S32>> info) {
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(cols));
FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream));
auto maybe_workspace = scratch.Allocate(sizeof(void*) * batch);
if (!maybe_workspace.has_value()) {
return ffi::Error(ffi::ErrorCode::kUnknown,
"Unable to allocate workspace for batched getrf");
}
auto workspace = maybe_workspace.value();
auto a_data = a.untyped_data();
auto out_data = out->untyped_data();
auto ipiv_data = ipiv->typed_data();
auto info_data = info->typed_data();
if (a_data != out_data) {
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(
gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * cols * cols,
gpuMemcpyDeviceToDevice, stream)));
}
FFI_ASSIGN_OR_RETURN(
auto a_ptrs_host,
MakeBatchPointers(stream, out_data, workspace, batch, sizeof(T) * n * n));
// TODO(phawkins, danfm): ideally we would not need to synchronize here, but
// to avoid it we need a way to keep the host-side buffer alive until the copy
// completes.
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
auto batch_ptrs = static_cast<T**>(workspace);
FFI_RETURN_IF_ERROR_STATUS(GetrfBatchedKernel<T>::Run(
handle.get(), n, batch_ptrs, n, ipiv_data, info_data, batch));
return ffi::Error::Success();
}
ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
ffi::Result<ffi::Buffer<ffi::DataType::S32>> ipiv,
@ -108,14 +167,36 @@ ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
ffi::ErrorCode::kInvalidArgument,
"The input and output to getrf must have the same element type");
}
FFI_RETURN_IF_ERROR(CheckMatrixDimensions(a.dimensions()));
auto [batch, rows, cols] = SplitBatch2D(a.dimensions());
if (batch > 1 && rows == cols && rows / batch <= 128) {
if (dataType == ffi::DataType::F32) {
return GetrfImpl<float>(stream, scratch, a, out, ipiv, info);
return GetrfBatchedImpl<float>(batch, cols, stream, scratch, a, out, ipiv,
info);
} else if (dataType == ffi::DataType::F64) {
return GetrfImpl<double>(stream, scratch, a, out, ipiv, info);
return GetrfBatchedImpl<double>(batch, cols, stream, scratch, a, out,
ipiv, info);
} else if (dataType == ffi::DataType::C64) {
return GetrfImpl<gpuComplex>(stream, scratch, a, out, ipiv, info);
return GetrfBatchedImpl<gpublasComplex>(batch, cols, stream, scratch, a,
out, ipiv, info);
} else if (dataType == ffi::DataType::C128) {
return GetrfImpl<gpuDoubleComplex>(stream, scratch, a, out, ipiv, info);
return GetrfBatchedImpl<gpublasDoubleComplex>(
batch, cols, stream, scratch, a, out, ipiv, info);
}
} else {
if (dataType == ffi::DataType::F32) {
return GetrfImpl<float>(batch, rows, cols, stream, scratch, a, out, ipiv,
info);
} else if (dataType == ffi::DataType::F64) {
return GetrfImpl<double>(batch, rows, cols, stream, scratch, a, out, ipiv,
info);
} else if (dataType == ffi::DataType::C64) {
return GetrfImpl<gpuComplex>(batch, rows, cols, stream, scratch, a, out,
ipiv, info);
} else if (dataType == ffi::DataType::C128) {
return GetrfImpl<gpuDoubleComplex>(batch, rows, cols, stream, scratch, a,
out, ipiv, info);
}
}
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"Unsupported element type for getrf");

View File

@ -43,10 +43,7 @@ except ImportError:
if _cublas:
for _name, _value in _cublas.registrations().items():
# TODO(danfm): Clean up after all legacy custom calls are ported.
api_version = 1 if _name.endswith("_ffi") else 0
xla_client.register_custom_call_target(_name, _value, platform="CUDA",
api_version=api_version)
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
try:
@ -78,10 +75,7 @@ except ImportError:
if _hipblas:
for _name, _value in _hipblas.registrations().items():
# TODO(danfm): Clean up after all legacy custom calls are ported.
api_version = 1 if _name.endswith("_ffi") else 0
xla_client.register_custom_call_target(_name, _value, platform="ROCM",
api_version=api_version)
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
try:
@ -115,15 +109,14 @@ def _getrf_hlo(platform, gpu_blas, gpu_solver, ctx, dtype, a):
num_bd = len(batch_dims)
i32_type = ir.IntegerType.get_signless(32)
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
batch = math.prod(batch_dims)
use_batched = batch > 1 and m == n and m // batch <= 128
# TODO(b/357034884): Remove after 3 week forward compatibility window.
if ctx.is_forward_compat():
if not gpu_blas:
raise GpuLibNotLinkedError()
if use_batched:
batch = math.prod(batch_dims)
if batch > 1 and m == n and m // batch <= 128:
lwork, opaque = gpu_blas.build_getrf_batched_descriptor(
np.dtype(dtype), batch, m)
workspace = ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8))
@ -154,9 +147,8 @@ def _getrf_hlo(platform, gpu_blas, gpu_solver, ctx, dtype, a):
operand_output_aliases={0: 0}).results
return out[:3]
target = "blas_getrf_batched_ffi" if use_batched else "solver_getrf_ffi"
return custom_call(
f"{platform}{target}",
f"{platform}solver_getrf_ffi",
result_types=[
a.type,
ir.RankedTensorType.get(batch_dims + (min(m, n),), i32_type),

View File

@ -98,23 +98,6 @@ cc_library(
],
)
cc_library(
name = "hipblas_kernels_ffi",
srcs = ["//jaxlib/gpu:blas_kernels_ffi.cc"],
hdrs = ["//jaxlib/gpu:blas_kernels_ffi.h"],
deps = [
":hip_blas_handle_pool",
":hip_gpu_kernel_helpers",
":hip_vendor",
"//jaxlib:ffi_helpers",
"@com_google_absl//absl/status",
"@local_config_rocm//rocm:hipblas",
"@local_config_rocm//rocm:rocm_headers",
"@xla//xla/ffi/api:ffi",
"@xla//xla/service:custom_call_status",
],
)
pybind_extension(
name = "_blas",
srcs = ["//jaxlib/gpu:blas.cc"],
@ -127,7 +110,6 @@ pybind_extension(
deps = [
":hip_vendor",
":hipblas_kernels",
":hipblas_kernels_ffi",
"//jaxlib:kernel_nanobind_helpers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings:str_format",
@ -176,12 +158,14 @@ cc_library(
srcs = ["//jaxlib/gpu:solver_kernels_ffi.cc"],
hdrs = ["//jaxlib/gpu:solver_kernels_ffi.h"],
deps = [
":hip_blas_handle_pool",
":hip_gpu_kernel_helpers",
":hip_solver_handle_pool",
":hip_vendor",
"//jaxlib:ffi_helpers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@local_config_rocm//rocm:hipblas",
"@local_config_rocm//rocm:hipsolver",
"@local_config_rocm//rocm:rocm_headers",
"@xla//xla/ffi/api:ffi",
@ -199,14 +183,15 @@ pybind_extension(
features = ["-use_header_modules"],
module_name = "_solver",
deps = [
":hip_solver_handle_pool",
":hip_gpu_kernel_helpers",
":hip_solver_handle_pool",
":hip_vendor",
":hipsolver_kernels",
":hipsolver_kernels_ffi",
"//jaxlib:kernel_nanobind_helpers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings:str_format",
"@local_config_rocm//rocm:hipblas",
"@local_config_rocm//rocm:hipsolver",
"@local_config_rocm//rocm:rocm_headers",
"@nanobind",