diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 8121a1058..b31ed78e3 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -72,6 +72,16 @@ cc_library( ], ) +cuda_library( + name = "cuda_make_batch_pointers", + srcs = ["//jaxlib/gpu:make_batch_pointers.cu.cc"], + hdrs = ["//jaxlib/gpu:make_batch_pointers.h"], + deps = [ + ":cuda_vendor", + "@local_config_cuda//cuda:cuda_headers", + ], +) + cc_library( name = "cuda_blas_handle_pool", srcs = ["//jaxlib/gpu:blas_handle_pool.cc"], @@ -95,6 +105,7 @@ cc_library( deps = [ ":cuda_blas_handle_pool", ":cuda_gpu_kernel_helpers", + ":cuda_make_batch_pointers", ":cuda_vendor", "//jaxlib:kernel_helpers", "@xla//xla/service:custom_call_status", @@ -223,6 +234,7 @@ cc_library( deps = [ ":cuda_blas_handle_pool", ":cuda_gpu_kernel_helpers", + ":cuda_make_batch_pointers", ":cuda_solver_handle_pool", ":cuda_vendor", "//jaxlib:ffi_helpers", diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index 6bdaf4ef1..706cac6b4 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -36,6 +36,8 @@ exports_files(srcs = [ "linalg_kernels.cc", "linalg_kernels.cu.cc", "linalg_kernels.h", + "make_batch_pointers.cu.cc", + "make_batch_pointers.h", "prng.cc", "prng_kernels.cc", "prng_kernels.cu.cc", diff --git a/jaxlib/gpu/blas_kernels.cc b/jaxlib/gpu/blas_kernels.cc index a963aa3fd..ac30aa9cc 100644 --- a/jaxlib/gpu/blas_kernels.cc +++ b/jaxlib/gpu/blas_kernels.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "jaxlib/gpu/blas_handle_pool.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/make_batch_pointers.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_helpers.h" #include "xla/service/custom_call_status.h" @@ -69,13 +70,9 @@ static absl::Status GetrfBatched_(gpuStream_t stream, void** buffers, int* ipiv = static_cast(buffers[2]); int* info = static_cast(buffers[3]); - auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[4], d.batch, - SizeOfBlasType(d.type) * d.n * d.n); - JAX_RETURN_IF_ERROR(a_ptrs_host.status()); - // TODO(phawkins): ideally we would not need to synchronize here, but to - // avoid it we need a way to keep the host-side buffer alive until the copy - // completes. - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + MakeBatchPointersAsync(stream, buffers[1], buffers[4], d.batch, + SizeOfBlasType(d.type) * d.n * d.n); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); switch (d.type) { case BlasType::F32: { float** batch_ptrs = static_cast(buffers[4]); @@ -132,17 +129,12 @@ static absl::Status GeqrfBatched_(gpuStream_t stream, void** buffers, } std::vector info(d.batch); - auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[3], d.batch, - SizeOfBlasType(d.type) * d.m * d.n); - JAX_RETURN_IF_ERROR(a_ptrs_host.status()); - auto tau_ptrs_host = - MakeBatchPointers(stream, buffers[2], buffers[4], d.batch, - SizeOfBlasType(d.type) * std::min(d.m, d.n)); - JAX_RETURN_IF_ERROR(tau_ptrs_host.status()); - // TODO(phawkins): ideally we would not need to synchronize here, but to - // avoid it we need a way to keep the host-side buffer alive until the copy - // completes. - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + MakeBatchPointersAsync(stream, buffers[1], buffers[3], d.batch, + SizeOfBlasType(d.type) * d.m * d.n); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); + MakeBatchPointersAsync(stream, buffers[2], buffers[4], d.batch, + SizeOfBlasType(d.type) * std::min(d.m, d.n)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); switch (d.type) { case BlasType::F32: { float** a_batch_ptrs = static_cast(buffers[3]); diff --git a/jaxlib/gpu/gpu_kernel_helpers.cc b/jaxlib/gpu/gpu_kernel_helpers.cc index f43122f2e..5a434f4b6 100644 --- a/jaxlib/gpu/gpu_kernel_helpers.cc +++ b/jaxlib/gpu/gpu_kernel_helpers.cc @@ -313,20 +313,5 @@ absl::Status AsStatus(cufftResult error, const char* file, std::int64_t line, } #endif -absl::StatusOr> MakeBatchPointers( - gpuStream_t stream, void* buffer, void* dev_ptrs, int batch, - int batch_elem_size) { - char* ptr = static_cast(buffer); - auto host_ptrs = absl::make_unique(batch); - for (int i = 0; i < batch; ++i) { - host_ptrs[i] = ptr; - ptr += batch_elem_size; - } - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpuMemcpyAsync(dev_ptrs, host_ptrs.get(), sizeof(void*) * batch, - gpuMemcpyHostToDevice, stream))); - return std::move(host_ptrs); -} - } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/gpu_kernel_helpers.h b/jaxlib/gpu/gpu_kernel_helpers.h index 46fca7bc4..aecb8a4fd 100644 --- a/jaxlib/gpu/gpu_kernel_helpers.h +++ b/jaxlib/gpu/gpu_kernel_helpers.h @@ -67,16 +67,6 @@ absl::Status AsStatus(cufftResult error, const char* file, std::int64_t line, const char* expr); #endif -// Builds an array of pointers to each array in a batch, in device memory. -// Caution: the return value must be kept alive (e.g., via a stream -// synchronization) until the copy enqueued by MakeBatchPointers on `stream` -// completes. -absl::StatusOr> MakeBatchPointers(gpuStream_t stream, - void* buffer, - void* dev_ptrs, - int batch, - int batch_elem_size); - } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/make_batch_pointers.cu.cc b/jaxlib/gpu/make_batch_pointers.cu.cc new file mode 100644 index 000000000..b10655645 --- /dev/null +++ b/jaxlib/gpu/make_batch_pointers.cu.cc @@ -0,0 +1,46 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/gpu/make_batch_pointers.h" + +#include + +#include "jaxlib/gpu/vendor.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { + +namespace { +__global__ void MakeBatchPointersAsyncKernel(char* buffer_in, void** buffer_out, + int batch, int batch_elem_size) { + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < batch; + idx += blockDim.x * gridDim.x) { + buffer_out[idx] = buffer_in + idx * batch_elem_size; + } +} +} // namespace + +void MakeBatchPointersAsync(gpuStream_t stream, void* buffer_in, + void* buffer_out, int batch, int batch_elem_size) { + const int block_dim = 128; + const std::size_t grid_dim = + std::min(1024, (batch + block_dim - 1) / block_dim); + MakeBatchPointersAsyncKernel<<>>( + static_cast(buffer_in), static_cast(buffer_out), batch, + batch_elem_size); +} + +} // namespace JAX_GPU_NAMESPACE +} // namespace jax diff --git a/jaxlib/gpu/make_batch_pointers.h b/jaxlib/gpu/make_batch_pointers.h new file mode 100644 index 000000000..f2fd06496 --- /dev/null +++ b/jaxlib/gpu/make_batch_pointers.h @@ -0,0 +1,30 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_GPU_MAKE_BATCH_POINTERS_H_ +#define JAXLIB_GPU_MAKE_BATCH_POINTERS_H_ + +#include "jaxlib/gpu/vendor.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { + +void MakeBatchPointersAsync(gpuStream_t stream, void* buffer_in, + void* buffer_out, int batch, int batch_elem_size); + +} // namespace JAX_GPU_NAMESPACE +} // namespace jax + +#endif // JAXLIB_GPU_MAKE_BATCH_POINTERS_H_ diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index 051b9fd03..6deb89144 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -23,6 +23,7 @@ limitations under the License. #include "jaxlib/ffi_helpers.h" #include "jaxlib/gpu/blas_handle_pool.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/make_batch_pointers.h" #include "jaxlib/gpu/solver_handle_pool.h" #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" @@ -142,13 +143,8 @@ ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream, gpuMemcpyDeviceToDevice, stream))); } - FFI_ASSIGN_OR_RETURN( - auto a_ptrs_host, - MakeBatchPointers(stream, out_data, workspace, batch, sizeof(T) * n * n)); - // TODO(phawkins, danfm): ideally we would not need to synchronize here, but - // to avoid it we need a way to keep the host-side buffer alive until the copy - // completes. - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + MakeBatchPointersAsync(stream, out_data, workspace, batch, sizeof(T) * n * n); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError())); auto batch_ptrs = static_cast(workspace); FFI_RETURN_IF_ERROR_STATUS(GetrfBatchedKernel::Run( diff --git a/jaxlib/rocm/BUILD.bazel b/jaxlib/rocm/BUILD.bazel index ba9ceb4c3..ce733d827 100644 --- a/jaxlib/rocm/BUILD.bazel +++ b/jaxlib/rocm/BUILD.bazel @@ -58,6 +58,16 @@ cc_library( ]), ) +rocm_library( + name = "hip_make_batch_pointers", + srcs = ["//third_party/py/jax/jaxlib/gpu:make_batch_pointers.cu.cc"], + hdrs = ["//third_party/py/jax/jaxlib/gpu:make_batch_pointers.h"], + deps = [ + ":hip_vendor", + "@local_config_rocm//rocm:rocm_headers", + ], +) + cc_library( name = "hip_blas_handle_pool", srcs = ["//jaxlib/gpu:blas_handle_pool.cc"], @@ -80,6 +90,7 @@ cc_library( deps = [ ":hip_blas_handle_pool", ":hip_gpu_kernel_helpers", + ":hip_make_batch_pointers", ":hip_vendor", "//jaxlib:kernel_helpers", "@com_google_absl//absl/algorithm:container", @@ -160,6 +171,7 @@ cc_library( deps = [ ":hip_blas_handle_pool", ":hip_gpu_kernel_helpers", + ":hip_make_batch_pointers", ":hip_solver_handle_pool", ":hip_vendor", "//jaxlib:ffi_helpers",