Merge pull request #22909 from ROCm:ci_fix_solver_paths

PiperOrigin-RevId: 660515208
This commit is contained in:
jax authors 2024-08-07 13:26:17 -07:00
commit de02988e94
3 changed files with 8 additions and 7 deletions

View File

@ -22,7 +22,6 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "absl/status/statusor.h"
#include "third_party/gpus/cuda/include/cusolver_common.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/solver_handle_pool.h"
#include "jaxlib/gpu/solver_kernels.h"

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "third_party/gpus/cuda/include/cuda_runtime_api.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cufft.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cusolverDn.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cusolver_common.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cusparse.h" // IWYU pragma: export
#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: export

View File

@ -100,10 +100,10 @@ cc_library(
cc_library(
name = "hipblas_kernels_ffi",
srcs = ["//third_party/py/jax/jaxlib/gpu:blas_kernels_ffi.cc"],
hdrs = ["//third_party/py/jax/jaxlib/gpu:blas_kernels_ffi.h"],
srcs = ["//jaxlib/gpu:blas_kernels_ffi.cc"],
hdrs = ["//jaxlib/gpu:blas_kernels_ffi.h"],
deps = [
":hip_gpu_handle_pools",
":hip_blas_handle_pool",
":hip_gpu_kernel_helpers",
":hip_vendor",
"//jaxlib:ffi_helpers",
@ -173,8 +173,8 @@ cc_library(
cc_library(
name = "hipsolver_kernels_ffi",
srcs = ["//third_party/py/jax/jaxlib/gpu:solver_kernels_ffi.cc"],
hdrs = ["//third_party/py/jax/jaxlib/gpu:solver_kernels_ffi.h"],
srcs = ["//jaxlib/gpu:solver_kernels_ffi.cc"],
hdrs = ["//jaxlib/gpu:solver_kernels_ffi.h"],
deps = [
":hip_gpu_kernel_helpers",
":hip_solver_handle_pool",
@ -199,10 +199,11 @@ pybind_extension(
features = ["-use_header_modules"],
module_name = "_solver",
deps = [
":hip_gpu_handle_pools",
":hip_solver_handle_pool",
":hip_gpu_kernel_helpers",
":hip_vendor",
":hipsolver_kernels",
":hipsolver_kernels_ffi",
"//jaxlib:kernel_nanobind_helpers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings:str_format",