diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 57ceebd3d..ed7d129df 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -46,6 +46,15 @@ py_binary( ], ) +py_test( + name = "build_wheel_test", + srcs = ["build_wheel_test.py"], + data = [":build_wheel"], + deps = [ + "@bazel_tools//tools/python/runfiles", + ], +) + py_binary( name = "build_gpu_plugin_wheel", srcs = ["build_gpu_plugin_wheel.py"], diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 394b7d49a..7fd63f15f 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -76,7 +76,10 @@ pyext = "pyd" if build_utils.is_windows() else "so" def exists(src_file): - return r.Rlocation(src_file) is not None + path = r.Rlocation(src_file) + if path is None: + return False + return os.path.exists(path) def patch_copy_mlir_import(src_file, dst_dir): diff --git a/jaxlib/tools/build_wheel_test.py b/jaxlib/tools/build_wheel_test.py new file mode 100644 index 000000000..a33491f1c --- /dev/null +++ b/jaxlib/tools/build_wheel_test.py @@ -0,0 +1,32 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This test verifies that the build_wheel.py runs successfully. + +import platform +import subprocess +import sys +import tempfile + +from bazel_tools.tools.python.runfiles import runfiles + +r = runfiles.Create() + +with tempfile.TemporaryDirectory(prefix="jax_build_wheel_test") as tmpdir: + subprocess.run([ + sys.executable, r.Rlocation("__main__/jaxlib/tools/build_wheel.py"), + f"--cpu={platform.machine()}", + f"--output_path={tmpdir}", + "--jaxlib_git_hash=12345678" + ], check=True)