From 9ce8de5fb0df7a4ef82b87a42f6d1ee0fb58e4c5 Mon Sep 17 00:00:00 2001 From: Ruturaj4 Date: Fri, 23 Aug 2024 09:36:50 -0500 Subject: [PATCH] [ROCm] add build file. --- build/rocm/dev_build_rocm.py | 165 +++++++++++++++++++++++++++++++++++ 1 file changed, 165 insertions(+) create mode 100755 build/rocm/dev_build_rocm.py diff --git a/build/rocm/dev_build_rocm.py b/build/rocm/dev_build_rocm.py new file mode 100755 index 000000000..2be64152f --- /dev/null +++ b/build/rocm/dev_build_rocm.py @@ -0,0 +1,165 @@ +# !/usr/bin/env python3 +# +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# NOTE(ruturaj4): This script automates the build process for JAX and XLA on ROCm, +# allowing for optional uninstallation of existing packages, and custom paths for ROCm and XLA repositories. + +import argparse +import os +import shutil +import subprocess +import sys + + +def get_rocm_version(): + try: + version = subprocess.check_output( + "cat /opt/rocm/.info/version | cut -d '-' -f 1", shell=True + ) + return version.decode("utf-8").strip() + except subprocess.CalledProcessError as e: + print(f"Error fetching ROCm version: {e}") + return None + + +def get_rocm_target(): + try: + target_info = subprocess.check_output( + "rocminfo | grep gfx | head -n 1", shell=True + ) + target = target_info.decode("utf-8").split()[1] + return target + except subprocess.CalledProcessError as e: + print(f"Error fetching ROCm target: {e}") + return None + + +def uninstall_existing_packages(packages): + cmd = ["python3", "-m", "pip", "uninstall", "-y"] + cmd.extend(packages) + + try: + subprocess.run(cmd, check=True) + print(f"Successfully uninstalled {packages}") + except subprocess.CalledProcessError as e: + print(f"Failed to uninstall {packages}: {e}") + + +def clean_dist_directory(): + try: + shutil.rmtree("dist") + print("Cleaned dist directory.") + except FileNotFoundError: + print("dist directory not found, skipping cleanup.") + except Exception as e: + print(f"Failed to clean dist directory: {e}") + sys.exit(1) + + +def build_jax_xla(xla_path, rocm_version, rocm_target, use_clang, clang_path): + bazel_options = ( + f"--bazel_options=--override_repository=xla={xla_path}" if xla_path else "" + ) + clang_option = f"--clang_path={clang_path}" if clang_path else "" + build_command = [ + "python3", + "./build/build.py", + "--enable_rocm", + "--build_gpu_plugin", + "--gpu_plugin_rocm_version=60", + f"--use_clang={str(use_clang).lower()}", + f"--rocm_amdgpu_targets={rocm_target}", + f"--rocm_path=/opt/rocm-{rocm_version}/", + bazel_options, + ] + + if clang_option: + build_command.append(clang_option) + + print("Executing build command:") + print(" ".join(build_command)) + + try: + subprocess.run(build_command, check=True) + print("Build completed successfully.") + except subprocess.CalledProcessError as e: + print(f"Build failed: {e}") + sys.exit(1) + + +def install_wheel(): + try: + subprocess.run( + ["python3", "-m", "pip", "install", "dist/*.whl"], check=True, shell=True + ) + print("Packages installed successfully.") + except subprocess.CalledProcessError as e: + print(f"Failed to install packages: {e}") + sys.exit(1) + + +def main(): + parser = argparse.ArgumentParser(description="Script to build JAX and XLA on ROCm.") + parser.add_argument( + "--clang-path", type=str, default="", help="Specify the Clang compiler path" + ) + parser.add_argument( + "--skip-uninstall", + action="store_true", + help="Skip uninstall of old versions during package install", + ) + parser.add_argument( + "--use-clang", default="false", help="Use Clang compiler if set" + ) + parser.add_argument( + "--xla-path", type=str, default="", help="Specify the XLA repository path" + ) + + args = parser.parse_args() + + if args.xla_path: + args.xla_path = os.path.abspath(args.xla_path) + print(f"Converted XLA path to absolute: {args.xla_path}") + + rocm_version = get_rocm_version() + if not rocm_version: + print("Could not determine ROCm version. Exiting.") + sys.exit(1) + + rocm_target = get_rocm_target() + if not rocm_target: + print("Could not determine ROCm target. Exiting.") + sys.exit(1) + + if not args.skip_uninstall: + print("Uninstalling existing packages...") + packages = ["jax", "jaxlib", "jax-rocm60-pjrt", "jax-rocm60-plugin"] + uninstall_existing_packages(packages) + + clean_dist_directory() + + print( + f"Building JAX and XLA with ROCm version: {rocm_version}, Target: {rocm_target}" + ) + build_jax_xla( + args.xla_path, rocm_version, rocm_target, args.use_clang, args.clang_path + ) + + install_wheel() + + +if __name__ == "__main__": + main()