Move cuda .py files to :gpu_support so that if :gpu_support is not present, then internal jaxlib will act like a CPU jaxlib even if --config=cuda is specified.

PiperOrigin-RevId: 403170945
This commit is contained in:
Yash Katariya 2021-10-14 13:19:11 -07:00 committed by jax authors
parent 1bafdb6d7e
commit ac0796048f
2 changed files with 8 additions and 7 deletions

View File

@ -38,6 +38,7 @@ py_binary(
] + if_not_windows([
"@org_tensorflow//tensorflow/compiler/xla/python/tpu_driver/client:py_tpu_client",
]) + if_cuda([
"//jaxlib:gpu_support",
"//jaxlib:_cublas",
"//jaxlib:_cusolver",
"//jaxlib:_cusparse",

View File

@ -19,7 +19,6 @@ load(
"cuda_library",
"flatbuffer_cc_library",
"flatbuffer_py_library",
"if_cuda_is_configured",
"if_rocm_is_configured",
"pybind_extension",
)
@ -118,12 +117,7 @@ py_library(
"lapack.py",
"pocketfft.py",
"version.py",
] + if_cuda_is_configured([
"cuda_linalg.py",
"cuda_prng.py",
"cusolver.py",
"cusparse.py",
]) + if_rocm_is_configured([
] + if_rocm_is_configured([
"rocsolver.py",
]),
deps = [":pocketfft_flatbuffers_py"],
@ -234,6 +228,12 @@ cc_library(
py_library(
name = "gpu_support",
srcs = [
"cuda_linalg.py",
"cuda_prng.py",
"cusolver.py",
"cusparse.py",
],
deps = [
":_cublas",
":_cuda_linalg",