2024-08-23 09:36:50 -05:00
|
|
|
# !/usr/bin/env python3
|
|
|
|
#
|
|
|
|
# Copyright 2024 The JAX Authors.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
# NOTE(ruturaj4): This script automates the build process for JAX and XLA on ROCm,
|
|
|
|
# allowing for optional uninstallation of existing packages, and custom paths for ROCm and XLA repositories.
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import os
|
|
|
|
import shutil
|
|
|
|
import subprocess
|
|
|
|
import sys
|
|
|
|
|
|
|
|
|
|
|
|
def get_rocm_version():
|
|
|
|
try:
|
|
|
|
version = subprocess.check_output(
|
|
|
|
"cat /opt/rocm/.info/version | cut -d '-' -f 1", shell=True
|
|
|
|
)
|
|
|
|
return version.decode("utf-8").strip()
|
|
|
|
except subprocess.CalledProcessError as e:
|
|
|
|
print(f"Error fetching ROCm version: {e}")
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
def get_rocm_target():
|
|
|
|
try:
|
|
|
|
target_info = subprocess.check_output(
|
|
|
|
"rocminfo | grep gfx | head -n 1", shell=True
|
|
|
|
)
|
|
|
|
target = target_info.decode("utf-8").split()[1]
|
|
|
|
return target
|
|
|
|
except subprocess.CalledProcessError as e:
|
|
|
|
print(f"Error fetching ROCm target: {e}")
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
def uninstall_existing_packages(packages):
|
|
|
|
cmd = ["python3", "-m", "pip", "uninstall", "-y"]
|
|
|
|
cmd.extend(packages)
|
|
|
|
|
|
|
|
try:
|
|
|
|
subprocess.run(cmd, check=True)
|
|
|
|
print(f"Successfully uninstalled {packages}")
|
|
|
|
except subprocess.CalledProcessError as e:
|
|
|
|
print(f"Failed to uninstall {packages}: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
def clean_dist_directory():
|
|
|
|
try:
|
|
|
|
shutil.rmtree("dist")
|
|
|
|
print("Cleaned dist directory.")
|
|
|
|
except FileNotFoundError:
|
|
|
|
print("dist directory not found, skipping cleanup.")
|
|
|
|
except Exception as e:
|
|
|
|
print(f"Failed to clean dist directory: {e}")
|
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
|
|
|
|
def build_jax_xla(xla_path, rocm_version, rocm_target, use_clang, clang_path):
|
|
|
|
bazel_options = (
|
|
|
|
f"--bazel_options=--override_repository=xla={xla_path}" if xla_path else ""
|
|
|
|
)
|
|
|
|
clang_option = f"--clang_path={clang_path}" if clang_path else ""
|
|
|
|
build_command = [
|
|
|
|
"python3",
|
|
|
|
"./build/build.py",
|
2024-11-25 19:25:08 +00:00
|
|
|
"build"
|
2024-08-23 09:36:50 -05:00
|
|
|
f"--use_clang={str(use_clang).lower()}",
|
2024-11-25 19:25:08 +00:00
|
|
|
"--wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt"
|
|
|
|
"--rocm_path=%/opt/rocm-{rocm_version}/",
|
|
|
|
"--rocm_version=60",
|
2024-08-23 09:36:50 -05:00
|
|
|
f"--rocm_amdgpu_targets={rocm_target}",
|
|
|
|
bazel_options,
|
2024-11-25 19:25:08 +00:00
|
|
|
"--verbose"
|
2024-08-23 09:36:50 -05:00
|
|
|
]
|
|
|
|
|
|
|
|
if clang_option:
|
|
|
|
build_command.append(clang_option)
|
|
|
|
|
|
|
|
print("Executing build command:")
|
|
|
|
print(" ".join(build_command))
|
|
|
|
|
|
|
|
try:
|
|
|
|
subprocess.run(build_command, check=True)
|
|
|
|
print("Build completed successfully.")
|
|
|
|
except subprocess.CalledProcessError as e:
|
|
|
|
print(f"Build failed: {e}")
|
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
|
|
|
|
def install_wheel():
|
|
|
|
try:
|
|
|
|
subprocess.run(
|
|
|
|
["python3", "-m", "pip", "install", "dist/*.whl"], check=True, shell=True
|
|
|
|
)
|
|
|
|
print("Packages installed successfully.")
|
|
|
|
except subprocess.CalledProcessError as e:
|
|
|
|
print(f"Failed to install packages: {e}")
|
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
parser = argparse.ArgumentParser(description="Script to build JAX and XLA on ROCm.")
|
|
|
|
parser.add_argument(
|
|
|
|
"--clang-path", type=str, default="", help="Specify the Clang compiler path"
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--skip-uninstall",
|
|
|
|
action="store_true",
|
|
|
|
help="Skip uninstall of old versions during package install",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--use-clang", default="false", help="Use Clang compiler if set"
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--xla-path", type=str, default="", help="Specify the XLA repository path"
|
|
|
|
)
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
if args.xla_path:
|
|
|
|
args.xla_path = os.path.abspath(args.xla_path)
|
|
|
|
print(f"Converted XLA path to absolute: {args.xla_path}")
|
|
|
|
|
|
|
|
rocm_version = get_rocm_version()
|
|
|
|
if not rocm_version:
|
|
|
|
print("Could not determine ROCm version. Exiting.")
|
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
rocm_target = get_rocm_target()
|
|
|
|
if not rocm_target:
|
|
|
|
print("Could not determine ROCm target. Exiting.")
|
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
if not args.skip_uninstall:
|
|
|
|
print("Uninstalling existing packages...")
|
|
|
|
packages = ["jax", "jaxlib", "jax-rocm60-pjrt", "jax-rocm60-plugin"]
|
|
|
|
uninstall_existing_packages(packages)
|
|
|
|
|
|
|
|
clean_dist_directory()
|
|
|
|
|
|
|
|
print(
|
|
|
|
f"Building JAX and XLA with ROCm version: {rocm_version}, Target: {rocm_target}"
|
|
|
|
)
|
|
|
|
build_jax_xla(
|
|
|
|
args.xla_path, rocm_version, rocm_target, args.use_clang, args.clang_path
|
|
|
|
)
|
|
|
|
|
|
|
|
install_wheel()
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|