From ad1bd387901af49fc8cf5576cd8af6abe3faf988 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 14 Aug 2024 09:20:07 -0700 Subject: [PATCH] Move logic about when to dispatch to batched LU decomposition algorithm on GPU into the kernel. This simplifies the lowering logic, and means that we don't get hit with a performance penalty when exporting with shape polymorphism. PiperOrigin-RevId: 662945116 --- jaxlib/cuda/BUILD | 21 +---- jaxlib/gpu/BUILD | 2 - jaxlib/gpu/blas.cc | 4 - jaxlib/gpu/blas_kernels_ffi.cc | 133 ------------------------------- jaxlib/gpu/blas_kernels_ffi.h | 30 ------- jaxlib/gpu/gpu_kernels.cc | 3 - jaxlib/gpu/solver_kernels_ffi.cc | 103 +++++++++++++++++++++--- jaxlib/gpu_solver.py | 18 ++--- jaxlib/rocm/BUILD.bazel | 23 +----- 9 files changed, 104 insertions(+), 233 deletions(-) delete mode 100644 jaxlib/gpu/blas_kernels_ffi.cc delete mode 100644 jaxlib/gpu/blas_kernels_ffi.h diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index e515de2d3..8121a1058 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -114,22 +114,6 @@ cc_library( ], ) -cc_library( - name = "cublas_kernels_ffi", - srcs = ["//jaxlib/gpu:blas_kernels_ffi.cc"], - hdrs = ["//jaxlib/gpu:blas_kernels_ffi.h"], - deps = [ - ":cuda_blas_handle_pool", - ":cuda_gpu_kernel_helpers", - ":cuda_vendor", - "//jaxlib:ffi_helpers", - "@xla//xla/ffi/api:ffi", - "@xla//xla/tsl/cuda:cublas", - "@xla//xla/tsl/cuda:cudart", - "@com_google_absl//absl/status", - ], -) - pybind_extension( name = "_blas", srcs = ["//jaxlib/gpu:blas.cc"], @@ -148,7 +132,6 @@ pybind_extension( module_name = "_blas", deps = [ ":cublas_kernels", - ":cublas_kernels_ffi", ":cuda_vendor", "//jaxlib:kernel_nanobind_helpers", "@xla//xla/tsl/cuda:cublas", @@ -238,11 +221,13 @@ cc_library( srcs = ["//jaxlib/gpu:solver_kernels_ffi.cc"], hdrs = ["//jaxlib/gpu:solver_kernels_ffi.h"], deps = [ + ":cuda_blas_handle_pool", ":cuda_gpu_kernel_helpers", ":cuda_solver_handle_pool", ":cuda_vendor", "//jaxlib:ffi_helpers", "@xla//xla/ffi/api:ffi", + "@xla//xla/tsl/cuda:cublas", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cusolver", "@com_google_absl//absl/status", @@ -274,6 +259,7 @@ pybind_extension( ":cusolver_kernels", ":cusolver_kernels_ffi", "//jaxlib:kernel_nanobind_helpers", + "@xla//xla/tsl/cuda:cublas", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cusolver", "@xla//xla/tsl/python/lib/core:numpy", @@ -466,7 +452,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":cublas_kernels", - ":cublas_kernels_ffi", ":cuda_linalg_kernels", ":cuda_prng_kernels", ":cuda_vendor", diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index daa03aa5b..6bdaf4ef1 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -29,8 +29,6 @@ exports_files(srcs = [ "blas_handle_pool.h", "blas_kernels.cc", "blas_kernels.h", - "blas_kernels_ffi.cc", - "blas_kernels_ffi.h", "gpu_kernel_helpers.cc", "gpu_kernel_helpers.h", "gpu_kernels.cc", diff --git a/jaxlib/gpu/blas.cc b/jaxlib/gpu/blas.cc index 62a1bbc94..e8761bd32 100644 --- a/jaxlib/gpu/blas.cc +++ b/jaxlib/gpu/blas.cc @@ -22,7 +22,6 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" #include "jaxlib/gpu/blas_kernels.h" -#include "jaxlib/gpu/blas_kernels_ffi.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" #include "xla/tsl/python/lib/core/numpy.h" @@ -70,9 +69,6 @@ nb::dict Registrations() { nb::dict dict; dict[JAX_GPU_PREFIX "blas_getrf_batched"] = EncapsulateFunction(GetrfBatched); dict[JAX_GPU_PREFIX "blas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched); - - dict[JAX_GPU_PREFIX "blas_getrf_batched_ffi"] = - EncapsulateFfiHandler(GetrfBatchedFfi); return dict; } diff --git a/jaxlib/gpu/blas_kernels_ffi.cc b/jaxlib/gpu/blas_kernels_ffi.cc deleted file mode 100644 index 610ce1052..000000000 --- a/jaxlib/gpu/blas_kernels_ffi.cc +++ /dev/null @@ -1,133 +0,0 @@ -/* Copyright 2024 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "jaxlib/gpu/blas_kernels_ffi.h" - -#include "absl/status/status.h" -#include "jaxlib/ffi_helpers.h" -#include "jaxlib/gpu/blas_handle_pool.h" -#include "jaxlib/gpu/gpu_kernel_helpers.h" -#include "jaxlib/gpu/vendor.h" -#include "xla/ffi/api/ffi.h" - -namespace jax { -namespace JAX_GPU_NAMESPACE { - -namespace ffi = ::xla::ffi; - -namespace { -#define GETRF_BATCHED_KERNEL_IMPL(type, name) \ - template <> \ - struct GetrfBatchedKernel { \ - static absl::Status Run(gpublasHandle_t handle, int n, type** a, int lda, \ - int* ipiv, int* info, int batch) { \ - return JAX_AS_STATUS(name(handle, n, a, lda, ipiv, info, batch)); \ - } \ - } - -template -struct GetrfBatchedKernel; -GETRF_BATCHED_KERNEL_IMPL(float, gpublasSgetrfBatched); -GETRF_BATCHED_KERNEL_IMPL(double, gpublasDgetrfBatched); -GETRF_BATCHED_KERNEL_IMPL(gpublasComplex, gpublasCgetrfBatched); -GETRF_BATCHED_KERNEL_IMPL(gpublasDoubleComplex, gpublasZgetrfBatched); -#undef GETRF_BATCHED_KERNEL_IMPL - -template -ffi::Error GetrfBatchedImpl(gpuStream_t stream, ffi::ScratchAllocator& scratch, - ffi::AnyBuffer a, ffi::Result out, - ffi::Result> ipiv, - ffi::Result> info) { - FFI_RETURN_IF_ERROR(CheckMatrixDimensions(a.dimensions())); - auto [batch, rows, cols] = SplitBatch2D(a.dimensions()); - if (rows != cols) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "getrf_batched only supports square matrices"); - } - FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); - FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream)); - - auto maybe_workspace = scratch.Allocate(sizeof(void*) * batch); - if (!maybe_workspace.has_value()) { - return ffi::Error(ffi::ErrorCode::kUnknown, - "Unable to allocate workspace for batched getrf"); - } - auto workspace = maybe_workspace.value(); - - auto a_data = a.untyped_data(); - auto out_data = out->untyped_data(); - auto ipiv_data = ipiv->typed_data(); - auto info_data = info->typed_data(); - if (a_data != out_data) { - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS( - gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * cols * cols, - gpuMemcpyDeviceToDevice, stream))); - } - - FFI_ASSIGN_OR_RETURN( - auto a_ptrs_host, - MakeBatchPointers(stream, out_data, workspace, batch, sizeof(T) * n * n)); - // TODO(phawkins, danfm): ideally we would not need to synchronize here, but - // to avoid it we need a way to keep the host-side buffer alive until the copy - // completes. - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); - - auto batch_ptrs = static_cast(workspace); - FFI_RETURN_IF_ERROR_STATUS(GetrfBatchedKernel::Run( - handle.get(), n, batch_ptrs, n, ipiv_data, info_data, batch)); - - return ffi::Error::Success(); -} - -ffi::Error GetrfBatchedDispatch( - gpuStream_t stream, ffi::ScratchAllocator scratch, ffi::AnyBuffer a, - ffi::Result out, - ffi::Result> ipiv, - ffi::Result> info) { - auto dataType = a.element_type(); - if (dataType != out->element_type()) { - return ffi::Error( - ffi::ErrorCode::kInvalidArgument, - "Input and output to getrf_batched must have the same element type"); - } - if (dataType == ffi::DataType::F32) { - return GetrfBatchedImpl(stream, scratch, a, out, ipiv, info); - } else if (dataType == ffi::DataType::F64) { - return GetrfBatchedImpl(stream, scratch, a, out, ipiv, info); - } else if (dataType == ffi::DataType::C64) { - return GetrfBatchedImpl(stream, scratch, a, out, ipiv, - info); - } else if (dataType == ffi::DataType::C128) { - return GetrfBatchedImpl(stream, scratch, a, out, ipiv, - info); - } - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "Unsupported element type for getrf"); -} -} // namespace - -XLA_FFI_DEFINE_HANDLER_SYMBOL( - GetrfBatchedFfi, GetrfBatchedDispatch, - ffi::Ffi::Bind() - .Ctx>() - .Ctx() - .Arg() // a - .Ret() // out - .Ret>() // ipiv - .Ret>() // info -); - -} // namespace JAX_GPU_NAMESPACE -} // namespace jax diff --git a/jaxlib/gpu/blas_kernels_ffi.h b/jaxlib/gpu/blas_kernels_ffi.h deleted file mode 100644 index ad3bf9012..000000000 --- a/jaxlib/gpu/blas_kernels_ffi.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2024 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_GPU_BLAS_KERNELS_FFI_H_ -#define JAXLIB_GPU_BLAS_KERNELS_FFI_H_ - -#include "jaxlib/gpu/vendor.h" -#include "xla/ffi/api/ffi.h" - -namespace jax { -namespace JAX_GPU_NAMESPACE { - -XLA_FFI_DECLARE_HANDLER_SYMBOL(GetrfBatchedFfi); - -} // namespace JAX_GPU_NAMESPACE -} // namespace jax - -#endif // JAXLIB_GPU_BLAS_KERNELS_FFI_H_ diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index ccca8e157..b76cea19e 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -17,7 +17,6 @@ limitations under the License. // JAX-generated HLO code from outside of JAX. #include "jaxlib/gpu/blas_kernels.h" -#include "jaxlib/gpu/blas_kernels_ffi.h" #include "jaxlib/gpu/linalg_kernels.h" #include "jaxlib/gpu/prng_kernels.h" #include "jaxlib/gpu/rnn_kernels.h" @@ -36,8 +35,6 @@ namespace { XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_getrf_batched", GetrfBatched, "CUDA"); -XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cublas_getrf_batched_ffi", "CUDA", - GetrfBatchedFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_geqrf_batched", GeqrfBatched, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn", RNNForward, "CUDA"); diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index 414e159b2..051b9fd03 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -16,10 +16,12 @@ limitations under the License. #include "jaxlib/gpu/solver_kernels_ffi.h" #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "jaxlib/ffi_helpers.h" +#include "jaxlib/gpu/blas_handle_pool.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/solver_handle_pool.h" #include "jaxlib/gpu/vendor.h" @@ -58,12 +60,11 @@ GETRF_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZgetrf); #undef GETRF_KERNEL_IMPL template -ffi::Error GetrfImpl(gpuStream_t stream, ffi::ScratchAllocator& scratch, +ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols, + gpuStream_t stream, ffi::ScratchAllocator& scratch, ffi::AnyBuffer a, ffi::Result out, ffi::Result> ipiv, ffi::Result> info) { - FFI_RETURN_IF_ERROR(CheckMatrixDimensions(a.dimensions())); - auto [batch, rows, cols] = SplitBatch2D(a.dimensions()); FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows)); FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); @@ -98,6 +99,64 @@ ffi::Error GetrfImpl(gpuStream_t stream, ffi::ScratchAllocator& scratch, return ffi::Error::Success(); } +#define GETRF_BATCHED_KERNEL_IMPL(type, name) \ + template <> \ + struct GetrfBatchedKernel { \ + static absl::Status Run(gpublasHandle_t handle, int n, type** a, int lda, \ + int* ipiv, int* info, int batch) { \ + return JAX_AS_STATUS(name(handle, n, a, lda, ipiv, info, batch)); \ + } \ + } + +template +struct GetrfBatchedKernel; +GETRF_BATCHED_KERNEL_IMPL(float, gpublasSgetrfBatched); +GETRF_BATCHED_KERNEL_IMPL(double, gpublasDgetrfBatched); +GETRF_BATCHED_KERNEL_IMPL(gpublasComplex, gpublasCgetrfBatched); +GETRF_BATCHED_KERNEL_IMPL(gpublasDoubleComplex, gpublasZgetrfBatched); +#undef GETRF_BATCHED_KERNEL_IMPL + +template +ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream, + ffi::ScratchAllocator& scratch, ffi::AnyBuffer a, + ffi::Result out, + ffi::Result> ipiv, + ffi::Result> info) { + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); + FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream)); + + auto maybe_workspace = scratch.Allocate(sizeof(void*) * batch); + if (!maybe_workspace.has_value()) { + return ffi::Error(ffi::ErrorCode::kUnknown, + "Unable to allocate workspace for batched getrf"); + } + auto workspace = maybe_workspace.value(); + + auto a_data = a.untyped_data(); + auto out_data = out->untyped_data(); + auto ipiv_data = ipiv->typed_data(); + auto info_data = info->typed_data(); + if (a_data != out_data) { + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS( + gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * cols * cols, + gpuMemcpyDeviceToDevice, stream))); + } + + FFI_ASSIGN_OR_RETURN( + auto a_ptrs_host, + MakeBatchPointers(stream, out_data, workspace, batch, sizeof(T) * n * n)); + // TODO(phawkins, danfm): ideally we would not need to synchronize here, but + // to avoid it we need a way to keep the host-side buffer alive until the copy + // completes. + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + + auto batch_ptrs = static_cast(workspace); + FFI_RETURN_IF_ERROR_STATUS(GetrfBatchedKernel::Run( + handle.get(), n, batch_ptrs, n, ipiv_data, info_data, batch)); + + return ffi::Error::Success(); +} + ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, ffi::AnyBuffer a, ffi::Result out, ffi::Result> ipiv, @@ -108,14 +167,36 @@ ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, ffi::ErrorCode::kInvalidArgument, "The input and output to getrf must have the same element type"); } - if (dataType == ffi::DataType::F32) { - return GetrfImpl(stream, scratch, a, out, ipiv, info); - } else if (dataType == ffi::DataType::F64) { - return GetrfImpl(stream, scratch, a, out, ipiv, info); - } else if (dataType == ffi::DataType::C64) { - return GetrfImpl(stream, scratch, a, out, ipiv, info); - } else if (dataType == ffi::DataType::C128) { - return GetrfImpl(stream, scratch, a, out, ipiv, info); + FFI_RETURN_IF_ERROR(CheckMatrixDimensions(a.dimensions())); + auto [batch, rows, cols] = SplitBatch2D(a.dimensions()); + if (batch > 1 && rows == cols && rows / batch <= 128) { + if (dataType == ffi::DataType::F32) { + return GetrfBatchedImpl(batch, cols, stream, scratch, a, out, ipiv, + info); + } else if (dataType == ffi::DataType::F64) { + return GetrfBatchedImpl(batch, cols, stream, scratch, a, out, + ipiv, info); + } else if (dataType == ffi::DataType::C64) { + return GetrfBatchedImpl(batch, cols, stream, scratch, a, + out, ipiv, info); + } else if (dataType == ffi::DataType::C128) { + return GetrfBatchedImpl( + batch, cols, stream, scratch, a, out, ipiv, info); + } + } else { + if (dataType == ffi::DataType::F32) { + return GetrfImpl(batch, rows, cols, stream, scratch, a, out, ipiv, + info); + } else if (dataType == ffi::DataType::F64) { + return GetrfImpl(batch, rows, cols, stream, scratch, a, out, ipiv, + info); + } else if (dataType == ffi::DataType::C64) { + return GetrfImpl(batch, rows, cols, stream, scratch, a, out, + ipiv, info); + } else if (dataType == ffi::DataType::C128) { + return GetrfImpl(batch, rows, cols, stream, scratch, a, + out, ipiv, info); + } } return ffi::Error(ffi::ErrorCode::kInvalidArgument, "Unsupported element type for getrf"); diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index 87171fdb4..baa84e8eb 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -43,10 +43,7 @@ except ImportError: if _cublas: for _name, _value in _cublas.registrations().items(): - # TODO(danfm): Clean up after all legacy custom calls are ported. - api_version = 1 if _name.endswith("_ffi") else 0 - xla_client.register_custom_call_target(_name, _value, platform="CUDA", - api_version=api_version) + xla_client.register_custom_call_target(_name, _value, platform="CUDA") for cuda_module_name in [".cuda", "jax_cuda12_plugin"]: try: @@ -78,10 +75,7 @@ except ImportError: if _hipblas: for _name, _value in _hipblas.registrations().items(): - # TODO(danfm): Clean up after all legacy custom calls are ported. - api_version = 1 if _name.endswith("_ffi") else 0 - xla_client.register_custom_call_target(_name, _value, platform="ROCM", - api_version=api_version) + xla_client.register_custom_call_target(_name, _value, platform="ROCM") for rocm_module_name in [".rocm", "jax_rocm60_plugin"]: try: @@ -115,15 +109,14 @@ def _getrf_hlo(platform, gpu_blas, gpu_solver, ctx, dtype, a): num_bd = len(batch_dims) i32_type = ir.IntegerType.get_signless(32) layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - batch = math.prod(batch_dims) - use_batched = batch > 1 and m == n and m // batch <= 128 # TODO(b/357034884): Remove after 3 week forward compatibility window. if ctx.is_forward_compat(): if not gpu_blas: raise GpuLibNotLinkedError() - if use_batched: + batch = math.prod(batch_dims) + if batch > 1 and m == n and m // batch <= 128: lwork, opaque = gpu_blas.build_getrf_batched_descriptor( np.dtype(dtype), batch, m) workspace = ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)) @@ -154,9 +147,8 @@ def _getrf_hlo(platform, gpu_blas, gpu_solver, ctx, dtype, a): operand_output_aliases={0: 0}).results return out[:3] - target = "blas_getrf_batched_ffi" if use_batched else "solver_getrf_ffi" return custom_call( - f"{platform}{target}", + f"{platform}solver_getrf_ffi", result_types=[ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n),), i32_type), diff --git a/jaxlib/rocm/BUILD.bazel b/jaxlib/rocm/BUILD.bazel index 1ec36fd30..ba9ceb4c3 100644 --- a/jaxlib/rocm/BUILD.bazel +++ b/jaxlib/rocm/BUILD.bazel @@ -98,23 +98,6 @@ cc_library( ], ) -cc_library( - name = "hipblas_kernels_ffi", - srcs = ["//jaxlib/gpu:blas_kernels_ffi.cc"], - hdrs = ["//jaxlib/gpu:blas_kernels_ffi.h"], - deps = [ - ":hip_blas_handle_pool", - ":hip_gpu_kernel_helpers", - ":hip_vendor", - "//jaxlib:ffi_helpers", - "@com_google_absl//absl/status", - "@local_config_rocm//rocm:hipblas", - "@local_config_rocm//rocm:rocm_headers", - "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", - ], -) - pybind_extension( name = "_blas", srcs = ["//jaxlib/gpu:blas.cc"], @@ -127,7 +110,6 @@ pybind_extension( deps = [ ":hip_vendor", ":hipblas_kernels", - ":hipblas_kernels_ffi", "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:str_format", @@ -176,12 +158,14 @@ cc_library( srcs = ["//jaxlib/gpu:solver_kernels_ffi.cc"], hdrs = ["//jaxlib/gpu:solver_kernels_ffi.h"], deps = [ + ":hip_blas_handle_pool", ":hip_gpu_kernel_helpers", ":hip_solver_handle_pool", ":hip_vendor", "//jaxlib:ffi_helpers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@local_config_rocm//rocm:hipblas", "@local_config_rocm//rocm:hipsolver", "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:ffi", @@ -199,14 +183,15 @@ pybind_extension( features = ["-use_header_modules"], module_name = "_solver", deps = [ - ":hip_solver_handle_pool", ":hip_gpu_kernel_helpers", + ":hip_solver_handle_pool", ":hip_vendor", ":hipsolver_kernels", ":hipsolver_kernels_ffi", "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:str_format", + "@local_config_rocm//rocm:hipblas", "@local_config_rocm//rocm:hipsolver", "@local_config_rocm//rocm:rocm_headers", "@nanobind",