diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 63e300a64..59e1cebc3 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -73,13 +73,33 @@ cc_library( ) cc_library( - name = "cublas_kernels", - srcs = ["//jaxlib/gpu:blas_kernels.cc"], - hdrs = ["//jaxlib/gpu:blas_kernels.h"], + name = "cuda_blas_handle_pool", + srcs = [ + "//jaxlib/gpu:blas_handle_pool.cc", + ], + hdrs = [ + "//jaxlib/gpu:blas_handle_pool.h", + ], deps = [ ":cuda_gpu_kernel_helpers", ":cuda_vendor", "//jaxlib:handle_pool", + "@xla//xla/tsl/cuda:cublas", + "@xla//xla/tsl/cuda:cudart", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + "@local_config_cuda//cuda:cuda_headers", + ], +) + +cc_library( + name = "cublas_kernels", + srcs = ["//jaxlib/gpu:blas_kernels.cc"], + hdrs = ["//jaxlib/gpu:blas_kernels.h"], + deps = [ + ":cuda_blas_handle_pool", + ":cuda_gpu_kernel_helpers", + ":cuda_vendor", "//jaxlib:kernel_helpers", "@xla//xla/service:custom_call_status", "@xla//xla/tsl/cuda:cublas", @@ -90,9 +110,9 @@ cc_library( "@com_google_absl//absl/hash", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cublas_headers", "@local_config_cuda//cuda:cuda_headers", ], @@ -119,6 +139,7 @@ pybind_extension( ":cuda_vendor", "//jaxlib:kernel_nanobind_helpers", "@xla//xla/tsl/cuda:cublas", + "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/python/lib/core:numpy", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:str_format", @@ -165,20 +186,40 @@ pybind_extension( ], ) +cc_library( + name = "cuda_solver_handle_pool", + srcs = [ + "//jaxlib/gpu:solver_handle_pool.cc", + ], + hdrs = [ + "//jaxlib/gpu:solver_handle_pool.h", + ], + deps = [ + ":cuda_gpu_kernel_helpers", + ":cuda_vendor", + "//jaxlib:handle_pool", + "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/cuda:cusolver", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + "@local_config_cuda//cuda:cuda_headers", + ], +) + cc_library( name = "cusolver_kernels", srcs = ["//jaxlib/gpu:solver_kernels.cc"], hdrs = ["//jaxlib/gpu:solver_kernels.h"], deps = [ ":cuda_gpu_kernel_helpers", + ":cuda_solver_handle_pool", ":cuda_vendor", - "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", "@xla//xla/service:custom_call_status", + "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cusolver", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", ], ) @@ -201,6 +242,7 @@ pybind_extension( module_name = "_solver", deps = [ ":cuda_gpu_kernel_helpers", + ":cuda_solver_handle_pool", ":cuda_vendor", ":cusolver_kernels", "//jaxlib:kernel_nanobind_helpers", diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index 942dbe639..c4f4c1246 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -25,6 +25,8 @@ package( exports_files(srcs = [ "blas.cc", + "blas_handle_pool.cc", + "blas_handle_pool.h", "blas_kernels.cc", "blas_kernels.h", "gpu_kernel_helpers.cc", @@ -42,6 +44,8 @@ exports_files(srcs = [ "rnn_kernels.cc", "rnn_kernels.h", "solver.cc", + "solver_handle_pool.cc", + "solver_handle_pool.h", "solver_kernels.cc", "solver_kernels.h", "sparse.cc", diff --git a/jaxlib/gpu/blas_handle_pool.cc b/jaxlib/gpu/blas_handle_pool.cc new file mode 100644 index 000000000..2ce204453 --- /dev/null +++ b/jaxlib/gpu/blas_handle_pool.cc @@ -0,0 +1,44 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/gpu/blas_handle_pool.h" + +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/vendor.h" +#include "jaxlib/handle_pool.h" + +namespace jax { + +template <> +/*static*/ absl::StatusOr BlasHandlePool::Borrow( + gpuStream_t stream) { + BlasHandlePool* pool = Instance(); + absl::MutexLock lock(&pool->mu_); + gpublasHandle_t handle; + if (pool->handles_[stream].empty()) { + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasCreate(&handle))); + } else { + handle = pool->handles_[stream].back(); + pool->handles_[stream].pop_back(); + } + if (stream) { + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasSetStream(handle, stream))); + } + return Handle(pool, handle, stream); +} + +} // namespace jax diff --git a/jaxlib/gpu/blas_handle_pool.h b/jaxlib/gpu/blas_handle_pool.h new file mode 100644 index 000000000..b3cdbaa88 --- /dev/null +++ b/jaxlib/gpu/blas_handle_pool.h @@ -0,0 +1,33 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_GPU_BLAS_HANDLE_POOL_H_ +#define JAXLIB_GPU_BLAS_HANDLE_POOL_H_ + +#include "absl/status/statusor.h" +#include "jaxlib/gpu/vendor.h" +#include "jaxlib/handle_pool.h" + +namespace jax { + +using BlasHandlePool = HandlePool; + +template <> +absl::StatusOr BlasHandlePool::Borrow( + gpuStream_t stream); + +} // namespace jax + +#endif // JAXLIB_GPU_BLAS_HANDLE_POOL_H_ diff --git a/jaxlib/gpu/blas_kernels.cc b/jaxlib/gpu/blas_kernels.cc index 329051a0a..a963aa3fd 100644 --- a/jaxlib/gpu/blas_kernels.cc +++ b/jaxlib/gpu/blas_kernels.cc @@ -16,42 +16,22 @@ limitations under the License. #include "jaxlib/gpu/blas_kernels.h" #include -#include -#include +#include +#include +#include #include -#include "absl/base/casts.h" -#include "absl/base/thread_annotations.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" -#include "absl/synchronization/mutex.h" +#include "jaxlib/gpu/blas_handle_pool.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_helpers.h" #include "xla/service/custom_call_status.h" namespace jax { -using BlasHandlePool = HandlePool; - -template <> -/*static*/ absl::StatusOr BlasHandlePool::Borrow( - gpuStream_t stream) { - BlasHandlePool* pool = Instance(); - absl::MutexLock lock(&pool->mu_); - gpublasHandle_t handle; - if (pool->handles_[stream].empty()) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasCreate(&handle))); - } else { - handle = pool->handles_[stream].back(); - pool->handles_[stream].pop_back(); - } - if (stream) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasSetStream(handle, stream))); - } - return Handle(pool, handle, stream); -} - namespace JAX_GPU_NAMESPACE { namespace { diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index c040f3875..dea2b9b07 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -13,23 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include +#include #include #include -#include #include "nanobind/nanobind.h" #include "nanobind/stl/pair.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" +#include "third_party/gpus/cuda/include/cusolver_common.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/solver_handle_pool.h" #include "jaxlib/gpu/solver_kernels.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" #include "xla/tsl/python/lib/core/numpy.h" - namespace jax { namespace JAX_GPU_NAMESPACE { namespace { diff --git a/jaxlib/gpu/solver_handle_pool.cc b/jaxlib/gpu/solver_handle_pool.cc new file mode 100644 index 000000000..c55ea923b --- /dev/null +++ b/jaxlib/gpu/solver_handle_pool.cc @@ -0,0 +1,70 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/gpu/solver_handle_pool.h" + +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/vendor.h" +#include "jaxlib/handle_pool.h" + +#ifdef JAX_GPU_CUDA +#include "third_party/gpus/cuda/include/cusolverSp.h" +#endif // JAX_GPU_CUDA + +namespace jax { + +template <> +/*static*/ absl::StatusOr SolverHandlePool::Borrow( + gpuStream_t stream) { + SolverHandlePool* pool = Instance(); + absl::MutexLock lock(&pool->mu_); + gpusolverDnHandle_t handle; + if (pool->handles_[stream].empty()) { + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCreate(&handle))); + } else { + handle = pool->handles_[stream].back(); + pool->handles_[stream].pop_back(); + } + if (stream) { + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSetStream(handle, stream))); + } + return Handle(pool, handle, stream); +} + +#ifdef JAX_GPU_CUDA + +template <> +/*static*/ absl::StatusOr +SpSolverHandlePool::Borrow(gpuStream_t stream) { + SpSolverHandlePool* pool = Instance(); + absl::MutexLock lock(&pool->mu_); + cusolverSpHandle_t handle; + if (pool->handles_[stream].empty()) { + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpCreate(&handle))); + } else { + handle = pool->handles_[stream].back(); + pool->handles_[stream].pop_back(); + } + if (stream) { + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpSetStream(handle, stream))); + } + return Handle(pool, handle, stream); +} + +#endif // JAX_GPU_CUDA + +} // namespace jax diff --git a/jaxlib/gpu/solver_handle_pool.h b/jaxlib/gpu/solver_handle_pool.h new file mode 100644 index 000000000..c46c062b3 --- /dev/null +++ b/jaxlib/gpu/solver_handle_pool.h @@ -0,0 +1,45 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_GPU_SOLVER_HANDLE_POOL_H_ +#define JAXLIB_GPU_SOLVER_HANDLE_POOL_H_ + +#include "absl/status/statusor.h" +#include "jaxlib/gpu/vendor.h" +#include "jaxlib/handle_pool.h" + +#ifdef JAX_GPU_CUDA +#include "third_party/gpus/cuda/include/cusolverSp.h" +#endif // JAX_GPU_CUDA + +namespace jax { + +using SolverHandlePool = HandlePool; + +template <> +absl::StatusOr SolverHandlePool::Borrow( + gpuStream_t stream); + +#ifdef JAX_GPU_CUDA +using SpSolverHandlePool = HandlePool; + +template <> +absl::StatusOr SpSolverHandlePool::Borrow( + gpuStream_t stream); +#endif // JAX_GPU_CUDA + +} // namespace jax + +#endif // JAXLIB_GPU_SOLVER_HANDLE_POOL_H_ diff --git a/jaxlib/gpu/solver_kernels.cc b/jaxlib/gpu/solver_kernels.cc index ffd39af69..8d90c7053 100644 --- a/jaxlib/gpu/solver_kernels.cc +++ b/jaxlib/gpu/solver_kernels.cc @@ -16,16 +16,16 @@ limitations under the License. #include "jaxlib/gpu/solver_kernels.h" #include +#include #include -#include -#include -#include +#include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/synchronization/mutex.h" +#include "jaxlib/gpu/solver_handle_pool.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_helpers.h" #include "xla/service/custom_call_status.h" @@ -35,46 +35,6 @@ limitations under the License. namespace jax { -template <> -/*static*/ absl::StatusOr SolverHandlePool::Borrow( - gpuStream_t stream) { - SolverHandlePool* pool = Instance(); - absl::MutexLock lock(&pool->mu_); - gpusolverDnHandle_t handle; - if (pool->handles_[stream].empty()) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCreate(&handle))); - } else { - handle = pool->handles_[stream].back(); - pool->handles_[stream].pop_back(); - } - if (stream) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSetStream(handle, stream))); - } - return Handle(pool, handle, stream); -} - -#ifdef JAX_GPU_CUDA - -template <> -/*static*/ absl::StatusOr -SpSolverHandlePool::Borrow(gpuStream_t stream) { - SpSolverHandlePool* pool = Instance(); - absl::MutexLock lock(&pool->mu_); - cusolverSpHandle_t handle; - if (pool->handles_[stream].empty()) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpCreate(&handle))); - } else { - handle = pool->handles_[stream].back(); - pool->handles_[stream].pop_back(); - } - if (stream) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpSetStream(handle, stream))); - } - return Handle(pool, handle, stream); -} - -#endif // JAX_GPU_CUDA - namespace JAX_GPU_NAMESPACE { static int SizeOfSolverType(SolverType type) { diff --git a/jaxlib/gpu/solver_kernels.h b/jaxlib/gpu/solver_kernels.h index e4d9d84b5..51082f2fe 100644 --- a/jaxlib/gpu/solver_kernels.h +++ b/jaxlib/gpu/solver_kernels.h @@ -16,33 +16,13 @@ limitations under the License. #ifndef JAXLIB_CUSOLVER_KERNELS_H_ #define JAXLIB_CUSOLVER_KERNELS_H_ -#include "absl/status/statusor.h" +#include + #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" #include "xla/service/custom_call_status.h" -#ifdef JAX_GPU_CUDA -#include "third_party/gpus/cuda/include/cusolverSp.h" -#endif // JAX_GPU_CUDA - namespace jax { -using SolverHandlePool = HandlePool; - -template <> -absl::StatusOr SolverHandlePool::Borrow( - gpuStream_t stream); - -#ifdef JAX_GPU_CUDA - -using SpSolverHandlePool = HandlePool; - -template <> -absl::StatusOr SpSolverHandlePool::Borrow( - gpuStream_t stream); - -#endif // JAX_GPU_CUDA - namespace JAX_GPU_NAMESPACE { // Set of types known to Cusolver. diff --git a/jaxlib/rocm/BUILD.bazel b/jaxlib/rocm/BUILD.bazel index 944a56215..36a6b7181 100644 --- a/jaxlib/rocm/BUILD.bazel +++ b/jaxlib/rocm/BUILD.bazel @@ -59,13 +59,32 @@ cc_library( ) cc_library( - name = "hipblas_kernels", - srcs = ["//jaxlib/gpu:blas_kernels.cc"], - hdrs = ["//jaxlib/gpu:blas_kernels.h"], + name = "hip_blas_handle_pool", + srcs = [ + "//jaxlib/gpu:blas_handle_pool.cc", + ], + hdrs = [ + "//jaxlib/gpu:blas_handle_pool.h", + ], deps = [ ":hip_gpu_kernel_helpers", ":hip_vendor", "//jaxlib:handle_pool", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + "@local_config_rocm//rocm:hipblas", + "@local_config_rocm//rocm:rocm_headers", + ], +) + +cc_library( + name = "hipblas_kernels", + srcs = ["//jaxlib/gpu:blas_kernels.cc"], + hdrs = ["//jaxlib/gpu:blas_kernels.h"], + deps = [ + ":hip_blas_handle_pool", + ":hip_gpu_kernel_helpers", + ":hip_vendor", "//jaxlib:kernel_helpers", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", @@ -73,6 +92,7 @@ cc_library( "@com_google_absl//absl/hash", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", @@ -104,14 +124,33 @@ pybind_extension( ], ) +cc_library( + name = "hip_solver_handle_pool", + srcs = [ + "//jaxlib/gpu:solver_handle_pool.cc", + ], + hdrs = [ + "//jaxlib/gpu:solver_handle_pool.h", + ], + deps = [ + ":hip_gpu_kernel_helpers", + ":hip_vendor", + "//jaxlib:handle_pool", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + "@local_config_rocm//rocm:hipsolver", + "@local_config_rocm//rocm:rocm_headers", + ], +) + cc_library( name = "hipsolver_kernels", srcs = ["//jaxlib/gpu:solver_kernels.cc"], hdrs = ["//jaxlib/gpu:solver_kernels.h"], deps = [ ":hip_gpu_kernel_helpers", + ":hip_solver_handle_pool", ":hip_vendor", - "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -132,6 +171,7 @@ pybind_extension( features = ["-use_header_modules"], module_name = "_solver", deps = [ + ":hip_gpu_handle_pools", ":hip_gpu_kernel_helpers", ":hip_vendor", ":hipsolver_kernels", @@ -241,9 +281,9 @@ pybind_extension( ":hip_linalg_kernels", ":hip_vendor", "//jaxlib:kernel_nanobind_helpers", - "@xla//xla/tsl/python/lib/core:numpy", "@local_config_rocm//rocm:rocm_headers", "@nanobind", + "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -257,11 +297,11 @@ cc_library( ":hip_vendor", "//jaxlib:ffi_helpers", "//jaxlib:kernel_helpers", - "@xla//xla/ffi/api:c_api", - "@xla//xla/ffi/api:ffi", - "@local_config_rocm//rocm:rocm_headers", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", + "@local_config_rocm//rocm:rocm_headers", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_status", ], ) @@ -274,8 +314,8 @@ rocm_library( ":hip_gpu_kernel_helpers", ":hip_vendor", "//jaxlib:kernel_helpers", - "@xla//xla/ffi/api:ffi", "@local_config_rocm//rocm:rocm_headers", + "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_status", ], )