Support bazel test without bazel build for CUDA PJRT plugin.

- Add build target for jax_plugins/ and jax_plugins/cuda for bazel test.
- Update jax_plugins/cuda/__init__.py to fallback to local `.so` file path.
- Add a flag --//jax:build_cuda_plugin to control whether to link in local cuda plugin.

The following command will test with cuda plugin:
```
bazel test tests:python_callback_test_gpu --test_output=all --test_filter=PythonCallbackTest.test_send_zero_dim_arrays_pure --config=tensorflow_testing_rbe_linux --config=rbe_linux_cuda12.2_nvcc_py3.9 --//jax:build_cuda_plugin=false
```

Default behavior (without `--//jax:build_cuda_plugin=false`) remains unchanged.

PiperOrigin-RevId: 582728477
This commit is contained in:
Jieying Luo 2023-11-15 10:37:28 -08:00 committed by jax authors
parent 118d85cd6c
commit 88685d8de0
5 changed files with 113 additions and 10 deletions

View File

@ -52,6 +52,21 @@ config_setting(
},
)
# When `build_cuda_plugin_for_tests` is true, it assumes `bazel build` for the cuda plugin will run
# before `bazel test` for cuda plugin tests. Set it to false for the case of running `bazel test`
# without `bazel build` for the cuda plugin.
bool_flag(
name = "build_cuda_plugin_for_tests",
build_setting_default = True,
)
config_setting(
name = "enable_cuda_plugin_build_for_tests",
flag_values = {
":build_cuda_plugin_for_tests": "True",
},
)
exports_files([
"LICENSE",
"version.py",

34
jax_plugins/BUILD.bazel Normal file
View File

@ -0,0 +1,34 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
load(
"//jaxlib:jax.bzl",
"if_cuda_is_configured",
"py_library_providing_imports_info",
)
py_library_providing_imports_info(
name = "jax_plugins",
)
py_library(
name = "gpu_plugin_only_test_deps",
deps = [
":jax_plugins",
] + if_cuda_is_configured([
"//jax_plugins/cuda:cuda_plugin",
]),
)

View File

@ -14,6 +14,14 @@
licenses(["notice"])
load("//jaxlib:symlink_files.bzl", "symlink_files")
load(
"//jaxlib:jax.bzl",
"if_windows",
"py_library_providing_imports_info",
"pytype_library",
)
package(
default_applicable_licenses = [],
default_visibility = ["//:__subpackages__"],
@ -26,3 +34,22 @@ exports_files([
"pyproject.toml",
"setup.py",
])
symlink_files(
name = "pjrt_c_api_gpu_plugin",
srcs = if_windows(
["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"],
["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"],
),
dst = ".",
flatten = True,
)
py_library_providing_imports_info(
name = "cuda_plugin",
srcs = [
"__init__.py",
],
data = [":pjrt_c_api_gpu_plugin"],
lib_rule = pytype_library,
)

View File

@ -19,24 +19,48 @@ import pathlib
import platform
import sys
import jax._src.xla_bridge as xb
from jax._src.lib import cuda_plugin_extension
from jax._src.lib import xla_client
import jax._src.xla_bridge as xb
logger = logging.getLogger(__name__)
def initialize():
path = pathlib.Path(__file__).resolve().parent / "xla_cuda_plugin.so"
if not path.exists():
logger.warning(
"WARNING: Native library %s does not exist. This most likely indicates"
" an issue with how %s was built or installed.",
path,
def _get_library_path():
installed_path = (
pathlib.Path(__file__).resolve().parent / 'xla_cuda_plugin.so'
)
if installed_path.exists():
return installed_path
local_path = os.path.join(
os.path.dirname(__file__), 'pjrt_c_api_gpu_plugin.so'
)
if os.path.exists(local_path):
logger.debug(
'Native library %s does not exist. This most likely indicates an issue'
' with how %s was built or installed. Fallback to local test'
' library %s',
installed_path,
__package__,
local_path,
)
return local_path
logger.debug(
'WARNING: Native library %s and local test library path %s do not'
' exist. This most likely indicates an issue with how %s was built or'
' installed or missing src files.',
installed_path,
local_path,
__package__,
)
return None
def initialize():
path = _get_library_path()
if path is None:
return
# TODO(b/300099402): use the util method when it is ready.

View File

@ -201,6 +201,9 @@ def jax_test(
] + deps + select({
"//jax:enable_jaxlib_build": ["//jaxlib/cuda:gpu_only_test_deps"],
"//conditions:default": [],
}) + select({
"//jax:enable_cuda_plugin_build_for_tests": [],
"//conditions:default": ["//jax_plugins:gpu_plugin_only_test_deps"],
}),
shard_count = test_shards,
tags = test_tags,