Remove synchronization from GPU LU decomposition kernel by adding an async batch pointers builder.

In the batched LU decomposition in cuBLAS, the output buffer is required to be a pointer of pointers to the appropriate batch matrices. Previously this reshaping was done on the host and then copied to the device, requiring a synchronization, but it seems straightforward to instead implement a tiny CUDA kernel to do this work. This definitely isn't a bottleneck or a high priority change, but this seemed like a reasonable time to fix a longstanding TODO.

PiperOrigin-RevId: 663686539
This commit is contained in:
Dan Foreman-Mackey 2024-08-16 04:36:30 -07:00 committed by jax authors
parent acacf8884e
commit b6306e3953
9 changed files with 115 additions and 50 deletions

View File

@ -72,6 +72,16 @@ cc_library(
],
)
cuda_library(
name = "cuda_make_batch_pointers",
srcs = ["//jaxlib/gpu:make_batch_pointers.cu.cc"],
hdrs = ["//jaxlib/gpu:make_batch_pointers.h"],
deps = [
":cuda_vendor",
"@local_config_cuda//cuda:cuda_headers",
],
)
cc_library(
name = "cuda_blas_handle_pool",
srcs = ["//jaxlib/gpu:blas_handle_pool.cc"],
@ -95,6 +105,7 @@ cc_library(
deps = [
":cuda_blas_handle_pool",
":cuda_gpu_kernel_helpers",
":cuda_make_batch_pointers",
":cuda_vendor",
"//jaxlib:kernel_helpers",
"@xla//xla/service:custom_call_status",
@ -223,6 +234,7 @@ cc_library(
deps = [
":cuda_blas_handle_pool",
":cuda_gpu_kernel_helpers",
":cuda_make_batch_pointers",
":cuda_solver_handle_pool",
":cuda_vendor",
"//jaxlib:ffi_helpers",

View File

@ -36,6 +36,8 @@ exports_files(srcs = [
"linalg_kernels.cc",
"linalg_kernels.cu.cc",
"linalg_kernels.h",
"make_batch_pointers.cu.cc",
"make_batch_pointers.h",
"prng.cc",
"prng_kernels.cc",
"prng_kernels.cu.cc",

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "jaxlib/gpu/blas_handle_pool.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/make_batch_pointers.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/kernel_helpers.h"
#include "xla/service/custom_call_status.h"
@ -69,13 +70,9 @@ static absl::Status GetrfBatched_(gpuStream_t stream, void** buffers,
int* ipiv = static_cast<int*>(buffers[2]);
int* info = static_cast<int*>(buffers[3]);
auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[4], d.batch,
SizeOfBlasType(d.type) * d.n * d.n);
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
// TODO(phawkins): ideally we would not need to synchronize here, but to
// avoid it we need a way to keep the host-side buffer alive until the copy
// completes.
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
MakeBatchPointersAsync(stream, buffers[1], buffers[4], d.batch,
SizeOfBlasType(d.type) * d.n * d.n);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError()));
switch (d.type) {
case BlasType::F32: {
float** batch_ptrs = static_cast<float**>(buffers[4]);
@ -132,17 +129,12 @@ static absl::Status GeqrfBatched_(gpuStream_t stream, void** buffers,
}
std::vector<int> info(d.batch);
auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[3], d.batch,
SizeOfBlasType(d.type) * d.m * d.n);
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
auto tau_ptrs_host =
MakeBatchPointers(stream, buffers[2], buffers[4], d.batch,
SizeOfBlasType(d.type) * std::min(d.m, d.n));
JAX_RETURN_IF_ERROR(tau_ptrs_host.status());
// TODO(phawkins): ideally we would not need to synchronize here, but to
// avoid it we need a way to keep the host-side buffer alive until the copy
// completes.
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
MakeBatchPointersAsync(stream, buffers[1], buffers[3], d.batch,
SizeOfBlasType(d.type) * d.m * d.n);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError()));
MakeBatchPointersAsync(stream, buffers[2], buffers[4], d.batch,
SizeOfBlasType(d.type) * std::min(d.m, d.n));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError()));
switch (d.type) {
case BlasType::F32: {
float** a_batch_ptrs = static_cast<float**>(buffers[3]);

View File

@ -313,20 +313,5 @@ absl::Status AsStatus(cufftResult error, const char* file, std::int64_t line,
}
#endif
absl::StatusOr<std::unique_ptr<void*[]>> MakeBatchPointers(
gpuStream_t stream, void* buffer, void* dev_ptrs, int batch,
int batch_elem_size) {
char* ptr = static_cast<char*>(buffer);
auto host_ptrs = absl::make_unique<void*[]>(batch);
for (int i = 0; i < batch; ++i) {
host_ptrs[i] = ptr;
ptr += batch_elem_size;
}
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
gpuMemcpyAsync(dev_ptrs, host_ptrs.get(), sizeof(void*) * batch,
gpuMemcpyHostToDevice, stream)));
return std::move(host_ptrs);
}
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

View File

@ -67,16 +67,6 @@ absl::Status AsStatus(cufftResult error, const char* file, std::int64_t line,
const char* expr);
#endif
// Builds an array of pointers to each array in a batch, in device memory.
// Caution: the return value must be kept alive (e.g., via a stream
// synchronization) until the copy enqueued by MakeBatchPointers on `stream`
// completes.
absl::StatusOr<std::unique_ptr<void*[]>> MakeBatchPointers(gpuStream_t stream,
void* buffer,
void* dev_ptrs,
int batch,
int batch_elem_size);
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

View File

@ -0,0 +1,46 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/gpu/make_batch_pointers.h"
#include <algorithm>
#include "jaxlib/gpu/vendor.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
namespace {
__global__ void MakeBatchPointersAsyncKernel(char* buffer_in, void** buffer_out,
int batch, int batch_elem_size) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < batch;
idx += blockDim.x * gridDim.x) {
buffer_out[idx] = buffer_in + idx * batch_elem_size;
}
}
} // namespace
void MakeBatchPointersAsync(gpuStream_t stream, void* buffer_in,
void* buffer_out, int batch, int batch_elem_size) {
const int block_dim = 128;
const std::size_t grid_dim =
std::min<std::size_t>(1024, (batch + block_dim - 1) / block_dim);
MakeBatchPointersAsyncKernel<<<grid_dim, block_dim, 0, stream>>>(
static_cast<char*>(buffer_in), static_cast<void**>(buffer_out), batch,
batch_elem_size);
}
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

