mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #22909 from ROCm:ci_fix_solver_paths
PiperOrigin-RevId: 660515208
This commit is contained in:
commit
de02988e94
@ -22,7 +22,6 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "third_party/gpus/cuda/include/cusolver_common.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/solver_handle_pool.h"
|
||||
#include "jaxlib/gpu/solver_kernels.h"
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h" // IWYU pragma: export
|
||||
#include "third_party/gpus/cuda/include/cufft.h" // IWYU pragma: export
|
||||
#include "third_party/gpus/cuda/include/cusolverDn.h" // IWYU pragma: export
|
||||
#include "third_party/gpus/cuda/include/cusolver_common.h" // IWYU pragma: export
|
||||
#include "third_party/gpus/cuda/include/cusparse.h" // IWYU pragma: export
|
||||
#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: export
|
||||
|
||||
|
@ -100,10 +100,10 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "hipblas_kernels_ffi",
|
||||
srcs = ["//third_party/py/jax/jaxlib/gpu:blas_kernels_ffi.cc"],
|
||||
hdrs = ["//third_party/py/jax/jaxlib/gpu:blas_kernels_ffi.h"],
|
||||
srcs = ["//jaxlib/gpu:blas_kernels_ffi.cc"],
|
||||
hdrs = ["//jaxlib/gpu:blas_kernels_ffi.h"],
|
||||
deps = [
|
||||
":hip_gpu_handle_pools",
|
||||
":hip_blas_handle_pool",
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_vendor",
|
||||
"//jaxlib:ffi_helpers",
|
||||
@ -173,8 +173,8 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "hipsolver_kernels_ffi",
|
||||
srcs = ["//third_party/py/jax/jaxlib/gpu:solver_kernels_ffi.cc"],
|
||||
hdrs = ["//third_party/py/jax/jaxlib/gpu:solver_kernels_ffi.h"],
|
||||
srcs = ["//jaxlib/gpu:solver_kernels_ffi.cc"],
|
||||
hdrs = ["//jaxlib/gpu:solver_kernels_ffi.h"],
|
||||
deps = [
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_solver_handle_pool",
|
||||
@ -199,10 +199,11 @@ pybind_extension(
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "_solver",
|
||||
deps = [
|
||||
":hip_gpu_handle_pools",
|
||||
":hip_solver_handle_pool",
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_vendor",
|
||||
":hipsolver_kernels",
|
||||
":hipsolver_kernels_ffi",
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
|
Loading…
x
Reference in New Issue
Block a user