mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Move jaxlib GPU handlers to separate build target.
In anticipation of refactoring the jaxlib GPU custom calls into FFI calls, this change moves the implementation of `BlasHandlePool`, `SolverHandlePool`, and `SpSolverHandlePool` into new target. PiperOrigin-RevId: 658497960
This commit is contained in:
parent
b3924da2a1
commit
f20efc630f
@ -73,13 +73,33 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cublas_kernels",
|
||||
srcs = ["//jaxlib/gpu:blas_kernels.cc"],
|
||||
hdrs = ["//jaxlib/gpu:blas_kernels.h"],
|
||||
name = "cuda_blas_handle_pool",
|
||||
srcs = [
|
||||
"//jaxlib/gpu:blas_handle_pool.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"//jaxlib/gpu:blas_handle_pool.h",
|
||||
],
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:handle_pool",
|
||||
"@xla//xla/tsl/cuda:cublas",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cublas_kernels",
|
||||
srcs = ["//jaxlib/gpu:blas_kernels.cc"],
|
||||
hdrs = ["//jaxlib/gpu:blas_kernels.h"],
|
||||
deps = [
|
||||
":cuda_blas_handle_pool",
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@xla//xla/tsl/cuda:cublas",
|
||||
@ -90,9 +110,9 @@ cc_library(
|
||||
"@com_google_absl//absl/hash",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@local_config_cuda//cuda:cublas_headers",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
],
|
||||
@ -119,6 +139,7 @@ pybind_extension(
|
||||
":cuda_vendor",
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"@xla//xla/tsl/cuda:cublas",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/python/lib/core:numpy",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
@ -165,20 +186,40 @@ pybind_extension(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cuda_solver_handle_pool",
|
||||
srcs = [
|
||||
"//jaxlib/gpu:solver_handle_pool.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"//jaxlib/gpu:solver_handle_pool.h",
|
||||
],
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:handle_pool",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/cuda:cusolver",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cusolver_kernels",
|
||||
srcs = ["//jaxlib/gpu:solver_kernels.cc"],
|
||||
hdrs = ["//jaxlib/gpu:solver_kernels.h"],
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_solver_handle_pool",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:handle_pool",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/cuda:cusolver",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
],
|
||||
)
|
||||
@ -201,6 +242,7 @@ pybind_extension(
|
||||
module_name = "_solver",
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_solver_handle_pool",
|
||||
":cuda_vendor",
|
||||
":cusolver_kernels",
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
|
@ -25,6 +25,8 @@ package(
|
||||
|
||||
exports_files(srcs = [
|
||||
"blas.cc",
|
||||
"blas_handle_pool.cc",
|
||||
"blas_handle_pool.h",
|
||||
"blas_kernels.cc",
|
||||
"blas_kernels.h",
|
||||
"gpu_kernel_helpers.cc",
|
||||
@ -42,6 +44,8 @@ exports_files(srcs = [
|
||||
"rnn_kernels.cc",
|
||||
"rnn_kernels.h",
|
||||
"solver.cc",
|
||||
"solver_handle_pool.cc",
|
||||
"solver_handle_pool.h",
|
||||
"solver_kernels.cc",
|
||||
"solver_kernels.h",
|
||||
"sparse.cc",
|
||||
|
44
jaxlib/gpu/blas_handle_pool.cc
Normal file
44
jaxlib/gpu/blas_handle_pool.cc
Normal file
@ -0,0 +1,44 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/gpu/blas_handle_pool.h"
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
template <>
|
||||
/*static*/ absl::StatusOr<BlasHandlePool::Handle> BlasHandlePool::Borrow(
|
||||
gpuStream_t stream) {
|
||||
BlasHandlePool* pool = Instance();
|
||||
absl::MutexLock lock(&pool->mu_);
|
||||
gpublasHandle_t handle;
|
||||
if (pool->handles_[stream].empty()) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasCreate(&handle)));
|
||||
} else {
|
||||
handle = pool->handles_[stream].back();
|
||||
pool->handles_[stream].pop_back();
|
||||
}
|
||||
if (stream) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasSetStream(handle, stream)));
|
||||
}
|
||||
return Handle(pool, handle, stream);
|
||||
}
|
||||
|
||||
} // namespace jax
|
33
jaxlib/gpu/blas_handle_pool.h
Normal file
33
jaxlib/gpu/blas_handle_pool.h
Normal file
@ -0,0 +1,33 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef JAXLIB_GPU_BLAS_HANDLE_POOL_H_
|
||||
#define JAXLIB_GPU_BLAS_HANDLE_POOL_H_
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
using BlasHandlePool = HandlePool<gpublasHandle_t, gpuStream_t>;
|
||||
|
||||
template <>
|
||||
absl::StatusOr<BlasHandlePool::Handle> BlasHandlePool::Borrow(
|
||||
gpuStream_t stream);
|
||||
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_GPU_BLAS_HANDLE_POOL_H_
|
@ -16,42 +16,22 @@ limitations under the License.
|
||||
#include "jaxlib/gpu/blas_kernels.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
#include <utility>
|
||||
#include <cstddef>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/base/casts.h"
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "jaxlib/gpu/blas_handle_pool.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
using BlasHandlePool = HandlePool<gpublasHandle_t, gpuStream_t>;
|
||||
|
||||
template <>
|
||||
/*static*/ absl::StatusOr<BlasHandlePool::Handle> BlasHandlePool::Borrow(
|
||||
gpuStream_t stream) {
|
||||
BlasHandlePool* pool = Instance();
|
||||
absl::MutexLock lock(&pool->mu_);
|
||||
gpublasHandle_t handle;
|
||||
if (pool->handles_[stream].empty()) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasCreate(&handle)));
|
||||
} else {
|
||||
handle = pool->handles_[stream].back();
|
||||
pool->handles_[stream].pop_back();
|
||||
}
|
||||
if (stream) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasSetStream(handle, stream)));
|
||||
}
|
||||
return Handle(pool, handle, stream);
|
||||
}
|
||||
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
namespace {
|
||||
|
@ -13,23 +13,22 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "nanobind/nanobind.h"
|
||||
#include "nanobind/stl/pair.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "third_party/gpus/cuda/include/cusolver_common.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/solver_handle_pool.h"
|
||||
#include "jaxlib/gpu/solver_kernels.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/kernel_nanobind_helpers.h"
|
||||
#include "xla/tsl/python/lib/core/numpy.h"
|
||||
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
namespace {
|
||||
|
70
jaxlib/gpu/solver_handle_pool.cc
Normal file
70
jaxlib/gpu/solver_handle_pool.cc
Normal file
@ -0,0 +1,70 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/gpu/solver_handle_pool.h"
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
|
||||
#ifdef JAX_GPU_CUDA
|
||||
#include "third_party/gpus/cuda/include/cusolverSp.h"
|
||||
#endif // JAX_GPU_CUDA
|
||||
|
||||
namespace jax {
|
||||
|
||||
template <>
|
||||
/*static*/ absl::StatusOr<SolverHandlePool::Handle> SolverHandlePool::Borrow(
|
||||
gpuStream_t stream) {
|
||||
SolverHandlePool* pool = Instance();
|
||||
absl::MutexLock lock(&pool->mu_);
|
||||
gpusolverDnHandle_t handle;
|
||||
if (pool->handles_[stream].empty()) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCreate(&handle)));
|
||||
} else {
|
||||
handle = pool->handles_[stream].back();
|
||||
pool->handles_[stream].pop_back();
|
||||
}
|
||||
if (stream) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSetStream(handle, stream)));
|
||||
}
|
||||
return Handle(pool, handle, stream);
|
||||
}
|
||||
|
||||
#ifdef JAX_GPU_CUDA
|
||||
|
||||
template <>
|
||||
/*static*/ absl::StatusOr<SpSolverHandlePool::Handle>
|
||||
SpSolverHandlePool::Borrow(gpuStream_t stream) {
|
||||
SpSolverHandlePool* pool = Instance();
|
||||
absl::MutexLock lock(&pool->mu_);
|
||||
cusolverSpHandle_t handle;
|
||||
if (pool->handles_[stream].empty()) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpCreate(&handle)));
|
||||
} else {
|
||||
handle = pool->handles_[stream].back();
|
||||
pool->handles_[stream].pop_back();
|
||||
}
|
||||
if (stream) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpSetStream(handle, stream)));
|
||||
}
|
||||
return Handle(pool, handle, stream);
|
||||
}
|
||||
|
||||
#endif // JAX_GPU_CUDA
|
||||
|
||||
} // namespace jax
|
45
jaxlib/gpu/solver_handle_pool.h
Normal file
45
jaxlib/gpu/solver_handle_pool.h
Normal file
@ -0,0 +1,45 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef JAXLIB_GPU_SOLVER_HANDLE_POOL_H_
|
||||
#define JAXLIB_GPU_SOLVER_HANDLE_POOL_H_
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
|
||||
#ifdef JAX_GPU_CUDA
|
||||
#include "third_party/gpus/cuda/include/cusolverSp.h"
|
||||
#endif // JAX_GPU_CUDA
|
||||
|
||||
namespace jax {
|
||||
|
||||
using SolverHandlePool = HandlePool<gpusolverDnHandle_t, gpuStream_t>;
|
||||
|
||||
template <>
|
||||
absl::StatusOr<SolverHandlePool::Handle> SolverHandlePool::Borrow(
|
||||
gpuStream_t stream);
|
||||
|
||||
#ifdef JAX_GPU_CUDA
|
||||
using SpSolverHandlePool = HandlePool<cusolverSpHandle_t, gpuStream_t>;
|
||||
|
||||
template <>
|
||||
absl::StatusOr<SpSolverHandlePool::Handle> SpSolverHandlePool::Borrow(
|
||||
gpuStream_t stream);
|
||||
#endif // JAX_GPU_CUDA
|
||||
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_GPU_SOLVER_HANDLE_POOL_H_
|
@ -16,16 +16,16 @@ limitations under the License.
|
||||
#include "jaxlib/gpu/solver_kernels.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <stdexcept>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "jaxlib/gpu/solver_handle_pool.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
@ -35,46 +35,6 @@ limitations under the License.
|
||||
|
||||
namespace jax {
|
||||
|
||||
template <>
|
||||
/*static*/ absl::StatusOr<SolverHandlePool::Handle> SolverHandlePool::Borrow(
|
||||
gpuStream_t stream) {
|
||||
SolverHandlePool* pool = Instance();
|
||||
absl::MutexLock lock(&pool->mu_);
|
||||
gpusolverDnHandle_t handle;
|
||||
if (pool->handles_[stream].empty()) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCreate(&handle)));
|
||||
} else {
|
||||
handle = pool->handles_[stream].back();
|
||||
pool->handles_[stream].pop_back();
|
||||
}
|
||||
if (stream) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSetStream(handle, stream)));
|
||||
}
|
||||
return Handle(pool, handle, stream);
|
||||
}
|
||||
|
||||
#ifdef JAX_GPU_CUDA
|
||||
|
||||
template <>
|
||||
/*static*/ absl::StatusOr<SpSolverHandlePool::Handle>
|
||||
SpSolverHandlePool::Borrow(gpuStream_t stream) {
|
||||
SpSolverHandlePool* pool = Instance();
|
||||
absl::MutexLock lock(&pool->mu_);
|
||||
cusolverSpHandle_t handle;
|
||||
if (pool->handles_[stream].empty()) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpCreate(&handle)));
|
||||
} else {
|
||||
handle = pool->handles_[stream].back();
|
||||
pool->handles_[stream].pop_back();
|
||||
}
|
||||
if (stream) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpSetStream(handle, stream)));
|
||||
}
|
||||
return Handle(pool, handle, stream);
|
||||
}
|
||||
|
||||
#endif // JAX_GPU_CUDA
|
||||
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
static int SizeOfSolverType(SolverType type) {
|
||||
|
@ -16,33 +16,13 @@ limitations under the License.
|
||||
#ifndef JAXLIB_CUSOLVER_KERNELS_H_
|
||||
#define JAXLIB_CUSOLVER_KERNELS_H_
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include <cstddef>
|
||||
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
#ifdef JAX_GPU_CUDA
|
||||
#include "third_party/gpus/cuda/include/cusolverSp.h"
|
||||
#endif // JAX_GPU_CUDA
|
||||
|
||||
namespace jax {
|
||||
|
||||
using SolverHandlePool = HandlePool<gpusolverDnHandle_t, gpuStream_t>;
|
||||
|
||||
template <>
|
||||
absl::StatusOr<SolverHandlePool::Handle> SolverHandlePool::Borrow(
|
||||
gpuStream_t stream);
|
||||
|
||||
#ifdef JAX_GPU_CUDA
|
||||
|
||||
using SpSolverHandlePool = HandlePool<cusolverSpHandle_t, gpuStream_t>;
|
||||
|
||||
template <>
|
||||
absl::StatusOr<SpSolverHandlePool::Handle> SpSolverHandlePool::Borrow(
|
||||
gpuStream_t stream);
|
||||
|
||||
#endif // JAX_GPU_CUDA
|
||||
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
// Set of types known to Cusolver.
|
||||
|
@ -59,13 +59,32 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hipblas_kernels",
|
||||
srcs = ["//jaxlib/gpu:blas_kernels.cc"],
|
||||
hdrs = ["//jaxlib/gpu:blas_kernels.h"],
|
||||
name = "hip_blas_handle_pool",
|
||||
srcs = [
|
||||
"//jaxlib/gpu:blas_handle_pool.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"//jaxlib/gpu:blas_handle_pool.h",
|
||||
],
|
||||
deps = [
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_vendor",
|
||||
"//jaxlib:handle_pool",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@local_config_rocm//rocm:hipblas",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hipblas_kernels",
|
||||
srcs = ["//jaxlib/gpu:blas_kernels.cc"],
|
||||
hdrs = ["//jaxlib/gpu:blas_kernels.h"],
|
||||
deps = [
|
||||
":hip_blas_handle_pool",
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_vendor",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base",
|
||||
@ -73,6 +92,7 @@ cc_library(
|
||||
"@com_google_absl//absl/hash",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
@ -104,14 +124,33 @@ pybind_extension(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hip_solver_handle_pool",
|
||||
srcs = [
|
||||
"//jaxlib/gpu:solver_handle_pool.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"//jaxlib/gpu:solver_handle_pool.h",
|
||||
],
|
||||
deps = [
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_vendor",
|
||||
"//jaxlib:handle_pool",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@local_config_rocm//rocm:hipsolver",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hipsolver_kernels",
|
||||
srcs = ["//jaxlib/gpu:solver_kernels.cc"],
|
||||
hdrs = ["//jaxlib/gpu:solver_kernels.h"],
|
||||
deps = [
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_solver_handle_pool",
|
||||
":hip_vendor",
|
||||
"//jaxlib:handle_pool",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
@ -132,6 +171,7 @@ pybind_extension(
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "_solver",
|
||||
deps = [
|
||||
":hip_gpu_handle_pools",
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_vendor",
|
||||
":hipsolver_kernels",
|
||||
@ -241,9 +281,9 @@ pybind_extension(
|
||||
":hip_linalg_kernels",
|
||||
":hip_vendor",
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"@xla//xla/tsl/python/lib/core:numpy",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@nanobind",
|
||||
"@xla//xla/tsl/python/lib/core:numpy",
|
||||
],
|
||||
)
|
||||
|
||||
@ -257,11 +297,11 @@ cc_library(
|
||||
":hip_vendor",
|
||||
"//jaxlib:ffi_helpers",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/status",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
],
|
||||
)
|
||||
@ -274,8 +314,8 @@ rocm_library(
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_vendor",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
],
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user