View File

@ -0,0 +1,30 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_GPU_MAKE_BATCH_POINTERS_H_
#define JAXLIB_GPU_MAKE_BATCH_POINTERS_H_
#include "jaxlib/gpu/vendor.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
void MakeBatchPointersAsync(gpuStream_t stream, void* buffer_in,
void* buffer_out, int batch, int batch_elem_size);
} // namespace JAX_GPU_NAMESPACE
} // namespace jax
#endif // JAXLIB_GPU_MAKE_BATCH_POINTERS_H_

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "jaxlib/ffi_helpers.h"
#include "jaxlib/gpu/blas_handle_pool.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/make_batch_pointers.h"
#include "jaxlib/gpu/solver_handle_pool.h"
#include "jaxlib/gpu/vendor.h"
#include "xla/ffi/api/ffi.h"
@ -142,13 +143,8 @@ ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream,
gpuMemcpyDeviceToDevice, stream)));
}
FFI_ASSIGN_OR_RETURN(
auto a_ptrs_host,
MakeBatchPointers(stream, out_data, workspace, batch, sizeof(T) * n * n));
// TODO(phawkins, danfm): ideally we would not need to synchronize here, but
// to avoid it we need a way to keep the host-side buffer alive until the copy
// completes.
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
MakeBatchPointersAsync(stream, out_data, workspace, batch, sizeof(T) * n * n);
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError()));
auto batch_ptrs = static_cast<T**>(workspace);
FFI_RETURN_IF_ERROR_STATUS(GetrfBatchedKernel<T>::Run(

View File

@ -58,6 +58,16 @@ cc_library(
]),
)
rocm_library(
name = "hip_make_batch_pointers",
srcs = ["//third_party/py/jax/jaxlib/gpu:make_batch_pointers.cu.cc"],
hdrs = ["//third_party/py/jax/jaxlib/gpu:make_batch_pointers.h"],
deps = [
":hip_vendor",
"@local_config_rocm//rocm:rocm_headers",
],
)
cc_library(
name = "hip_blas_handle_pool",
srcs = ["//jaxlib/gpu:blas_handle_pool.cc"],
@ -80,6 +90,7 @@ cc_library(
deps = [
":hip_blas_handle_pool",
":hip_gpu_kernel_helpers",
":hip_make_batch_pointers",
":hip_vendor",
"//jaxlib:kernel_helpers",
"@com_google_absl//absl/algorithm:container",
@ -160,6 +171,7 @@ cc_library(
deps = [
":hip_blas_handle_pool",
":hip_gpu_kernel_helpers",
":hip_make_batch_pointers",
":hip_solver_handle_pool",
":hip_vendor",
"//jaxlib:ffi_helpers",