* Add conditional docker interactive mode

Interactive causes bazel to output more
useful info when running locally.

* Fix issue with rocm el8 repo urls

Work around quirk with rocm version
when it ends with 0

* Fix package name conflict

Ubu22 and higher have a package name conflict
between the debian versions and the AMD provided
versions.

* [ROCm] Use clang env
This commit is contained in:
Mathew Odden 2024-08-19 14:32:35 -05:00 committed by Ruturaj4
parent 9c3f2dcefc
commit 5c2ffa893f
4 changed files with 72 additions and 13 deletions

View File

@ -34,7 +34,8 @@ def image_by_name(name):
def dist_wheels(
rocm_version, python_versions, xla_path, rocm_build_job="", rocm_build_num=""
rocm_version, python_versions, xla_path, rocm_build_job="", rocm_build_num="",
compiler="gcc"
):
if xla_path:
xla_path = os.path.abspath(xla_path)
@ -71,6 +72,8 @@ def dist_wheels(
rocm_version,
"--python-versions",
pyver_string,
"--compiler",
compiler,
]
if xla_path:
@ -92,6 +95,9 @@ def dist_wheels(
cmd.extend(mounts)
if os.isatty(sys.stdout.fileno()):
cmd.append("-it")
# NOTE(mrodden): bazel times out without --init, probably blocking on a zombie PID
cmd.extend(
[
@ -251,6 +257,12 @@ def parse_args():
help="Path to XLA source to use during jaxlib build, instead of builtin XLA",
)
p.add_argument(
"--compiler",
choices=["gcc", "clang"],
help="Compiler backend to use when compiling jax/jaxlib"
)
subp = p.add_subparsers(dest="action", required=True)
dwp = subp.add_parser("dist_wheels")
@ -288,6 +300,7 @@ def main():
args.xla_source_dir,
args.rocm_build_job,
args.rocm_build_num,
args.compiler,
)
dist_docker(
args.rocm_version,

View File

@ -50,6 +50,7 @@ ROCM_BUILD_JOB=""
ROCM_BUILD_NUM=""
BASE_DOCKER="ubuntu:20.04"
CUSTOM_INSTALL=""
JAX_USE_CLANG=""
POSITIONAL_ARGS=()
RUNTIME_FLAG=1
@ -89,6 +90,10 @@ while [[ $# -gt 0 ]]; do
ROCM_BUILD_NUM="$2"
shift 2
;;
--use_clang)
JAX_USE_CLANG="$2"
shift 2
;;
*)
POSITIONAL_ARGS+=("$1")
shift
@ -135,6 +140,12 @@ echo "Building (runtime) container (${DOCKER_IMG_NAME}) with Dockerfile($DOCKERF
export XLA_CLONE_DIR="${XLA_CLONE_DIR:-}"
# default to gcc
JAX_COMPILER="gcc"
if [ -n "$JAX_USE_CLANG" ]; then
JAX_COMPILER="clang"
fi
# ci_build.sh is mostly a compatibility wrapper for ci_build
# 'dist_docker' will run 'dist_wheels' followed by a Docker build to create the "JAX image",
@ -145,6 +156,7 @@ export XLA_CLONE_DIR="${XLA_CLONE_DIR:-}"
--xla-source-dir=$XLA_CLONE_DIR \
--rocm-build-job=$ROCM_BUILD_JOB \
--rocm-build-num=$ROCM_BUILD_NUM \
--compiler=$JAX_COMPILER \
dist_docker \
--dockerfile $DOCKERFILE_PATH \
--image-tag $DOCKER_IMG_NAME

View File

@ -56,7 +56,8 @@ def update_rocm_targets(rocm_path, targets):
open(version_fp, "a").close()
def build_jaxlib_wheel(jax_path, rocm_path, python_version, xla_path=None):
def build_jaxlib_wheel(jax_path, rocm_path, python_version, xla_path=None, compiler="gcc"):
use_clang = "true" if compiler == "clang" else "false"
cmd = [
"python",
"build/build.py",
@ -64,6 +65,7 @@ def build_jaxlib_wheel(jax_path, rocm_path, python_version, xla_path=None):
"--build_gpu_plugin",
"--gpu_plugin_rocm_version=60",
"--rocm_path=%s" % rocm_path,
"--use_clang=%s" % use_clang,
]
if xla_path:
@ -194,6 +196,12 @@ def parse_args():
default=None,
help="Optional directory where XLA source is located to use instead of JAX builtin XLA",
)
p.add_argument(
"--compiler",
type=str,
default="gcc",
help="Compiler backend to use when compiling jax/jaxlib",
)
p.add_argument("jax_path", help="Directory where JAX source directory is located")
@ -225,7 +233,7 @@ def main():
update_rocm_targets(rocm_path, GPU_DEVICE_TARGETS)
for py in python_versions:
build_jaxlib_wheel(args.jax_path, rocm_path, py, args.xla_path)
build_jaxlib_wheel(args.jax_path, rocm_path, py, args.xla_path, args.compiler)
wheel_paths = find_wheels(os.path.join(args.jax_path, "dist"))
for wheel_path in wheel_paths:
# skip jax wheel since it is non-platform

View File

@ -115,6 +115,23 @@ RHEL8 = System(
)
def parse_version(version_str):
if isinstance(version_str, str):
parts = version_str.split(".")
rv = type("Version", (), {})()
rv.major = int(parts[0].strip())
rv.minor = int(parts[1].strip())
rv.rev = None
if len(parts) > 2:
rv.rev = int(parts[2].strip())
else:
rv = version_str
return rv
def get_system():
md = os_release_meta()
@ -210,16 +227,7 @@ def install_amdgpu_installer_internal(rocm_version):
def _build_installer_url(rocm_version, metadata):
md = metadata
if isinstance(rocm_version, str):
parts = rocm_version.split(".")
rv = type("Version", (), {})()
rv.major = parts[0]
rv.minor = parts[1]
if len(parts) > 2:
rv.rev = parts[2]
else:
rv = rocm_version
rv = parse_version(rocm_version)
base_url = "http://artifactory-cdn.amd.com/artifactory/list"
@ -247,8 +255,21 @@ def _build_installer_url(rocm_version, metadata):
return url, package_name
APT_RADEON_PIN_CONTENT = """
Package: *
Pin: release o=repo.radeon.com
Pin-Priority: 600
"""
def setup_repos_ubuntu(rocm_version_str):
rv = parse_version(rocm_version_str)
# if X.Y.0 -> repo url version should be X.Y
if rv.rev == 0:
rocm_version_str = "%d.%d" % (rv.major, rv.minor)
s = get_system()
s.install_packages(["wget", "sudo", "gnupg"])
@ -270,6 +291,11 @@ def setup_repos_ubuntu(rocm_version_str):
% (rocm_version_str, codename)
)
# on ubuntu 22 or greater, debian community rocm packages
# conflict with repo.radeon.com packages
with open("/etc/apt/preferences.d/rocm-pin-600", "w") as fd:
fd.write(APT_RADEON_PIN_CONTENT)
# update indexes
subprocess.check_call(["apt-get", "update"])