[ROCm] add build file.

This commit is contained in:
Ruturaj4 2024-08-23 09:36:50 -05:00
parent c6c701e6a7
commit 9ce8de5fb0

165
build/rocm/dev_build_rocm.py Executable file
View File

@ -0,0 +1,165 @@
# !/usr/bin/env python3
#
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE(ruturaj4): This script automates the build process for JAX and XLA on ROCm,
# allowing for optional uninstallation of existing packages, and custom paths for ROCm and XLA repositories.
import argparse
import os
import shutil
import subprocess
import sys
def get_rocm_version():
try:
version = subprocess.check_output(
"cat /opt/rocm/.info/version | cut -d '-' -f 1", shell=True
)
return version.decode("utf-8").strip()
except subprocess.CalledProcessError as e:
print(f"Error fetching ROCm version: {e}")
return None
def get_rocm_target():
try:
target_info = subprocess.check_output(
"rocminfo | grep gfx | head -n 1", shell=True
)
target = target_info.decode("utf-8").split()[1]
return target
except subprocess.CalledProcessError as e:
print(f"Error fetching ROCm target: {e}")
return None
def uninstall_existing_packages(packages):
cmd = ["python3", "-m", "pip", "uninstall", "-y"]
cmd.extend(packages)
try:
subprocess.run(cmd, check=True)
print(f"Successfully uninstalled {packages}")
except subprocess.CalledProcessError as e:
print(f"Failed to uninstall {packages}: {e}")
def clean_dist_directory():
try:
shutil.rmtree("dist")
print("Cleaned dist directory.")
except FileNotFoundError:
print("dist directory not found, skipping cleanup.")
except Exception as e:
print(f"Failed to clean dist directory: {e}")
sys.exit(1)
def build_jax_xla(xla_path, rocm_version, rocm_target, use_clang, clang_path):
bazel_options = (
f"--bazel_options=--override_repository=xla={xla_path}" if xla_path else ""
)
clang_option = f"--clang_path={clang_path}" if clang_path else ""
build_command = [
"python3",
"./build/build.py",
"--enable_rocm",
"--build_gpu_plugin",
"--gpu_plugin_rocm_version=60",
f"--use_clang={str(use_clang).lower()}",
f"--rocm_amdgpu_targets={rocm_target}",
f"--rocm_path=/opt/rocm-{rocm_version}/",
bazel_options,
]
if clang_option:
build_command.append(clang_option)
print("Executing build command:")
print(" ".join(build_command))
try:
subprocess.run(build_command, check=True)
print("Build completed successfully.")
except subprocess.CalledProcessError as e:
print(f"Build failed: {e}")
sys.exit(1)
def install_wheel():
try:
subprocess.run(
["python3", "-m", "pip", "install", "dist/*.whl"], check=True, shell=True
)
print("Packages installed successfully.")
except subprocess.CalledProcessError as e:
print(f"Failed to install packages: {e}")
sys.exit(1)
def main():
parser = argparse.ArgumentParser(description="Script to build JAX and XLA on ROCm.")
parser.add_argument(
"--clang-path", type=str, default="", help="Specify the Clang compiler path"
)
parser.add_argument(
"--skip-uninstall",
action="store_true",
help="Skip uninstall of old versions during package install",
)
parser.add_argument(
"--use-clang", default="false", help="Use Clang compiler if set"
)
parser.add_argument(
"--xla-path", type=str, default="", help="Specify the XLA repository path"
)
args = parser.parse_args()
if args.xla_path:
args.xla_path = os.path.abspath(args.xla_path)
print(f"Converted XLA path to absolute: {args.xla_path}")
rocm_version = get_rocm_version()
if not rocm_version:
print("Could not determine ROCm version. Exiting.")
sys.exit(1)
rocm_target = get_rocm_target()
if not rocm_target:
print("Could not determine ROCm target. Exiting.")
sys.exit(1)
if not args.skip_uninstall:
print("Uninstalling existing packages...")
packages = ["jax", "jaxlib", "jax-rocm60-pjrt", "jax-rocm60-plugin"]
uninstall_existing_packages(packages)
clean_dist_directory()
print(
f"Building JAX and XLA with ROCm version: {rocm_version}, Target: {rocm_target}"
)
build_jax_xla(
args.xla_path, rocm_version, rocm_target, args.use_clang, args.clang_path
)
install_wheel()
if __name__ == "__main__":
main()