From 3b4a7b029b226fbf298546857358d8344490ab9c Mon Sep 17 00:00:00 2001
From: Charles Hofer <Charles.Hofer@amd.com>
Date: Fri, 11 Apr 2025 19:18:23 +0000
Subject: [PATCH] Make Clang use manylinux C++ standard library

---
 .../build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm   | 2 +-
 build/rocm/ci_build                                      | 9 ++++++++-
 build/rocm/tools/build_wheels.py                         | 5 +++++
 build/rocm/tools/fixwheel.py                             | 2 +-
 4 files changed, 15 insertions(+), 3 deletions(-)

diff --git a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm
index 14bf6fd60..c32d18dbc 100644
--- a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm
+++ b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm
@@ -9,7 +9,7 @@ ARG ROCM_BUILD_NUM
 #  manylinux base image. However, adding this does fix an issue where Bazel isn't able
 #  to find them.
 RUN --mount=type=cache,target=/var/cache/dnf \
-    dnf install -y gcc-c++-8.5.0-22.el8_10.x86_64 numactl-devel
+    dnf install -y numactl-devel
 
 RUN --mount=type=cache,target=/var/cache/dnf \
     --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \
diff --git a/build/rocm/ci_build b/build/rocm/ci_build
index ee2c8698d..8620fbf1a 100755
--- a/build/rocm/ci_build
+++ b/build/rocm/ci_build
@@ -98,7 +98,14 @@ def dist_wheels(
 
     bw_cmd.append("/jax")
 
-    cmd = ["docker", "run"]
+    cmd = [
+        "docker",
+        "run",
+        "-e",
+        "HIPCC_COMPILE_FLAGS_APPEND=--gcc-toolchain=/opt/rh/gcc-toolset-14/root/usr/", 
+        "-e",
+        "HIPCC_LINK_FLAGS_APPEND=--gcc-toolchain=/opt/rh/gcc-toolset-14/root/usr/"
+    ]
 
     mounts = [
         "-v",
diff --git a/build/rocm/tools/build_wheels.py b/build/rocm/tools/build_wheels.py
index 139a1fdd8..ebeadcd88 100644
--- a/build/rocm/tools/build_wheels.py
+++ b/build/rocm/tools/build_wheels.py
@@ -105,6 +105,11 @@ def build_jaxlib_wheel(
         "python",
         "build/build.py",
         "build",
+        "--bazel_options=--host_linkopt=--gcc-toolchain=/opt/rh/gcc-toolset-14/root/usr/",
+        "--bazel_options=--linkopt=--gcc-toolchain=/opt/rh/gcc-toolset-14/root/usr/",
+        "--bazel_options=--host_cxxopt=--gcc-toolchain=/opt/rh/gcc-toolset-14/root/usr/",
+        "--bazel_options=--cxxopt=--gcc-toolchain=/opt/rh/gcc-toolset-14/root/usr/",
+        #"--bazel_options=--subcommands",
         "--wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt",
         "--rocm_path=%s" % rocm_path,
         "--rocm_version=60",
diff --git a/build/rocm/tools/fixwheel.py b/build/rocm/tools/fixwheel.py
index ea7716272..7d8c1fcce 100644
--- a/build/rocm/tools/fixwheel.py
+++ b/build/rocm/tools/fixwheel.py
@@ -87,7 +87,7 @@ def fix_wheel(path):
     exclude = list(ext_libs.keys())
 
     # call auditwheel repair with excludes
-    cmd = ["auditwheel", "repair", "--plat", plat, "--only-plat"]
+    cmd = ["auditwheel", "-v", "repair", "--plat", plat, "--only-plat"]
 
     for ex in exclude:
         cmd.append("--exclude")