mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[ROCm] add build file.
This commit is contained in:
parent
c6c701e6a7
commit
9ce8de5fb0
165
build/rocm/dev_build_rocm.py
Executable file
165
build/rocm/dev_build_rocm.py
Executable file
@ -0,0 +1,165 @@
|
||||
# !/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# NOTE(ruturaj4): This script automates the build process for JAX and XLA on ROCm,
|
||||
# allowing for optional uninstallation of existing packages, and custom paths for ROCm and XLA repositories.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
def get_rocm_version():
|
||||
try:
|
||||
version = subprocess.check_output(
|
||||
"cat /opt/rocm/.info/version | cut -d '-' -f 1", shell=True
|
||||
)
|
||||
return version.decode("utf-8").strip()
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error fetching ROCm version: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_rocm_target():
|
||||
try:
|
||||
target_info = subprocess.check_output(
|
||||
"rocminfo | grep gfx | head -n 1", shell=True
|
||||
)
|
||||
target = target_info.decode("utf-8").split()[1]
|
||||
return target
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error fetching ROCm target: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def uninstall_existing_packages(packages):
|
||||
cmd = ["python3", "-m", "pip", "uninstall", "-y"]
|
||||
cmd.extend(packages)
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, check=True)
|
||||
print(f"Successfully uninstalled {packages}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Failed to uninstall {packages}: {e}")
|
||||
|
||||
|
||||
def clean_dist_directory():
|
||||
try:
|
||||
shutil.rmtree("dist")
|
||||
print("Cleaned dist directory.")
|
||||
except FileNotFoundError:
|
||||
print("dist directory not found, skipping cleanup.")
|
||||
except Exception as e:
|
||||
print(f"Failed to clean dist directory: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def build_jax_xla(xla_path, rocm_version, rocm_target, use_clang, clang_path):
|
||||
bazel_options = (
|
||||
f"--bazel_options=--override_repository=xla={xla_path}" if xla_path else ""
|
||||
)
|
||||
clang_option = f"--clang_path={clang_path}" if clang_path else ""
|
||||
build_command = [
|
||||
"python3",
|
||||
"./build/build.py",
|
||||
"--enable_rocm",
|
||||
"--build_gpu_plugin",
|
||||
"--gpu_plugin_rocm_version=60",
|
||||
f"--use_clang={str(use_clang).lower()}",
|
||||
f"--rocm_amdgpu_targets={rocm_target}",
|
||||
f"--rocm_path=/opt/rocm-{rocm_version}/",
|
||||
bazel_options,
|
||||
]
|
||||
|
||||
if clang_option:
|
||||
build_command.append(clang_option)
|
||||
|
||||
print("Executing build command:")
|
||||
print(" ".join(build_command))
|
||||
|
||||
try:
|
||||
subprocess.run(build_command, check=True)
|
||||
print("Build completed successfully.")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Build failed: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def install_wheel():
|
||||
try:
|
||||
subprocess.run(
|
||||
["python3", "-m", "pip", "install", "dist/*.whl"], check=True, shell=True
|
||||
)
|
||||
print("Packages installed successfully.")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Failed to install packages: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Script to build JAX and XLA on ROCm.")
|
||||
parser.add_argument(
|
||||
"--clang-path", type=str, default="", help="Specify the Clang compiler path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-uninstall",
|
||||
action="store_true",
|
||||
help="Skip uninstall of old versions during package install",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-clang", default="false", help="Use Clang compiler if set"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--xla-path", type=str, default="", help="Specify the XLA repository path"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.xla_path:
|
||||
args.xla_path = os.path.abspath(args.xla_path)
|
||||
print(f"Converted XLA path to absolute: {args.xla_path}")
|
||||
|
||||
rocm_version = get_rocm_version()
|
||||
if not rocm_version:
|
||||
print("Could not determine ROCm version. Exiting.")
|
||||
sys.exit(1)
|
||||
|
||||
rocm_target = get_rocm_target()
|
||||
if not rocm_target:
|
||||
print("Could not determine ROCm target. Exiting.")
|
||||
sys.exit(1)
|
||||
|
||||
if not args.skip_uninstall:
|
||||
print("Uninstalling existing packages...")
|
||||
packages = ["jax", "jaxlib", "jax-rocm60-pjrt", "jax-rocm60-plugin"]
|
||||
uninstall_existing_packages(packages)
|
||||
|
||||
clean_dist_directory()
|
||||
|
||||
print(
|
||||
f"Building JAX and XLA with ROCm version: {rocm_version}, Target: {rocm_target}"
|
||||
)
|
||||
build_jax_xla(
|
||||
args.xla_path, rocm_version, rocm_target, args.use_clang, args.clang_path
|
||||
)
|
||||
|
||||
install_wheel()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user