mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove synchronization from GPU LU decomposition kernel by adding an async batch pointers builder.
In the batched LU decomposition in cuBLAS, the output buffer is required to be a pointer of pointers to the appropriate batch matrices. Previously this reshaping was done on the host and then copied to the device, requiring a synchronization, but it seems straightforward to instead implement a tiny CUDA kernel to do this work. This definitely isn't a bottleneck or a high priority change, but this seemed like a reasonable time to fix a longstanding TODO. PiperOrigin-RevId: 663686539
This commit is contained in:
parent
acacf8884e
commit
b6306e3953
@ -72,6 +72,16 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_library(
|
||||
name = "cuda_make_batch_pointers",
|
||||
srcs = ["//jaxlib/gpu:make_batch_pointers.cu.cc"],
|
||||
hdrs = ["//jaxlib/gpu:make_batch_pointers.h"],
|
||||
deps = [
|
||||
":cuda_vendor",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cuda_blas_handle_pool",
|
||||
srcs = ["//jaxlib/gpu:blas_handle_pool.cc"],
|
||||
@ -95,6 +105,7 @@ cc_library(
|
||||
deps = [
|
||||
":cuda_blas_handle_pool",
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_make_batch_pointers",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
@ -223,6 +234,7 @@ cc_library(
|
||||
deps = [
|
||||
":cuda_blas_handle_pool",
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_make_batch_pointers",
|
||||
":cuda_solver_handle_pool",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:ffi_helpers",
|
||||
|
@ -36,6 +36,8 @@ exports_files(srcs = [
|
||||
"linalg_kernels.cc",
|
||||
"linalg_kernels.cu.cc",
|
||||
"linalg_kernels.h",
|
||||
"make_batch_pointers.cu.cc",
|
||||
"make_batch_pointers.h",
|
||||
"prng.cc",
|
||||
"prng_kernels.cc",
|
||||
"prng_kernels.cu.cc",
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "jaxlib/gpu/blas_handle_pool.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/make_batch_pointers.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
@ -69,13 +70,9 @@ static absl::Status GetrfBatched_(gpuStream_t stream, void** buffers,
|
||||
|
||||
int* ipiv = static_cast<int*>(buffers[2]);
|
||||
int* info = static_cast<int*>(buffers[3]);
|
||||
auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[4], d.batch,
|
||||
SizeOfBlasType(d.type) * d.n * d.n);
|
||||
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
|
||||
// TODO(phawkins): ideally we would not need to synchronize here, but to
|
||||
// avoid it we need a way to keep the host-side buffer alive until the copy
|
||||
// completes.
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
|
||||
MakeBatchPointersAsync(stream, buffers[1], buffers[4], d.batch,
|
||||
SizeOfBlasType(d.type) * d.n * d.n);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError()));
|
||||
switch (d.type) {
|
||||
case BlasType::F32: {
|
||||
float** batch_ptrs = static_cast<float**>(buffers[4]);
|
||||
@ -132,17 +129,12 @@ static absl::Status GeqrfBatched_(gpuStream_t stream, void** buffers,
|
||||
}
|
||||
|
||||
std::vector<int> info(d.batch);
|
||||
auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[3], d.batch,
|
||||
SizeOfBlasType(d.type) * d.m * d.n);
|
||||
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
|
||||
auto tau_ptrs_host =
|
||||
MakeBatchPointers(stream, buffers[2], buffers[4], d.batch,
|
||||
SizeOfBlasType(d.type) * std::min(d.m, d.n));
|
||||
JAX_RETURN_IF_ERROR(tau_ptrs_host.status());
|
||||
// TODO(phawkins): ideally we would not need to synchronize here, but to
|
||||
// avoid it we need a way to keep the host-side buffer alive until the copy
|
||||
// completes.
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
|
||||
MakeBatchPointersAsync(stream, buffers[1], buffers[3], d.batch,
|
||||
SizeOfBlasType(d.type) * d.m * d.n);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError()));
|
||||
MakeBatchPointersAsync(stream, buffers[2], buffers[4], d.batch,
|
||||
SizeOfBlasType(d.type) * std::min(d.m, d.n));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError()));
|
||||
switch (d.type) {
|
||||
case BlasType::F32: {
|
||||
float** a_batch_ptrs = static_cast<float**>(buffers[3]);
|
||||
|
@ -313,20 +313,5 @@ absl::Status AsStatus(cufftResult error, const char* file, std::int64_t line,
|
||||
}
|
||||
#endif
|
||||
|
||||
absl::StatusOr<std::unique_ptr<void*[]>> MakeBatchPointers(
|
||||
gpuStream_t stream, void* buffer, void* dev_ptrs, int batch,
|
||||
int batch_elem_size) {
|
||||
char* ptr = static_cast<char*>(buffer);
|
||||
auto host_ptrs = absl::make_unique<void*[]>(batch);
|
||||
for (int i = 0; i < batch; ++i) {
|
||||
host_ptrs[i] = ptr;
|
||||
ptr += batch_elem_size;
|
||||
}
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
gpuMemcpyAsync(dev_ptrs, host_ptrs.get(), sizeof(void*) * batch,
|
||||
gpuMemcpyHostToDevice, stream)));
|
||||
return std::move(host_ptrs);
|
||||
}
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
||||
|
@ -67,16 +67,6 @@ absl::Status AsStatus(cufftResult error, const char* file, std::int64_t line,
|
||||
const char* expr);
|
||||
#endif
|
||||
|
||||
// Builds an array of pointers to each array in a batch, in device memory.
|
||||
// Caution: the return value must be kept alive (e.g., via a stream
|
||||
// synchronization) until the copy enqueued by MakeBatchPointers on `stream`
|
||||
// completes.
|
||||
absl::StatusOr<std::unique_ptr<void*[]>> MakeBatchPointers(gpuStream_t stream,
|
||||
void* buffer,
|
||||
void* dev_ptrs,
|
||||
int batch,
|
||||
int batch_elem_size);
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
||||
|
||||
|
46
jaxlib/gpu/make_batch_pointers.cu.cc
Normal file
46
jaxlib/gpu/make_batch_pointers.cu.cc
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/gpu/make_batch_pointers.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
namespace {
|
||||
__global__ void MakeBatchPointersAsyncKernel(char* buffer_in, void** buffer_out,
|
||||
int batch, int batch_elem_size) {
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < batch;
|
||||
idx += blockDim.x * gridDim.x) {
|
||||
buffer_out[idx] = buffer_in + idx * batch_elem_size;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void MakeBatchPointersAsync(gpuStream_t stream, void* buffer_in,
|
||||
void* buffer_out, int batch, int batch_elem_size) {
|
||||
const int block_dim = 128;
|
||||
const std::size_t grid_dim =
|
||||
std::min<std::size_t>(1024, (batch + block_dim - 1) / block_dim);
|
||||
MakeBatchPointersAsyncKernel<<<grid_dim, block_dim, 0, stream>>>(
|
||||
static_cast<char*>(buffer_in), static_cast<void**>(buffer_out), batch,
|
||||
batch_elem_size);
|
||||
}
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
30
jaxlib/gpu/make_batch_pointers.h
Normal file
30
jaxlib/gpu/make_batch_pointers.h
Normal file
@ -0,0 +1,30 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef JAXLIB_GPU_MAKE_BATCH_POINTERS_H_
|
||||
#define JAXLIB_GPU_MAKE_BATCH_POINTERS_H_
|
||||
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
void MakeBatchPointersAsync(gpuStream_t stream, void* buffer_in,
|
||||
void* buffer_out, int batch, int batch_elem_size);
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_GPU_MAKE_BATCH_POINTERS_H_
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "jaxlib/ffi_helpers.h"
|
||||
#include "jaxlib/gpu/blas_handle_pool.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/make_batch_pointers.h"
|
||||
#include "jaxlib/gpu/solver_handle_pool.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
@ -142,13 +143,8 @@ ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream,
|
||||
gpuMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
FFI_ASSIGN_OR_RETURN(
|
||||
auto a_ptrs_host,
|
||||
MakeBatchPointers(stream, out_data, workspace, batch, sizeof(T) * n * n));
|
||||
// TODO(phawkins, danfm): ideally we would not need to synchronize here, but
|
||||
// to avoid it we need a way to keep the host-side buffer alive until the copy
|
||||
// completes.
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
|
||||
MakeBatchPointersAsync(stream, out_data, workspace, batch, sizeof(T) * n * n);
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError()));
|
||||
|
||||
auto batch_ptrs = static_cast<T**>(workspace);
|
||||
FFI_RETURN_IF_ERROR_STATUS(GetrfBatchedKernel<T>::Run(
|
||||
|
@ -58,6 +58,16 @@ cc_library(
|
||||
]),
|
||||
)
|
||||
|
||||
rocm_library(
|
||||
name = "hip_make_batch_pointers",
|
||||
srcs = ["//third_party/py/jax/jaxlib/gpu:make_batch_pointers.cu.cc"],
|
||||
hdrs = ["//third_party/py/jax/jaxlib/gpu:make_batch_pointers.h"],
|
||||
deps = [
|
||||
":hip_vendor",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hip_blas_handle_pool",
|
||||
srcs = ["//jaxlib/gpu:blas_handle_pool.cc"],
|
||||
@ -80,6 +90,7 @@ cc_library(
|
||||
deps = [
|
||||
":hip_blas_handle_pool",
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_make_batch_pointers",
|
||||
":hip_vendor",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
@ -160,6 +171,7 @@ cc_library(
|
||||
deps = [
|
||||
":hip_blas_handle_pool",
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_make_batch_pointers",
|
||||
":hip_solver_handle_pool",
|
||||
":hip_vendor",
|
||||
"//jaxlib:ffi_helpers",
|
||||
|
Loading…
x
Reference in New Issue
Block a user