mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Add a bazel test that verifies that the jaxlib wheel builds.
This commit is contained in:
parent
8a8cd6d01a
commit
dedd69f323
@ -46,6 +46,15 @@ py_binary(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "build_wheel_test",
|
||||
srcs = ["build_wheel_test.py"],
|
||||
data = [":build_wheel"],
|
||||
deps = [
|
||||
"@bazel_tools//tools/python/runfiles",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "build_gpu_plugin_wheel",
|
||||
srcs = ["build_gpu_plugin_wheel.py"],
|
||||
|
@ -76,7 +76,10 @@ pyext = "pyd" if build_utils.is_windows() else "so"
|
||||
|
||||
|
||||
def exists(src_file):
|
||||
return r.Rlocation(src_file) is not None
|
||||
path = r.Rlocation(src_file)
|
||||
if path is None:
|
||||
return False
|
||||
return os.path.exists(path)
|
||||
|
||||
|
||||
def patch_copy_mlir_import(src_file, dst_dir):
|
||||
|
32
jaxlib/tools/build_wheel_test.py
Normal file
32
jaxlib/tools/build_wheel_test.py
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This test verifies that the build_wheel.py runs successfully.
|
||||
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from bazel_tools.tools.python.runfiles import runfiles
|
||||
|
||||
r = runfiles.Create()
|
||||
|
||||
with tempfile.TemporaryDirectory(prefix="jax_build_wheel_test") as tmpdir:
|
||||
subprocess.run([
|
||||
sys.executable, r.Rlocation("__main__/jaxlib/tools/build_wheel.py"),
|
||||
f"--cpu={platform.machine()}",
|
||||
f"--output_path={tmpdir}",
|
||||
"--jaxlib_git_hash=12345678"
|
||||
], check=True)
|
Loading…
x
Reference in New Issue
Block a user