From 88685d8de0b802f50c73f6dae1bf9bb5357cddc7 Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Wed, 15 Nov 2023 10:37:28 -0800 Subject: [PATCH] Support bazel test without bazel build for CUDA PJRT plugin. - Add build target for jax_plugins/ and jax_plugins/cuda for bazel test. - Update jax_plugins/cuda/__init__.py to fallback to local `.so` file path. - Add a flag --//jax:build_cuda_plugin to control whether to link in local cuda plugin. The following command will test with cuda plugin: ``` bazel test tests:python_callback_test_gpu --test_output=all --test_filter=PythonCallbackTest.test_send_zero_dim_arrays_pure --config=tensorflow_testing_rbe_linux --config=rbe_linux_cuda12.2_nvcc_py3.9 --//jax:build_cuda_plugin=false ``` Default behavior (without `--//jax:build_cuda_plugin=false`) remains unchanged. PiperOrigin-RevId: 582728477 --- jax/BUILD | 15 ++++++++++++ jax_plugins/BUILD.bazel | 34 ++++++++++++++++++++++++++++ jax_plugins/cuda/BUILD.bazel | 27 ++++++++++++++++++++++ jax_plugins/cuda/__init__.py | 44 ++++++++++++++++++++++++++++-------- jaxlib/jax.bzl | 3 +++ 5 files changed, 113 insertions(+), 10 deletions(-) create mode 100644 jax_plugins/BUILD.bazel diff --git a/jax/BUILD b/jax/BUILD index 2ceea2c4d..c6513d453 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -52,6 +52,21 @@ config_setting( }, ) +# When `build_cuda_plugin_for_tests` is true, it assumes `bazel build` for the cuda plugin will run +# before `bazel test` for cuda plugin tests. Set it to false for the case of running `bazel test` +# without `bazel build` for the cuda plugin. +bool_flag( + name = "build_cuda_plugin_for_tests", + build_setting_default = True, +) + +config_setting( + name = "enable_cuda_plugin_build_for_tests", + flag_values = { + ":build_cuda_plugin_for_tests": "True", + }, +) + exports_files([ "LICENSE", "version.py", diff --git a/jax_plugins/BUILD.bazel b/jax_plugins/BUILD.bazel new file mode 100644 index 000000000..2102c6404 --- /dev/null +++ b/jax_plugins/BUILD.bazel @@ -0,0 +1,34 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +load( + "//jaxlib:jax.bzl", + "if_cuda_is_configured", + "py_library_providing_imports_info", +) + +py_library_providing_imports_info( + name = "jax_plugins", +) + +py_library( + name = "gpu_plugin_only_test_deps", + deps = [ + ":jax_plugins", + ] + if_cuda_is_configured([ + "//jax_plugins/cuda:cuda_plugin", + ]), +) \ No newline at end of file diff --git a/jax_plugins/cuda/BUILD.bazel b/jax_plugins/cuda/BUILD.bazel index 0718ea858..47e73db8d 100644 --- a/jax_plugins/cuda/BUILD.bazel +++ b/jax_plugins/cuda/BUILD.bazel @@ -14,6 +14,14 @@ licenses(["notice"]) +load("//jaxlib:symlink_files.bzl", "symlink_files") +load( + "//jaxlib:jax.bzl", + "if_windows", + "py_library_providing_imports_info", + "pytype_library", +) + package( default_applicable_licenses = [], default_visibility = ["//:__subpackages__"], @@ -26,3 +34,22 @@ exports_files([ "pyproject.toml", "setup.py", ]) + +symlink_files( + name = "pjrt_c_api_gpu_plugin", + srcs = if_windows( + ["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"], + ["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"], + ), + dst = ".", + flatten = True, +) + +py_library_providing_imports_info( + name = "cuda_plugin", + srcs = [ + "__init__.py", + ], + data = [":pjrt_c_api_gpu_plugin"], + lib_rule = pytype_library, +) \ No newline at end of file diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index 431e71caf..168d8d474 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -19,24 +19,48 @@ import pathlib import platform import sys -import jax._src.xla_bridge as xb - from jax._src.lib import cuda_plugin_extension from jax._src.lib import xla_client - +import jax._src.xla_bridge as xb logger = logging.getLogger(__name__) -def initialize(): - path = pathlib.Path(__file__).resolve().parent / "xla_cuda_plugin.so" - if not path.exists(): - logger.warning( - "WARNING: Native library %s does not exist. This most likely indicates" - " an issue with how %s was built or installed.", - path, +def _get_library_path(): + installed_path = ( + pathlib.Path(__file__).resolve().parent / 'xla_cuda_plugin.so' + ) + if installed_path.exists(): + return installed_path + + local_path = os.path.join( + os.path.dirname(__file__), 'pjrt_c_api_gpu_plugin.so' + ) + if os.path.exists(local_path): + logger.debug( + 'Native library %s does not exist. This most likely indicates an issue' + ' with how %s was built or installed. Fallback to local test' + ' library %s', + installed_path, __package__, + local_path, ) + return local_path + + logger.debug( + 'WARNING: Native library %s and local test library path %s do not' + ' exist. This most likely indicates an issue with how %s was built or' + ' installed or missing src files.', + installed_path, + local_path, + __package__, + ) + return None + + +def initialize(): + path = _get_library_path() + if path is None: return # TODO(b/300099402): use the util method when it is ready. diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 392c4cd33..7d30a75ed 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -201,6 +201,9 @@ def jax_test( ] + deps + select({ "//jax:enable_jaxlib_build": ["//jaxlib/cuda:gpu_only_test_deps"], "//conditions:default": [], + }) + select({ + "//jax:enable_cuda_plugin_build_for_tests": [], + "//conditions:default": ["//jax_plugins:gpu_plugin_only_test_deps"], }), shard_count = test_shards, tags = test_tags,