mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Support bazel test without bazel build for CUDA PJRT plugin.
- Add build target for jax_plugins/ and jax_plugins/cuda for bazel test. - Update jax_plugins/cuda/__init__.py to fallback to local `.so` file path. - Add a flag --//jax:build_cuda_plugin to control whether to link in local cuda plugin. The following command will test with cuda plugin: ``` bazel test tests:python_callback_test_gpu --test_output=all --test_filter=PythonCallbackTest.test_send_zero_dim_arrays_pure --config=tensorflow_testing_rbe_linux --config=rbe_linux_cuda12.2_nvcc_py3.9 --//jax:build_cuda_plugin=false ``` Default behavior (without `--//jax:build_cuda_plugin=false`) remains unchanged. PiperOrigin-RevId: 582728477
This commit is contained in:
parent
118d85cd6c
commit
88685d8de0
15
jax/BUILD
15
jax/BUILD
@ -52,6 +52,21 @@ config_setting(
|
||||
},
|
||||
)
|
||||
|
||||
# When `build_cuda_plugin_for_tests` is true, it assumes `bazel build` for the cuda plugin will run
|
||||
# before `bazel test` for cuda plugin tests. Set it to false for the case of running `bazel test`
|
||||
# without `bazel build` for the cuda plugin.
|
||||
bool_flag(
|
||||
name = "build_cuda_plugin_for_tests",
|
||||
build_setting_default = True,
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "enable_cuda_plugin_build_for_tests",
|
||||
flag_values = {
|
||||
":build_cuda_plugin_for_tests": "True",
|
||||
},
|
||||
)
|
||||
|
||||
exports_files([
|
||||
"LICENSE",
|
||||
"version.py",
|
||||
|
34
jax_plugins/BUILD.bazel
Normal file
34
jax_plugins/BUILD.bazel
Normal file
@ -0,0 +1,34 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"if_cuda_is_configured",
|
||||
"py_library_providing_imports_info",
|
||||
)
|
||||
|
||||
py_library_providing_imports_info(
|
||||
name = "jax_plugins",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "gpu_plugin_only_test_deps",
|
||||
deps = [
|
||||
":jax_plugins",
|
||||
] + if_cuda_is_configured([
|
||||
"//jax_plugins/cuda:cuda_plugin",
|
||||
]),
|
||||
)
|
@ -14,6 +14,14 @@
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
load("//jaxlib:symlink_files.bzl", "symlink_files")
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"if_windows",
|
||||
"py_library_providing_imports_info",
|
||||
"pytype_library",
|
||||
)
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = ["//:__subpackages__"],
|
||||
@ -26,3 +34,22 @@ exports_files([
|
||||
"pyproject.toml",
|
||||
"setup.py",
|
||||
])
|
||||
|
||||
symlink_files(
|
||||
name = "pjrt_c_api_gpu_plugin",
|
||||
srcs = if_windows(
|
||||
["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"],
|
||||
["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"],
|
||||
),
|
||||
dst = ".",
|
||||
flatten = True,
|
||||
)
|
||||
|
||||
py_library_providing_imports_info(
|
||||
name = "cuda_plugin",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
],
|
||||
data = [":pjrt_c_api_gpu_plugin"],
|
||||
lib_rule = pytype_library,
|
||||
)
|
@ -19,24 +19,48 @@ import pathlib
|
||||
import platform
|
||||
import sys
|
||||
|
||||
import jax._src.xla_bridge as xb
|
||||
|
||||
from jax._src.lib import cuda_plugin_extension
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
import jax._src.xla_bridge as xb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialize():
|
||||
path = pathlib.Path(__file__).resolve().parent / "xla_cuda_plugin.so"
|
||||
if not path.exists():
|
||||
logger.warning(
|
||||
"WARNING: Native library %s does not exist. This most likely indicates"
|
||||
" an issue with how %s was built or installed.",
|
||||
path,
|
||||
def _get_library_path():
|
||||
installed_path = (
|
||||
pathlib.Path(__file__).resolve().parent / 'xla_cuda_plugin.so'
|
||||
)
|
||||
if installed_path.exists():
|
||||
return installed_path
|
||||
|
||||
local_path = os.path.join(
|
||||
os.path.dirname(__file__), 'pjrt_c_api_gpu_plugin.so'
|
||||
)
|
||||
if os.path.exists(local_path):
|
||||
logger.debug(
|
||||
'Native library %s does not exist. This most likely indicates an issue'
|
||||
' with how %s was built or installed. Fallback to local test'
|
||||
' library %s',
|
||||
installed_path,
|
||||
__package__,
|
||||
local_path,
|
||||
)
|
||||
return local_path
|
||||
|
||||
logger.debug(
|
||||
'WARNING: Native library %s and local test library path %s do not'
|
||||
' exist. This most likely indicates an issue with how %s was built or'
|
||||
' installed or missing src files.',
|
||||
installed_path,
|
||||
local_path,
|
||||
__package__,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def initialize():
|
||||
path = _get_library_path()
|
||||
if path is None:
|
||||
return
|
||||
|
||||
# TODO(b/300099402): use the util method when it is ready.
|
||||
|
@ -201,6 +201,9 @@ def jax_test(
|
||||
] + deps + select({
|
||||
"//jax:enable_jaxlib_build": ["//jaxlib/cuda:gpu_only_test_deps"],
|
||||
"//conditions:default": [],
|
||||
}) + select({
|
||||
"//jax:enable_cuda_plugin_build_for_tests": [],
|
||||
"//conditions:default": ["//jax_plugins:gpu_plugin_only_test_deps"],
|
||||
}),
|
||||
shard_count = test_shards,
|
||||
tags = test_tags,
|
||||
|
Loading…
x
Reference in New Issue
Block a user