mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
* Add conditional docker interactive mode
Interactive causes bazel to output more useful info when running locally. * Fix issue with rocm el8 repo urls Work around quirk with rocm version when it ends with 0 * Fix package name conflict Ubu22 and higher have a package name conflict between the debian versions and the AMD provided versions. * [ROCm] Use clang env
This commit is contained in:
parent
9c3f2dcefc
commit
5c2ffa893f
@ -34,7 +34,8 @@ def image_by_name(name):
|
||||
|
||||
|
||||
def dist_wheels(
|
||||
rocm_version, python_versions, xla_path, rocm_build_job="", rocm_build_num=""
|
||||
rocm_version, python_versions, xla_path, rocm_build_job="", rocm_build_num="",
|
||||
compiler="gcc"
|
||||
):
|
||||
if xla_path:
|
||||
xla_path = os.path.abspath(xla_path)
|
||||
@ -71,6 +72,8 @@ def dist_wheels(
|
||||
rocm_version,
|
||||
"--python-versions",
|
||||
pyver_string,
|
||||
"--compiler",
|
||||
compiler,
|
||||
]
|
||||
|
||||
if xla_path:
|
||||
@ -92,6 +95,9 @@ def dist_wheels(
|
||||
|
||||
cmd.extend(mounts)
|
||||
|
||||
if os.isatty(sys.stdout.fileno()):
|
||||
cmd.append("-it")
|
||||
|
||||
# NOTE(mrodden): bazel times out without --init, probably blocking on a zombie PID
|
||||
cmd.extend(
|
||||
[
|
||||
@ -251,6 +257,12 @@ def parse_args():
|
||||
help="Path to XLA source to use during jaxlib build, instead of builtin XLA",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--compiler",
|
||||
choices=["gcc", "clang"],
|
||||
help="Compiler backend to use when compiling jax/jaxlib"
|
||||
)
|
||||
|
||||
subp = p.add_subparsers(dest="action", required=True)
|
||||
|
||||
dwp = subp.add_parser("dist_wheels")
|
||||
@ -288,6 +300,7 @@ def main():
|
||||
args.xla_source_dir,
|
||||
args.rocm_build_job,
|
||||
args.rocm_build_num,
|
||||
args.compiler,
|
||||
)
|
||||
dist_docker(
|
||||
args.rocm_version,
|
||||
|
@ -50,6 +50,7 @@ ROCM_BUILD_JOB=""
|
||||
ROCM_BUILD_NUM=""
|
||||
BASE_DOCKER="ubuntu:20.04"
|
||||
CUSTOM_INSTALL=""
|
||||
JAX_USE_CLANG=""
|
||||
POSITIONAL_ARGS=()
|
||||
|
||||
RUNTIME_FLAG=1
|
||||
@ -89,6 +90,10 @@ while [[ $# -gt 0 ]]; do
|
||||
ROCM_BUILD_NUM="$2"
|
||||
shift 2
|
||||
;;
|
||||
--use_clang)
|
||||
JAX_USE_CLANG="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
POSITIONAL_ARGS+=("$1")
|
||||
shift
|
||||
@ -135,6 +140,12 @@ echo "Building (runtime) container (${DOCKER_IMG_NAME}) with Dockerfile($DOCKERF
|
||||
|
||||
export XLA_CLONE_DIR="${XLA_CLONE_DIR:-}"
|
||||
|
||||
# default to gcc
|
||||
JAX_COMPILER="gcc"
|
||||
if [ -n "$JAX_USE_CLANG" ]; then
|
||||
JAX_COMPILER="clang"
|
||||
fi
|
||||
|
||||
# ci_build.sh is mostly a compatibility wrapper for ci_build
|
||||
|
||||
# 'dist_docker' will run 'dist_wheels' followed by a Docker build to create the "JAX image",
|
||||
@ -145,6 +156,7 @@ export XLA_CLONE_DIR="${XLA_CLONE_DIR:-}"
|
||||
--xla-source-dir=$XLA_CLONE_DIR \
|
||||
--rocm-build-job=$ROCM_BUILD_JOB \
|
||||
--rocm-build-num=$ROCM_BUILD_NUM \
|
||||
--compiler=$JAX_COMPILER \
|
||||
dist_docker \
|
||||
--dockerfile $DOCKERFILE_PATH \
|
||||
--image-tag $DOCKER_IMG_NAME
|
||||
|
@ -56,7 +56,8 @@ def update_rocm_targets(rocm_path, targets):
|
||||
open(version_fp, "a").close()
|
||||
|
||||
|
||||
def build_jaxlib_wheel(jax_path, rocm_path, python_version, xla_path=None):
|
||||
def build_jaxlib_wheel(jax_path, rocm_path, python_version, xla_path=None, compiler="gcc"):
|
||||
use_clang = "true" if compiler == "clang" else "false"
|
||||
cmd = [
|
||||
"python",
|
||||
"build/build.py",
|
||||
@ -64,6 +65,7 @@ def build_jaxlib_wheel(jax_path, rocm_path, python_version, xla_path=None):
|
||||
"--build_gpu_plugin",
|
||||
"--gpu_plugin_rocm_version=60",
|
||||
"--rocm_path=%s" % rocm_path,
|
||||
"--use_clang=%s" % use_clang,
|
||||
]
|
||||
|
||||
if xla_path:
|
||||
@ -194,6 +196,12 @@ def parse_args():
|
||||
default=None,
|
||||
help="Optional directory where XLA source is located to use instead of JAX builtin XLA",
|
||||
)
|
||||
p.add_argument(
|
||||
"--compiler",
|
||||
type=str,
|
||||
default="gcc",
|
||||
help="Compiler backend to use when compiling jax/jaxlib",
|
||||
)
|
||||
|
||||
p.add_argument("jax_path", help="Directory where JAX source directory is located")
|
||||
|
||||
@ -225,7 +233,7 @@ def main():
|
||||
update_rocm_targets(rocm_path, GPU_DEVICE_TARGETS)
|
||||
|
||||
for py in python_versions:
|
||||
build_jaxlib_wheel(args.jax_path, rocm_path, py, args.xla_path)
|
||||
build_jaxlib_wheel(args.jax_path, rocm_path, py, args.xla_path, args.compiler)
|
||||
wheel_paths = find_wheels(os.path.join(args.jax_path, "dist"))
|
||||
for wheel_path in wheel_paths:
|
||||
# skip jax wheel since it is non-platform
|
||||
|
@ -115,6 +115,23 @@ RHEL8 = System(
|
||||
)
|
||||
|
||||
|
||||
def parse_version(version_str):
|
||||
if isinstance(version_str, str):
|
||||
parts = version_str.split(".")
|
||||
rv = type("Version", (), {})()
|
||||
rv.major = int(parts[0].strip())
|
||||
rv.minor = int(parts[1].strip())
|
||||
rv.rev = None
|
||||
|
||||
if len(parts) > 2:
|
||||
rv.rev = int(parts[2].strip())
|
||||
|
||||
else:
|
||||
rv = version_str
|
||||
|
||||
return rv
|
||||
|
||||
|
||||
def get_system():
|
||||
md = os_release_meta()
|
||||
|
||||
@ -210,16 +227,7 @@ def install_amdgpu_installer_internal(rocm_version):
|
||||
def _build_installer_url(rocm_version, metadata):
|
||||
md = metadata
|
||||
|
||||
if isinstance(rocm_version, str):
|
||||
parts = rocm_version.split(".")
|
||||
rv = type("Version", (), {})()
|
||||
rv.major = parts[0]
|
||||
rv.minor = parts[1]
|
||||
|
||||
if len(parts) > 2:
|
||||
rv.rev = parts[2]
|
||||
else:
|
||||
rv = rocm_version
|
||||
rv = parse_version(rocm_version)
|
||||
|
||||
base_url = "http://artifactory-cdn.amd.com/artifactory/list"
|
||||
|
||||
@ -247,8 +255,21 @@ def _build_installer_url(rocm_version, metadata):
|
||||
return url, package_name
|
||||
|
||||
|
||||
APT_RADEON_PIN_CONTENT = """
|
||||
Package: *
|
||||
Pin: release o=repo.radeon.com
|
||||
Pin-Priority: 600
|
||||
"""
|
||||
|
||||
|
||||
def setup_repos_ubuntu(rocm_version_str):
|
||||
|
||||
rv = parse_version(rocm_version_str)
|
||||
|
||||
# if X.Y.0 -> repo url version should be X.Y
|
||||
if rv.rev == 0:
|
||||
rocm_version_str = "%d.%d" % (rv.major, rv.minor)
|
||||
|
||||
s = get_system()
|
||||
s.install_packages(["wget", "sudo", "gnupg"])
|
||||
|
||||
@ -270,6 +291,11 @@ def setup_repos_ubuntu(rocm_version_str):
|
||||
% (rocm_version_str, codename)
|
||||
)
|
||||
|
||||
# on ubuntu 22 or greater, debian community rocm packages
|
||||
# conflict with repo.radeon.com packages
|
||||
with open("/etc/apt/preferences.d/rocm-pin-600", "w") as fd:
|
||||
fd.write(APT_RADEON_PIN_CONTENT)
|
||||
|
||||
# update indexes
|
||||
subprocess.check_call(["apt-get", "update"])
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user